From 4fc8f262b68d1539d28ee080dc284fcdc9ae0ae0 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Tue, 9 Mar 2021 14:57:40 -0500 Subject: [PATCH 001/139] open PR From ee57184516a3de2d8df6ee3a7e84958a8fd9a486 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 2 Apr 2021 10:26:33 -0400 Subject: [PATCH 002/139] catalog port --- src/catalog/catalog_accessor.cpp | 2 +- src/catalog/database_catalog.cpp | 9 +++---- src/catalog/postgres/pg_proc_impl.cpp | 27 ++++++++++++--------- src/include/catalog/catalog_accessor.h | 2 +- src/include/catalog/database_catalog.h | 2 +- src/include/catalog/postgres/pg_proc.h | 2 +- src/include/catalog/postgres/pg_proc_impl.h | 2 +- test/catalog/catalog_test.cpp | 6 ++--- 8 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/catalog/catalog_accessor.cpp b/src/catalog/catalog_accessor.cpp index 1d04fa92ec..9d8cccf8b6 100644 --- a/src/catalog/catalog_accessor.cpp +++ b/src/catalog/catalog_accessor.cpp @@ -177,7 +177,7 @@ proc_oid_t CatalogAccessor::CreateProcedure(const std::string &procname, languag namespace_oid_t procns, const std::vector &args, const std::vector &arg_types, const std::vector &all_arg_types, - const std::vector &arg_modes, + const std::vector &arg_modes, type_oid_t rettype, const std::string &src, bool is_aggregate) { return dbc_->CreateProcedure(txn_, procname, language_oid, procns, args, arg_types, all_arg_types, arg_modes, rettype, src, is_aggregate); diff --git a/src/catalog/database_catalog.cpp b/src/catalog/database_catalog.cpp index 25b57dd61b..6aef17f6ee 100644 --- a/src/catalog/database_catalog.cpp +++ b/src/catalog/database_catalog.cpp @@ -428,14 +428,13 @@ proc_oid_t DatabaseCatalog::CreateProcedure(common::ManagedPointer &args, const std::vector &arg_types, const std::vector &all_arg_types, - const std::vector &arg_modes, + const std::vector &arg_modes, type_oid_t rettype, const std::string &src, bool is_aggregate) { if (!TryLock(txn)) return INVALID_PROC_OID; proc_oid_t oid = proc_oid_t{next_oid_++}; - return pg_proc_.CreateProcedure(txn, oid, procname, language_oid, procns, args, arg_types, all_arg_types, arg_modes, - rettype, src, is_aggregate) - ? oid - : INVALID_PROC_OID; + // TODO(Kyle): Why did Tanuj have his own implementation here? + const auto result = pg_proc_.CreateProcedure(txn, oid, procname, language_oid, procns, args, arg_types, all_arg_types, arg_modes, rettype, src, is_aggregate); + return result ? oid : INVALID_PROC_OID; } bool DatabaseCatalog::DropProcedure(const common::ManagedPointer txn, diff --git a/src/catalog/postgres/pg_proc_impl.cpp b/src/catalog/postgres/pg_proc_impl.cpp index 1aeec0fecf..d65b626649 100644 --- a/src/catalog/postgres/pg_proc_impl.cpp +++ b/src/catalog/postgres/pg_proc_impl.cpp @@ -77,7 +77,7 @@ bool PgProcImpl::CreateProcedure(const common::ManagedPointer &args, const std::vector &arg_types, const std::vector &all_arg_types, - const std::vector &arg_modes, const type_oid_t rettype, + const std::vector &arg_modes, const type_oid_t rettype, const std::string &src, const bool is_aggregate) { NOISEPAGE_ASSERT(args.size() < UINT16_MAX, "Number of arguments must fit in a SMALLINT"); @@ -89,11 +89,13 @@ bool PgProcImpl::CreateProcedure(const common::ManagedPointer arg_name_vec; + std::vector arg_name_vec{}; arg_name_vec.reserve(args.size() * sizeof(storage::VarlenEntry)); - for (auto &arg : args) { - arg_name_vec.push_back(arg); - } + std::copy(args.cbegin(), args.cend(), std::back_inserter(arg_name_vec)); + // arg_name_vec.reserve(args.size() * ); + // for (auto &arg : args) { + // arg_name_vec.push_back(arg); + // } const auto arg_names_varlen = storage::StorageUtil::CreateVarlen(args); const auto arg_types_varlen = storage::StorageUtil::CreateVarlen(arg_types); @@ -107,9 +109,11 @@ bool PgProcImpl::CreateProcedure(const common::ManagedPointerGetProjectedRowInitializer(); auto name_pri = procs_name_index_->GetProjectedRowInitializer(); - byte *const buffer = common::AllocationUtil::AllocateAligned(name_pri.ProjectedRowSize()); + auto buffer = std::unique_ptr(common::AllocationUtil::AllocateAligned(name_pri.ProjectedRowSize())); // Insert into pg_proc_name_index. { - auto name_pr = name_pri.InitializeRow(buffer); + auto name_pr = name_pri.InitializeRow(buffer.get()); auto name_map = procs_name_index_->GetKeyOidToOffsetMap(); name_pr->Set(name_map[indexkeycol_oid_t(1)], procns, false); name_pr->Set(name_map[indexkeycol_oid_t(2)], name_varlen, false); if (auto result = procs_name_index_->Insert(txn, *name_pr, tuple_slot); !result) { - delete[] buffer; return false; } } // Insert into pg_proc_oid_index. { - auto oid_pr = oid_pri.InitializeRow(buffer); + auto oid_pr = oid_pri.InitializeRow(buffer.get()); oid_pr->Set(0, oid, false); bool UNUSED_ATTRIBUTE result = procs_oid_index_->InsertUnique(txn, *oid_pr, tuple_slot); NOISEPAGE_ASSERT(result, "Oid insertion should be unique"); } - delete[] buffer; return true; } @@ -448,12 +451,12 @@ void PgProcImpl::BootstrapProcs(const common::ManagedPointernext_oid_++}, "nprunnersemitint", PgLanguage::INTERNAL_LANGUAGE_OID, PgNamespace::NAMESPACE_DEFAULT_NAMESPACE_OID, {"num_tuples", "num_cols", "num_int_cols", "num_real_cols"}, {INT, INT, INT, INT}, {INT, INT, INT, INT}, - {PgProc::ArgModes::IN, PgProc::ArgModes::IN, PgProc::ArgModes::IN, PgProc::ArgModes::IN}, INT, "", false); + {PgProc::ArgMode::IN, PgProc::ArgMode::IN, PgProc::ArgMode::IN, PgProc::ArgMode::IN}, INT, "", false); CreateProcedure( txn, proc_oid_t{dbc->next_oid_++}, "nprunnersemitreal", PgLanguage::INTERNAL_LANGUAGE_OID, PgNamespace::NAMESPACE_DEFAULT_NAMESPACE_OID, {"num_tuples", "num_cols", "num_int_cols", "num_real_cols"}, {INT, INT, INT, INT}, {INT, INT, INT, INT}, - {PgProc::ArgModes::IN, PgProc::ArgModes::IN, PgProc::ArgModes::IN, PgProc::ArgModes::IN}, REAL, "", false); + {PgProc::ArgMode::IN, PgProc::ArgMode::IN, PgProc::ArgMode::IN, PgProc::ArgMode::IN}, REAL, "", false); CreateProcedure(txn, proc_oid_t{dbc->next_oid_++}, "nprunnersdummyint", PgLanguage::INTERNAL_LANGUAGE_OID, PgNamespace::NAMESPACE_DEFAULT_NAMESPACE_OID, {}, {}, {}, {}, INT, "", false); CreateProcedure(txn, proc_oid_t{dbc->next_oid_++}, "nprunnersdummyreal", PgLanguage::INTERNAL_LANGUAGE_OID, diff --git a/src/include/catalog/catalog_accessor.h b/src/include/catalog/catalog_accessor.h index b193649836..d3cbc4b8dc 100644 --- a/src/include/catalog/catalog_accessor.h +++ b/src/include/catalog/catalog_accessor.h @@ -320,7 +320,7 @@ class EXPORT CatalogAccessor { proc_oid_t CreateProcedure(const std::string &procname, language_oid_t language_oid, namespace_oid_t procns, const std::vector &args, const std::vector &arg_types, const std::vector &all_arg_types, - const std::vector &arg_modes, type_oid_t rettype, + const std::vector &arg_modes, type_oid_t rettype, const std::string &src, bool is_aggregate); /** diff --git a/src/include/catalog/database_catalog.h b/src/include/catalog/database_catalog.h index 867280145f..5c51cc7714 100644 --- a/src/include/catalog/database_catalog.h +++ b/src/include/catalog/database_catalog.h @@ -162,7 +162,7 @@ class DatabaseCatalog { proc_oid_t CreateProcedure(common::ManagedPointer txn, const std::string &procname, language_oid_t language_oid, namespace_oid_t procns, const std::vector &args, const std::vector &arg_types, const std::vector &all_arg_types, - const std::vector &arg_modes, type_oid_t rettype, + const std::vector &arg_modes, type_oid_t rettype, const std::string &src, bool is_aggregate); /** @brief Drop the specified procedure. @see PgProcImpl::DropProcedure */ bool DropProcedure(common::ManagedPointer txn, proc_oid_t proc); diff --git a/src/include/catalog/postgres/pg_proc.h b/src/include/catalog/postgres/pg_proc.h index b948709152..11a6d2a173 100644 --- a/src/include/catalog/postgres/pg_proc.h +++ b/src/include/catalog/postgres/pg_proc.h @@ -26,7 +26,7 @@ class PgProcImpl; class PgProc { public: /** The type of the argument to the procedure. */ - enum class ArgModes : char { + enum class ArgMode : char { IN = 'i', ///< Input argument. OUT = 'o', ///< Output argument. INOUT = 'b', ///< Both input and output argument. diff --git a/src/include/catalog/postgres/pg_proc_impl.h b/src/include/catalog/postgres/pg_proc_impl.h index ce3627339a..8c1303243c 100644 --- a/src/include/catalog/postgres/pg_proc_impl.h +++ b/src/include/catalog/postgres/pg_proc_impl.h @@ -103,7 +103,7 @@ class PgProcImpl { const std::string &procname, language_oid_t language_oid, namespace_oid_t procns, const std::vector &args, const std::vector &arg_types, const std::vector &all_arg_types, - const std::vector &arg_modes, type_oid_t rettype, + const std::vector &arg_modes, type_oid_t rettype, const std::string &src, bool is_aggregate); /** diff --git a/test/catalog/catalog_test.cpp b/test/catalog/catalog_test.cpp index 5c42be117c..8ec8e3729b 100644 --- a/test/catalog/catalog_test.cpp +++ b/test/catalog/catalog_test.cpp @@ -138,9 +138,9 @@ TEST_F(CatalogTests, ProcTest) { std::vector arg_types = {accessor->GetTypeOidFromTypeId(type::TypeId::INTEGER), accessor->GetTypeOidFromTypeId(type::TypeId::BOOLEAN), accessor->GetTypeOidFromTypeId(type::TypeId::SMALLINT)}; - std::vector arg_modes = {catalog::postgres::PgProc::ArgModes::IN, - catalog::postgres::PgProc::ArgModes::IN, - catalog::postgres::PgProc::ArgModes::IN}; + std::vector arg_modes = {catalog::postgres::PgProc::ArgMode::IN, + catalog::postgres::PgProc::ArgMode::IN, + catalog::postgres::PgProc::ArgMode::IN}; auto src = "int sample(arg1, arg2, arg3){return 2;}"; auto proc_oid = From ddad51a51404fb5a6ff5a9752c7bebd6c5e59421 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 2 Apr 2021 11:18:18 -0400 Subject: [PATCH 003/139] port tpl tests --- build-support/run_tpl_tests.py | 75 +++++++++++++++++-------------- sample_tpl/agg-lambda.tpl | 81 ++++++++++++++++++++++++++++++++++ sample_tpl/call-lambda.tpl | 43 ++++++++++++++++++ sample_tpl/param-lambda.tpl | 11 +++++ sample_tpl/struct-lambda.tpl | 21 +++++++++ sample_tpl/tpl_tests.txt | 4 ++ 6 files changed, 201 insertions(+), 34 deletions(-) create mode 100644 sample_tpl/agg-lambda.tpl create mode 100644 sample_tpl/call-lambda.tpl create mode 100644 sample_tpl/param-lambda.tpl create mode 100644 sample_tpl/struct-lambda.tpl diff --git a/build-support/run_tpl_tests.py b/build-support/run_tpl_tests.py index ea1fdf9257..c594ee2c73 100755 --- a/build-support/run_tpl_tests.py +++ b/build-support/run_tpl_tests.py @@ -1,15 +1,25 @@ #!/usr/bin/env python3 -import argparse import os -import subprocess import sys +import argparse +import subprocess -VM_TARGET_STRING = 'VM main() returned: ' -ADAPTIVE_TARGET_STRING = 'ADAPTIVE main() returned: ' -JIT_TARGET_STRING = 'JIT main() returned: ' -TARGET_STRINGS = [VM_TARGET_STRING, ADAPTIVE_TARGET_STRING, JIT_TARGET_STRING] +# Exit codes +EXIT_SUCCESS = 0 +EXIT_FAILURE = 1 + +# String prefixed to VM execution tests +VM_TARGET_STRING = "VM main() returned: " +# String prefixed to ADAPTIVE execution tests +ADAPTIVE_TARGET_STRING = "ADAPTIVE main() returned: " + +# String prefixed to JIT execution tests +JIT_TARGET_STRING = "JIT main() returned: " + +# Collection of all target strings +TARGET_STRINGS = [VM_TARGET_STRING, ADAPTIVE_TARGET_STRING, JIT_TARGET_STRING] def run(tpl_bin, tpl_file, is_sql): args = [tpl_bin] @@ -18,11 +28,7 @@ def run(tpl_bin, tpl_file, is_sql): args.append(tpl_file) proc = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) result = [] - #print("tpl_file stdout:") - #print(proc.stdout.decode('utf-8')) - #print("tpl_file stderr:") - #print(proc.stderr.decode('utf-8')) - for line in reversed(proc.stdout.decode('utf-8').split('\n')): + for line in reversed(proc.stdout.decode("utf-8").split("\n")): if "ERROR" in line or "error" in line: return [] for target_string in TARGET_STRINGS: @@ -31,51 +37,52 @@ def run(tpl_bin, tpl_file, is_sql): result.append(line[idx + len(target_string):]) return result - def check(tpl_bin, tpl_folder, tpl_tests_file, build_dir): os.chdir(build_dir) with open(tpl_tests_file) as tpl_tests: num_tests, failed = 0, set() - print('Tests:') + print("Tests:") for line in tpl_tests: line = line.strip() - if not line or line[0] == '#': + if not line or line[0] == "#": continue - tpl_file, sql, expected_output = [x.strip() for x in line.split(',')] + tpl_file, sql, expected_output = [x.strip() for x in line.split(",")] + is_sql = sql.lower() == "true" res = run(tpl_bin, os.path.join(tpl_folder, tpl_file), is_sql) num_tests += 1 - report = 'PASS' + report = "PASS" if not res: - report = 'ERR' + report = "ERR" failed.add(tpl_file) - elif len(res) != 3 or not all(output == expected_output for output in res): - report = 'FAIL [expect: {}, actual: {}]'.format(expected_output, - res) + elif len(res) != 3 or not all(output == expected_output for output in res): + report = "FAIL [expect: {}, actual: {}]".format(expected_output, res) failed.add(tpl_file) - print('\t{}: {}'.format(tpl_file, report)) - print('{}/{} tests passed.'.format(num_tests - len(failed), num_tests)) + print("\t{}: {}".format(tpl_file, report)) + + print("{}/{} tests passed.".format(num_tests - len(failed), num_tests)) if len(failed) > 0: - print('{} failed:'.format(len(failed))) + print("{} failed:".format(len(failed))) for fail in failed: - print('\t{}'.format(fail)) - sys.exit(-1) + print("\t{}".format(fail)) + return EXIT_FAILURE + return EXIT_SUCCESS def main(): parser = argparse.ArgumentParser() - parser.add_argument('-b', dest='tpl_bin', help='TPL binary.') - parser.add_argument('-f', dest='tpl_tests_file', - help='File containing lines.') - parser.add_argument('-t', dest='tpl_folder', help='TPL tests folder.') - parser.add_argument('-d', dest='build_dir', help='Build Directory.') + parser.add_argument("-b", dest="tpl_bin", help="TPL binary.") + parser.add_argument("-f", dest="tpl_tests_file", + help="File containing lines.") + parser.add_argument("-t", dest="tpl_folder", help="TPL tests folder.") + parser.add_argument("-d", dest="build_dir", help="Build Directory.") args = parser.parse_args() - check(args.tpl_bin, args.tpl_folder, args.tpl_tests_file, args.build_dir) - + + return check(args.tpl_bin, args.tpl_folder, args.tpl_tests_file, args.build_dir) -if __name__ == '__main__': - main() +if __name__ == "__main__": + sys.exit(main()) diff --git a/sample_tpl/agg-lambda.tpl b/sample_tpl/agg-lambda.tpl new file mode 100644 index 0000000000..e1c5a6df30 --- /dev/null +++ b/sample_tpl/agg-lambda.tpl @@ -0,0 +1,81 @@ +// Expected output: 10 +// SQL: SELECT col_b, count(col_a) FROM test_1 GROUP BY col_b + +struct State { + table: AggregationHashTable + count: int32 +} + +struct OutputStruct { + out0: Integer +} + +struct Agg { + key: Integer + count: CountStarAggregate +} + +fun setUpState(execCtx: *ExecutionContext, state: *State) -> nil { + state.count = 0 + @aggHTInit(&state.table, execCtx, @sizeOf(Agg)) +} + +fun tearDownState(state: *State) -> nil { + @aggHTFree(&state.table) +} + +fun keyCheck(agg: *Agg, vpi: *VectorProjectionIterator) -> bool { + var key = @vpiGetInt(vpi, 1) + return @sqlToBool(key == agg.key) +} + +fun constructAgg(agg: *Agg, vpi: *VectorProjectionIterator) -> nil { + agg.key = @vpiGetInt(vpi, 1) + @aggInit(&agg.count) +} + +fun updateAgg(agg: *Agg, vpi: *VectorProjectionIterator) -> nil { + var input = @vpiGetInt(vpi, 0) + @aggAdvance(&agg.count, &input) +} + +fun pipeline_1(execCtx: *ExecutionContext, state: *State, lam : lambda [(Integer)->nil] ) -> nil { + var ht = &state.table + var tvi: TableVectorIterator + var table_oid = @testCatalogLookup(execCtx, "test_1", "") + var col_oids: [2]uint32 + col_oids[0] = @testCatalogLookup(execCtx, "test_1", "cola") + col_oids[1] = @testCatalogLookup(execCtx, "test_1", "colb") + for (@tableIterInit(&tvi, execCtx, table_oid, col_oids); @tableIterAdvance(&tvi); ) { + var vec = @tableIterGetVPI(&tvi) + for (; @vpiHasNext(vec); @vpiAdvance(vec)) { + var output_row: OutputStruct + output_row.out0 = @vpiGetIntNull(vec, 0) + lam(output_row.out0) + } + } + @tableIterClose(&tvi) +} + +fun execQuery(execCtx: *ExecutionContext, qs: *State, lam : lambda [(Integer)->nil] ) -> nil { + pipeline_1(execCtx, qs, lam) +} + +fun main(execCtx: *ExecutionContext) -> int32 { + var count : Integer + count = @intToSql(0) + var lam = lambda [count] (x : Integer) -> nil { + count = count + 1 + } + var state: State + + setUpState(execCtx, &state) + execQuery(execCtx, &state, lam) + tearDownState(&state) + + var ret = state.count + if(count > 0) { + return 1 + } + return 0 +} diff --git a/sample_tpl/call-lambda.tpl b/sample_tpl/call-lambda.tpl new file mode 100644 index 0000000000..76f088d12a --- /dev/null +++ b/sample_tpl/call-lambda.tpl @@ -0,0 +1,43 @@ +// Expected output: 70 + +fun f(z : Date ) -> Date { return z } + +fun main(exec : *ExecutionContext) -> int32 { + var y = 11 + var lam = lambda [y] (z: Integer ) -> nil { + y = y + z + } + lam(10) + + + var d = @dateToSql(1999, 2, 11) + //f(lam, d) + var k : Date + //var h = &k + //*h = d + //k = f(d) + lam(d) + if(@datePart(y, @intToSql(21)) == @intToSql(1999)){ + // good + return 1 + } + return 0 +} + +fun pipeline1(QueryState *q) { + TableIterator tvi; + for(@tableIteratorAdvance(&tvi)){ + @hashTableInsert(q.join_ht, @getTupleValue(&tvi, 3))) + } +} + +fun pipeline2(QueryState *q) { + TableIterator tvi; + for(@tableIteratorAdvance(&tvi)){ + var o_custkey = @getTupleValue(&tvi, 1) + if(@hashTableKeyExists(q.join_ht, o_custkey)){ + var out = @outputBufferAlloc(q.output_buff) + out.col1 = o_custkey + 1 + } + } +} \ No newline at end of file diff --git a/sample_tpl/param-lambda.tpl b/sample_tpl/param-lambda.tpl new file mode 100644 index 0000000000..417e05b26b --- /dev/null +++ b/sample_tpl/param-lambda.tpl @@ -0,0 +1,11 @@ +// Expected output: 10 + +fun check(x: int32) -> int32 { + var ret = x + return ret +} + +fun main() -> int32 { + var fn = lambda (x: int32) -> nil { return x + 1; } + return fn(2) +} diff --git a/sample_tpl/struct-lambda.tpl b/sample_tpl/struct-lambda.tpl new file mode 100644 index 0000000000..d8195af028 --- /dev/null +++ b/sample_tpl/struct-lambda.tpl @@ -0,0 +1,21 @@ +// Expected output: 10 + +struct S { + a: int + b: int + c: (int32) -> int32 +} +struct SDup { + d: int + e: int + f: int +} + +fun sss(x : int32) -> int32 { + return x +} + +fun main() -> int { + var p: S + p.c = sss +} diff --git a/sample_tpl/tpl_tests.txt b/sample_tpl/tpl_tests.txt index b7e15901bf..1ffd0a07e0 100644 --- a/sample_tpl/tpl_tests.txt +++ b/sample_tpl/tpl_tests.txt @@ -10,6 +10,7 @@ array.tpl,false,44 array-iterate.tpl,false,110 array-iterate-2.tpl,false,110 call.tpl,false,70 +#call-lambda.tpl,false,70 TODO(Kyle): Requires lambdas comments.tpl,false,46 compare.tpl,false,200 date-functions.tpl,false,0 @@ -28,6 +29,7 @@ loop4.tpl,false,166167000 nil.tpl,false,0 offsetof.tpl,false,54 param.tpl,false,10 +#param-lambda.tpl,false,10 TODO(Kyle): Requires lambdas point.tpl,false,-20 pointer.tpl,false,10 return-expr.tpl,false,15 @@ -40,6 +42,7 @@ short-circuit.tpl,false,1 #sql-conversions.tpl,false,0 TODO(WAN): wtf Mac CI? sql-date.tpl,false,0 struct.tpl,false,10 +#struct-lambda.tpl,false,10 TODO(Kyle): Requires lambdas struct-debug.tpl,false,100000 struct-empty.tpl,false,0 struct-field-use.tpl,false,30 @@ -63,6 +66,7 @@ types/timestamps.tpl,false,0 ################################################################################ agg.tpl,true,10 +#agg-lambda.tpl,true,10 TODO(Kyle): Requires lambdas #agg-vec.tpl,true,10 doesn't work on prashanth's branch #agg-vec-filter.tpl,true,10 doesn't work on prashanth's branch delete.tpl,true,0 From f8383eed8578d069664fb7daf044138692ae5f06 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 2 Apr 2021 14:17:05 -0400 Subject: [PATCH 004/139] port network layer --- src/include/network/network_defs.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/include/network/network_defs.h b/src/include/network/network_defs.h index 33f76a1b87..75de4d1a78 100644 --- a/src/include/network/network_defs.h +++ b/src/include/network/network_defs.h @@ -109,6 +109,7 @@ enum class QueryType : uint8_t { QUERY_CREATE_INDEX, QUERY_CREATE_TRIGGER, QUERY_CREATE_SCHEMA, + QUERY_CREATE_FUNCTION, QUERY_CREATE_VIEW, QUERY_DROP_TABLE, QUERY_DROP_DB, From 9ae19b222f50384f233c009ad72743f8e0b30885 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 2 Apr 2021 15:01:19 -0400 Subject: [PATCH 005/139] port traffic cop, but dependencies on executors and refactor of execution context not causing problems --- src/traffic_cop/traffic_cop.cpp | 18 +++++++++++++++++- src/traffic_cop/traffic_cop_util.cpp | 3 +++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index 3a4976bb36..4b7d699833 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -197,7 +197,7 @@ TrafficCopResult TrafficCop::ExecuteCreateStatement( NOISEPAGE_ASSERT( query_type == network::QueryType::QUERY_CREATE_TABLE || query_type == network::QueryType::QUERY_CREATE_SCHEMA || query_type == network::QueryType::QUERY_CREATE_INDEX || query_type == network::QueryType::QUERY_CREATE_DB || - query_type == network::QueryType::QUERY_CREATE_VIEW || query_type == network::QueryType::QUERY_CREATE_TRIGGER, + query_type == network::QueryType::QUERY_CREATE_VIEW || query_type == network::QueryType::QUERY_CREATE_TRIGGER || query_type == network::QueryType::QUERY_CREATE_FUNCTION, "ExecuteCreateStatement called with invalid QueryType."); switch (query_type) { case network::QueryType::QUERY_CREATE_TABLE: { @@ -229,6 +229,15 @@ TrafficCopResult TrafficCop::ExecuteCreateStatement( } break; } + case network::QueryType::QUERY_CREATE_FUNCTION: { + // TODO(Kyle): Port executor + // if (execution::sql::DDLExecutors::CreateFunctionExecutor( + // physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { + // return {ResultType::COMPLETE, 0}; + // } + throw NOT_IMPLEMENTED_EXCEPTION("CREATE FUNCTION not implemented"); + break; + } default: { return {ResultType::ERROR, common::ErrorData(common::ErrorSeverity::ERROR, "unsupported CREATE statement type", common::ErrorCode::ERRCODE_FEATURE_NOT_SUPPORTED)}; @@ -437,6 +446,13 @@ TrafficCopResult TrafficCop::RunExecutableQuery(const common::ManagedPointerSetParams(portal->Parameters()); + // TODO(Kyle): Refactor to algorithm + // std::vector> params{}; + // for (auto &cve : *(portal->Parameters())){ + // params.push_back(common::ManagedPointer(cve.PeekPtr())); + // } + // exec_ctx->SetParams(common::ManagedPointer(¶ms)); + const auto exec_query = portal->GetStatement()->GetExecutableQuery(); try { diff --git a/src/traffic_cop/traffic_cop_util.cpp b/src/traffic_cop/traffic_cop_util.cpp index e75c6ccbc1..b83dd95b9b 100644 --- a/src/traffic_cop/traffic_cop_util.cpp +++ b/src/traffic_cop/traffic_cop_util.cpp @@ -129,6 +129,9 @@ network::QueryType TrafficCopUtil::QueryTypeForStatement(const common::ManagedPo return network::QueryType::QUERY_CREATE_VIEW; } } + case parser::StatementType::CREATE_FUNC: { + return network::QueryType::QUERY_CREATE_FUNCTION; + } case parser::StatementType::DROP: { const auto drop_type = statement.CastManagedPointerTo()->GetDropType(); switch (drop_type) { From f18d61706a6215be3df08b65a67abd399a69476c Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 2 Apr 2021 16:53:00 -0400 Subject: [PATCH 006/139] port parser code, and run formatter --- src/catalog/catalog_accessor.cpp | 4 +- src/catalog/database_catalog.cpp | 7 ++- src/catalog/postgres/pg_proc_impl.cpp | 14 ++--- .../parser/create_function_statement.h | 38 ++++++++++- src/include/parser/postgresparser.h | 16 +++-- src/parser/postgresparser.cpp | 63 +++++++++++-------- src/traffic_cop/traffic_cop.cpp | 4 +- test/catalog/catalog_test.cpp | 4 +- 8 files changed, 102 insertions(+), 48 deletions(-) diff --git a/src/catalog/catalog_accessor.cpp b/src/catalog/catalog_accessor.cpp index 9d8cccf8b6..805616ce44 100644 --- a/src/catalog/catalog_accessor.cpp +++ b/src/catalog/catalog_accessor.cpp @@ -177,8 +177,8 @@ proc_oid_t CatalogAccessor::CreateProcedure(const std::string &procname, languag namespace_oid_t procns, const std::vector &args, const std::vector &arg_types, const std::vector &all_arg_types, - const std::vector &arg_modes, - type_oid_t rettype, const std::string &src, bool is_aggregate) { + const std::vector &arg_modes, type_oid_t rettype, + const std::string &src, bool is_aggregate) { return dbc_->CreateProcedure(txn_, procname, language_oid, procns, args, arg_types, all_arg_types, arg_modes, rettype, src, is_aggregate); } diff --git a/src/catalog/database_catalog.cpp b/src/catalog/database_catalog.cpp index 6aef17f6ee..5106d550b1 100644 --- a/src/catalog/database_catalog.cpp +++ b/src/catalog/database_catalog.cpp @@ -428,12 +428,13 @@ proc_oid_t DatabaseCatalog::CreateProcedure(common::ManagedPointer &args, const std::vector &arg_types, const std::vector &all_arg_types, - const std::vector &arg_modes, - type_oid_t rettype, const std::string &src, bool is_aggregate) { + const std::vector &arg_modes, type_oid_t rettype, + const std::string &src, bool is_aggregate) { if (!TryLock(txn)) return INVALID_PROC_OID; proc_oid_t oid = proc_oid_t{next_oid_++}; // TODO(Kyle): Why did Tanuj have his own implementation here? - const auto result = pg_proc_.CreateProcedure(txn, oid, procname, language_oid, procns, args, arg_types, all_arg_types, arg_modes, rettype, src, is_aggregate); + const auto result = pg_proc_.CreateProcedure(txn, oid, procname, language_oid, procns, args, arg_types, all_arg_types, + arg_modes, rettype, src, is_aggregate); return result ? oid : INVALID_PROC_OID; } diff --git a/src/catalog/postgres/pg_proc_impl.cpp b/src/catalog/postgres/pg_proc_impl.cpp index d65b626649..2946e05f24 100644 --- a/src/catalog/postgres/pg_proc_impl.cpp +++ b/src/catalog/postgres/pg_proc_impl.cpp @@ -109,7 +109,7 @@ bool PgProcImpl::CreateProcedure(const common::ManagedPointernext_oid_++}, "nprunnersemitint", PgLanguage::INTERNAL_LANGUAGE_OID, - PgNamespace::NAMESPACE_DEFAULT_NAMESPACE_OID, {"num_tuples", "num_cols", "num_int_cols", "num_real_cols"}, - {INT, INT, INT, INT}, {INT, INT, INT, INT}, - {PgProc::ArgMode::IN, PgProc::ArgMode::IN, PgProc::ArgMode::IN, PgProc::ArgMode::IN}, INT, "", false); + CreateProcedure(txn, proc_oid_t{dbc->next_oid_++}, "nprunnersemitint", PgLanguage::INTERNAL_LANGUAGE_OID, + PgNamespace::NAMESPACE_DEFAULT_NAMESPACE_OID, + {"num_tuples", "num_cols", "num_int_cols", "num_real_cols"}, {INT, INT, INT, INT}, + {INT, INT, INT, INT}, + {PgProc::ArgMode::IN, PgProc::ArgMode::IN, PgProc::ArgMode::IN, PgProc::ArgMode::IN}, INT, "", false); CreateProcedure( txn, proc_oid_t{dbc->next_oid_++}, "nprunnersemitreal", PgLanguage::INTERNAL_LANGUAGE_OID, PgNamespace::NAMESPACE_DEFAULT_NAMESPACE_OID, {"num_tuples", "num_cols", "num_int_cols", "num_real_cols"}, diff --git a/src/include/parser/create_function_statement.h b/src/include/parser/create_function_statement.h index e20b74b5cc..64221d5790 100644 --- a/src/include/parser/create_function_statement.h +++ b/src/include/parser/create_function_statement.h @@ -30,7 +30,8 @@ struct BaseFunctionParameter { VARCHAR, TEXT, BOOL, - BOOLEAN + BOOLEAN, + DATE }; /** @param datatype data type of the parameter */ @@ -41,6 +42,41 @@ struct BaseFunctionParameter { /** @return data type of the parameter */ DataType GetDataType() { return datatype_; } + /** @return internal type id of the parameter */ + static type::TypeId DataTypeToTypeId(DataType datatype) { + switch (datatype) { + case DataType::INT: + return type::TypeId::INTEGER; + case DataType::INTEGER: + return type::TypeId::INTEGER; + case DataType::TINYINT: + return type::TypeId::TINYINT; + case DataType::SMALLINT: + return type::TypeId::SMALLINT; + case DataType::BIGINT: + return type::TypeId::BIGINT; + case DataType::CHAR: + return type::TypeId::INVALID; + case DataType::DOUBLE: + return type::TypeId::DECIMAL; + case DataType::FLOAT: + return type::TypeId::DECIMAL; + case DataType::DECIMAL: + return type::TypeId::DECIMAL; + case DataType::VARCHAR: + return type::TypeId::VARCHAR; + case DataType::TEXT: + return type::TypeId::VARCHAR; + case DataType::BOOL: + return type::TypeId::BOOLEAN; + case DataType::BOOLEAN: + return type::TypeId::BOOLEAN; + case DataType::DATE: + return type::TypeId::DATE; + } + return type::TypeId::INVALID; + } + private: const DataType datatype_; }; diff --git a/src/include/parser/postgresparser.h b/src/include/parser/postgresparser.h index e4407abf0b..c25cd27635 100644 --- a/src/include/parser/postgresparser.h +++ b/src/include/parser/postgresparser.h @@ -85,16 +85,19 @@ class PostgresParser { * Transforms the entire parsed nodes list into a corresponding SQLStatementList. * @param[in,out] parse_result the current parse result, which will be updated * @param root list of parsed nodes + * @param query_string the query string */ - static void ListTransform(ParseResult *parse_result, List *root); + static void ListTransform(ParseResult *parse_result, List *root, const std::string &query_string); /** * Transforms a single node in the parse list into a noisepage SQLStatement object. * @param[in,out] parse_result the current parse result, which will be updated * @param node parsed node + * @param query_string the query string * @return SQLStatement corresponding to the parsed node */ - static std::unique_ptr NodeTransform(ParseResult *parse_result, Node *node); + static std::unique_ptr NodeTransform(ParseResult *parse_result, Node *node, + const std::string &query_string); static std::unique_ptr ExprTransform(ParseResult *parse_result, Node *node, char *alias); static ExpressionType StringToExpressionType(const std::string &parser_str); @@ -133,7 +136,8 @@ class PostgresParser { // CREATE statements static std::unique_ptr CreateTransform(ParseResult *parse_result, CreateStmt *root); static std::unique_ptr CreateDatabaseTransform(ParseResult *parse_result, CreateDatabaseStmt *root); - static std::unique_ptr CreateFunctionTransform(ParseResult *parse_result, CreateFunctionStmt *root); + static std::unique_ptr CreateFunctionTransform(ParseResult *parse_result, CreateFunctionStmt *root, + const std::string &query_string); static std::unique_ptr CreateIndexTransform(ParseResult *parse_result, IndexStmt *root); static std::unique_ptr CreateSchemaTransform(ParseResult *parse_result, CreateSchemaStmt *root); static std::unique_ptr CreateTriggerTransform(ParseResult *parse_result, CreateTrigStmt *root); @@ -172,7 +176,8 @@ class PostgresParser { List *root); // EXPLAIN statements - static std::unique_ptr ExplainTransform(ParseResult *parse_result, ExplainStmt *root); + static std::unique_ptr ExplainTransform(ParseResult *parse_result, ExplainStmt *root, + const std::string &query_string); // INSERT statements static std::unique_ptr InsertTransform(ParseResult *parse_result, InsertStmt *root); @@ -183,7 +188,8 @@ class PostgresParser { ParseResult *parse_result, List *root); // PREPARE statements - static std::unique_ptr PrepareTransform(ParseResult *parse_result, PrepareStmt *root); + static std::unique_ptr PrepareTransform(ParseResult *parse_result, PrepareStmt *root, + const std::string &query_string); static std::unique_ptr TruncateTransform(ParseResult *parse_result, TruncateStmt *truncate_stmt); diff --git a/src/parser/postgresparser.cpp b/src/parser/postgresparser.cpp index f240fa9202..0e4ff3e2fd 100644 --- a/src/parser/postgresparser.cpp +++ b/src/parser/postgresparser.cpp @@ -61,7 +61,7 @@ std::unique_ptr PostgresParser::BuildParseTree(const std::s // Transform the Postgres parse tree to a Terrier representation. auto parse_result = std::make_unique(); try { - ListTransform(parse_result.get(), result.tree); + ListTransform(parse_result.get(), result.tree, query_string); } catch (const Exception &e) { pg_query_parse_finish(ctx); pg_query_free_parse_result(result); @@ -74,16 +74,17 @@ std::unique_ptr PostgresParser::BuildParseTree(const std::s return parse_result; } -void PostgresParser::ListTransform(ParseResult *parse_result, List *root) { +void PostgresParser::ListTransform(ParseResult *parse_result, List *root, const std::string &query_string) { if (root != nullptr) { for (auto cell = root->head; cell != nullptr; cell = cell->next) { auto node = static_cast(cell->data.ptr_value); - parse_result->AddStatement(NodeTransform(parse_result, node)); + parse_result->AddStatement(NodeTransform(parse_result, node, query_string)); } } } -std::unique_ptr PostgresParser::NodeTransform(ParseResult *parse_result, Node *node) { +std::unique_ptr PostgresParser::NodeTransform(ParseResult *parse_result, Node *node, + const std::string &query_string) { // TODO(WAN): Document what input is parsed to nullptr if (node == nullptr) { return nullptr; @@ -104,7 +105,7 @@ std::unique_ptr PostgresParser::NodeTransform(ParseResult *parse_r break; } case T_CreateFunctionStmt: { - result = CreateFunctionTransform(parse_result, reinterpret_cast(node)); + result = CreateFunctionTransform(parse_result, reinterpret_cast(node), query_string); break; } case T_CreateSchemaStmt: { @@ -128,7 +129,7 @@ std::unique_ptr PostgresParser::NodeTransform(ParseResult *parse_r break; } case T_ExplainStmt: { - result = ExplainTransform(parse_result, reinterpret_cast(node)); + result = ExplainTransform(parse_result, reinterpret_cast(node), query_string); break; } case T_IndexStmt: { @@ -140,7 +141,7 @@ std::unique_ptr PostgresParser::NodeTransform(ParseResult *parse_r break; } case T_PrepareStmt: { - result = PrepareTransform(parse_result, reinterpret_cast(node)); + result = PrepareTransform(parse_result, reinterpret_cast(node), query_string); break; } case T_SelectStmt: { @@ -1275,21 +1276,24 @@ std::unique_ptr PostgresParser::CreateDatabaseTransform(Pa // Postgres.CreateFunctionStmt -> noisepage.CreateFunctionStatement std::unique_ptr PostgresParser::CreateFunctionTransform(ParseResult *parse_result, - CreateFunctionStmt *root) { + CreateFunctionStmt *root, + const std::string &query_string) { bool replace = root->replace_; - std::vector> func_parameters; - - for (auto cell = root->parameters_->head; cell != nullptr; cell = cell->next) { - auto node = reinterpret_cast(cell->data.ptr_value); - switch (node->type) { - case T_FunctionParameter: { - func_parameters.emplace_back( - FunctionParameterTransform(parse_result, reinterpret_cast(node))); - break; - } - default: { - // TODO(WAN): previous code just ignored it, is this right? - break; + std::vector> func_parameters{}; + if (root->parameters_ != nullptr) { + for (auto cell = root->parameters_->head; cell != nullptr; cell = cell->next) { + auto node = reinterpret_cast(cell->data.ptr_value); + switch (node->type) { + case T_FunctionParameter: { + func_parameters.emplace_back( + FunctionParameterTransform(parse_result, reinterpret_cast(node))); + break; + } + default: { + // TODO(WAN): previous code just ignored it, is this right? + // TODO(Kyle): Good question^ + break; + } } } } @@ -1299,7 +1303,8 @@ std::unique_ptr PostgresParser::CreateFunctionTransform(ParseResul // TODO(WAN): assumption from old code, can only pass one function name for now std::string func_name = (reinterpret_cast(root->funcname_->tail->data.ptr_value)->val_.str_); - std::vector func_body; + std::vector func_body{}; + func_body.push_back(std::string(query_string.c_str())); AsType as_type = AsType::INVALID; PLType pl_type = PLType::INVALID; @@ -1313,7 +1318,7 @@ std::unique_ptr PostgresParser::CreateFunctionTransform(ParseResul func_body.push_back(query_string); } - if (func_body.size() > 1) { + if (func_body.size() > 2) { as_type = AsType::EXECUTABLE; } else { as_type = AsType::QUERY_STRING; @@ -1669,6 +1674,8 @@ std::unique_ptr PostgresParser::ReturnTypeTransform(ParseResult *par data_type = BaseFunctionParameter::DataType::TINYINT; } else if (strcmp(name, "bool") == 0) { data_type = BaseFunctionParameter::DataType::BOOL; + } else if (strcmp(name, "date") == 0) { + data_type = BaseFunctionParameter::DataType::DATE; } else { PARSER_LOG_AND_THROW("ReturnTypeTransform", "ReturnType", name); } @@ -1873,9 +1880,10 @@ std::vector> PostgresParser::ParamLis return result; } -std::unique_ptr PostgresParser::ExplainTransform(ParseResult *parse_result, ExplainStmt *root) { +std::unique_ptr PostgresParser::ExplainTransform(ParseResult *parse_result, ExplainStmt *root, + const std::string &query_string) { std::unique_ptr result; - auto query = NodeTransform(parse_result, root->query_); + auto query = NodeTransform(parse_result, root->query_, query_string); result = std::make_unique(std::move(query)); return result; } @@ -2005,9 +2013,10 @@ std::vector> PostgresParser::UpdateTargetTransform } // Postgres.PrepareStmt -> noisepage.PrepareStatement -std::unique_ptr PostgresParser::PrepareTransform(ParseResult *parse_result, PrepareStmt *root) { +std::unique_ptr PostgresParser::PrepareTransform(ParseResult *parse_result, PrepareStmt *root, + const std::string &query_string) { auto name = root->name_; - auto query = NodeTransform(parse_result, root->query_); + auto query = NodeTransform(parse_result, root->query_, query_string); // TODO(WAN): This should probably be populated? std::vector> placeholders; diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index 4b7d699833..c06d61968c 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -197,7 +197,9 @@ TrafficCopResult TrafficCop::ExecuteCreateStatement( NOISEPAGE_ASSERT( query_type == network::QueryType::QUERY_CREATE_TABLE || query_type == network::QueryType::QUERY_CREATE_SCHEMA || query_type == network::QueryType::QUERY_CREATE_INDEX || query_type == network::QueryType::QUERY_CREATE_DB || - query_type == network::QueryType::QUERY_CREATE_VIEW || query_type == network::QueryType::QUERY_CREATE_TRIGGER || query_type == network::QueryType::QUERY_CREATE_FUNCTION, + query_type == network::QueryType::QUERY_CREATE_VIEW || + query_type == network::QueryType::QUERY_CREATE_TRIGGER || + query_type == network::QueryType::QUERY_CREATE_FUNCTION, "ExecuteCreateStatement called with invalid QueryType."); switch (query_type) { case network::QueryType::QUERY_CREATE_TABLE: { diff --git a/test/catalog/catalog_test.cpp b/test/catalog/catalog_test.cpp index 8ec8e3729b..9aa65b8948 100644 --- a/test/catalog/catalog_test.cpp +++ b/test/catalog/catalog_test.cpp @@ -139,8 +139,8 @@ TEST_F(CatalogTests, ProcTest) { accessor->GetTypeOidFromTypeId(type::TypeId::BOOLEAN), accessor->GetTypeOidFromTypeId(type::TypeId::SMALLINT)}; std::vector arg_modes = {catalog::postgres::PgProc::ArgMode::IN, - catalog::postgres::PgProc::ArgMode::IN, - catalog::postgres::PgProc::ArgMode::IN}; + catalog::postgres::PgProc::ArgMode::IN, + catalog::postgres::PgProc::ArgMode::IN}; auto src = "int sample(arg1, arg2, arg3){return 2;}"; auto proc_oid = From 731f2b9b11e5ebfbdf18e0b860eda099481fada9 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sat, 3 Apr 2021 10:33:18 -0400 Subject: [PATCH 007/139] pull in udf specific files in parser, ast, and code generation --- src/execution/compiler/udf/udf_codegen.cpp | 653 ++++++++++++++++++ .../execution/ast/udf/udf_ast_context.h | 53 ++ .../execution/ast/udf/udf_ast_node_visitor.h | 57 ++ src/include/execution/ast/udf/udf_ast_nodes.h | 237 +++++++ .../execution/compiler/udf/udf_codegen.h | 101 +++ src/include/parser/udf/udf_parser.h | 51 ++ src/parser/udf/udf_parser.cpp | 347 ++++++++++ 7 files changed, 1499 insertions(+) create mode 100644 src/execution/compiler/udf/udf_codegen.cpp create mode 100644 src/include/execution/ast/udf/udf_ast_context.h create mode 100644 src/include/execution/ast/udf/udf_ast_node_visitor.h create mode 100644 src/include/execution/ast/udf/udf_ast_nodes.h create mode 100644 src/include/execution/compiler/udf/udf_codegen.h create mode 100644 src/include/parser/udf/udf_parser.h create mode 100644 src/parser/udf/udf_parser.cpp diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp new file mode 100644 index 0000000000..35838df5dc --- /dev/null +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -0,0 +1,653 @@ +#include "common/error/exception.h" + +#include "binder/bind_node_visitor.h" + +#include "execution/ast/ast.h" + +// TODO(Kyle): Not Ported Yet +// #include "execution/ast/ast_clone.h" + +#include "execution/compiler/compilation_context.h" +#include "execution/compiler/executable_query.h" +#include "execution/compiler/if.h" +#include "execution/compiler/loop.h" +#include "execution/exec/execution_settings.h" + +#include "catalog/catalog_accessor.h" +#include "optimizer/cost_model/trivial_cost_model.h" +#include "optimizer/statistics/stats_storage.h" + +#include "traffic_cop/traffic_cop_util.h" + +#include "parser/expression/constant_value_expression.h" +#include "parser/postgresparser.h" + +#include "execution/ast/udf/udf_ast_nodes.h" +#include "execution/compiler/udf/udf_codegen.h" + +#include "planner/plannodes/abstract_plan_node.h" + +// TODO(Kyle): Documentation. + +namespace noisepage { +namespace execution { +namespace compiler { +namespace udf { + +UDFCodegen::UDFCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, + ast::udf::UDFASTContext *udf_ast_context, CodeGen *codegen, catalog::db_oid_t db_oid) + : accessor_{accessor}, + fb_{fb}, + udf_ast_context_{udf_ast_context}, + codegen_{codegen}, + aux_decls_(codegen->GetAstContext()->GetRegion()), + db_oid_{db_oid} { + for (size_t i = 0; fb->GetParameterByPosition(i) != nullptr; i++) { + auto param = fb->GetParameterByPosition(i); + const auto &name = param->As()->Name(); + str_to_ident_.emplace(name.GetString(), name); + } +} + +const char *UDFCodegen::GetReturnParamString() { return "return_val"; } + +void UDFCodegen::GenerateUDF(AbstractAST *ast) { ast->Accept(this); } + +void UDFCodegen::Visit(AbstractAST *ast) { UNREACHABLE("Not implemented"); } + +void UDFCodegen::Visit(DynamicSQLStmtAST *ast) { UNREACHABLE("Not implemented"); } + +catalog::type_oid_t UDFCodegen::GetCatalogTypeOidFromSQLType(execution::ast::BuiltinType::Kind type) { + switch (type) { + case execution::ast::BuiltinType::Kind::Integer: { + return accessor_->GetTypeOidFromTypeId(type::TypeId::INTEGER); + } + case execution::ast::BuiltinType::Kind::Boolean: { + return accessor_->GetTypeOidFromTypeId(type::TypeId::BOOLEAN); + } + default: + return accessor_->GetTypeOidFromTypeId(type::TypeId::INVALID); + NOISEPAGE_ASSERT(false, "Unsupported param type"); + } +} + +void UDFCodegen::Visit(CallExprAST *ast) { + // UNREACHABLE("Not implemented"); + auto &args = ast->args; + std::vector args_ast; + std::vector args_ast_region_vec; + std::vector arg_types; + + for (auto &arg : args) { + arg->Accept(this); + args_ast.push_back(dst_); + args_ast_region_vec.push_back(dst_); + auto *builtin = dst_->GetType()->SafeAs(); + NOISEPAGE_ASSERT(builtin != nullptr, "Not builtin parameter"); + NOISEPAGE_ASSERT(builtin->IsSqlValueType(), "Param is not SQL Value Type"); + arg_types.push_back(GetCatalogTypeOidFromSQLType(builtin->GetKind())); + } + auto proc_oid = accessor_->GetProcOid(ast->callee, arg_types); + NOISEPAGE_ASSERT(proc_oid != catalog::INVALID_PROC_OID, "Invalid call"); + auto context = accessor_->GetProcCtxPtr(proc_oid); + if (context->IsBuiltin()) { + fb_->Append(codegen_->MakeStmt(codegen_->CallBuiltin(context->GetBuiltin(), std::move(args_ast)))); + } else { + auto it = str_to_ident_.find(ast->callee); + execution::ast::Identifier ident_expr; + if (it != str_to_ident_.end()) { + ident_expr = it->second; + } else { + auto file = reinterpret_cast( + execution::ast::AstClone::Clone(context->GetFile(), codegen_->GetAstContext()->GetNodeFactory(), "", + context->GetASTContext(), codegen_->GetAstContext().Get())); + for (auto decl : file->Declarations()) { + aux_decls_.push_back(decl); + } + ident_expr = codegen_->MakeFreshIdentifier(file->Declarations().back()->Name().GetString()); + str_to_ident_[file->Declarations().back()->Name().GetString()] = ident_expr; + } + fb_->Append(codegen_->MakeStmt(codegen_->Call(ident_expr, args_ast_region_vec))); + } + // fb_->Append(codegen_->Call) +} + +void UDFCodegen::Visit(StmtAST *ast) { UNREACHABLE("Not implemented"); } + +void UDFCodegen::Visit(ExprAST *ast) { UNREACHABLE("Not implemented"); } + +void UDFCodegen::Visit(DeclStmtAST *ast) { + if (ast->name == "*internal*") { + return; + } + execution::ast::Identifier ident = codegen_->MakeFreshIdentifier(ast->name); + str_to_ident_.emplace(ast->name, ident); + auto prev_type = current_type_; + execution::ast::Expr *tpl_type = nullptr; + if (ast->type == type::TypeId::INVALID) { + // record type + execution::util::RegionVector fields(codegen_->GetAstContext()->GetRegion()); + for (auto p : udf_ast_context_->GetRecordType(ast->name)) { + fields.push_back(codegen_->MakeField(codegen_->MakeIdentifier(p.first), + codegen_->TplType(execution::sql::GetTypeId(p.second)))); + } + auto record_decl = codegen_->DeclareStruct(codegen_->MakeFreshIdentifier("rectype"), std::move(fields)); + aux_decls_.push_back(record_decl); + tpl_type = record_decl->TypeRepr(); + } else { + tpl_type = codegen_->TplType(execution::sql::GetTypeId(ast->type)); + } + current_type_ = ast->type; + if (ast->initial != nullptr) { + // Visit(ast->initial.get()); + ast->initial->Accept(this); + fb_->Append(codegen_->DeclareVar(ident, tpl_type, dst_)); + } else { + fb_->Append(codegen_->DeclareVarNoInit(ident, tpl_type)); + } + current_type_ = prev_type; +} + +void UDFCodegen::Visit(FunctionAST *ast) { + for (size_t i = 0; i < ast->param_types_.size(); i++) { + // auto param_type = codegen_->TplType(ast->param_types_[i]); + str_to_ident_.emplace(ast->param_names_[i], codegen_->MakeFreshIdentifier("udf")); + } + ast->body.get()->Accept(this); +} + +void UDFCodegen::Visit(VariableExprAST *ast) { + auto it = str_to_ident_.find(ast->name); + NOISEPAGE_ASSERT(it != str_to_ident_.end(), "variable not declared"); + dst_ = codegen_->MakeExpr(it->second); +} + +void UDFCodegen::Visit(ValueExprAST *ast) { + auto val = common::ManagedPointer(ast->value_).CastManagedPointerTo(); + if (val->IsNull()) { + dst_ = codegen_->ConstNull(current_type_); + return; + } + auto type_id = execution::sql::GetTypeId(val->GetReturnValueType()); + switch (type_id) { + case execution::sql::TypeId::Boolean: + dst_ = codegen_->BoolToSql(val->GetBoolVal().val_); + break; + case execution::sql::TypeId::TinyInt: + case execution::sql::TypeId::SmallInt: + case execution::sql::TypeId::Integer: + case execution::sql::TypeId::BigInt: + dst_ = codegen_->IntToSql(val->GetInteger().val_); + break; + case execution::sql::TypeId::Float: + case execution::sql::TypeId::Double: + dst_ = codegen_->FloatToSql(val->GetReal().val_); + case execution::sql::TypeId::Date: + dst_ = codegen_->DateToSql(val->GetDateVal().val_); + break; + case execution::sql::TypeId::Timestamp: + dst_ = codegen_->TimestampToSql(val->GetTimestampVal().val_); + break; + case execution::sql::TypeId::Varchar: + dst_ = codegen_->StringToSql(val->GetStringVal().StringView()); + break; + default: + throw NOT_IMPLEMENTED_EXCEPTION("Unsupported type in UDF codegen"); + } +} + +void UDFCodegen::Visit(AssignStmtAST *ast) { + type::TypeId left_type = type::TypeId::INVALID; + udf_ast_context_->GetVariableType(ast->lhs->name, &left_type); + current_type_ = left_type; + + reinterpret_cast(ast->rhs.get())->Accept(this); + auto rhs_expr = dst_; + + auto it = str_to_ident_.find(ast->lhs->name); + NOISEPAGE_ASSERT(it != str_to_ident_.end(), "Variable not found"); + auto left_codegen_ident = it->second; + + auto *left_expr = codegen_->MakeExpr(left_codegen_ident); + + // auto right_type = rhs_expr->GetType()->GetTypeId(); + + // if (left_type == type::TypeId::VARCHAR) { + fb_->Append(codegen_->Assign(left_expr, rhs_expr)); + // } +} + +void UDFCodegen::Visit(BinaryExprAST *ast) { + execution::parsing::Token::Type op_token; + bool compare = false; + switch (ast->op) { + case noisepage::parser::ExpressionType::OPERATOR_DIVIDE: + op_token = execution::parsing::Token::Type::SLASH; + break; + case noisepage::parser::ExpressionType::OPERATOR_PLUS: + op_token = execution::parsing::Token::Type::PLUS; + break; + case noisepage::parser::ExpressionType::OPERATOR_MINUS: + op_token = execution::parsing::Token::Type::MINUS; + break; + case noisepage::parser::ExpressionType::OPERATOR_MULTIPLY: + op_token = execution::parsing::Token::Type::STAR; + break; + case noisepage::parser::ExpressionType::OPERATOR_MOD: + op_token = execution::parsing::Token::Type::PERCENT; + break; + case noisepage::parser::ExpressionType::CONJUNCTION_OR: + op_token = execution::parsing::Token::Type::OR; + break; + case noisepage::parser::ExpressionType::CONJUNCTION_AND: + op_token = execution::parsing::Token::Type::AND; + break; + case noisepage::parser::ExpressionType::COMPARE_GREATER_THAN: + compare = true; + op_token = execution::parsing::Token::Type::GREATER; + break; + case noisepage::parser::ExpressionType::COMPARE_GREATER_THAN_OR_EQUAL_TO: + compare = true; + op_token = execution::parsing::Token::Type::GREATER_EQUAL; + break; + case noisepage::parser::ExpressionType::COMPARE_LESS_THAN_OR_EQUAL_TO: + compare = true; + op_token = execution::parsing::Token::Type::LESS_EQUAL; + break; + case noisepage::parser::ExpressionType::COMPARE_LESS_THAN: + compare = true; + op_token = execution::parsing::Token::Type::LESS; + break; + case noisepage::parser::ExpressionType::COMPARE_EQUAL: + compare = true; + op_token = execution::parsing::Token::Type::EQUAL_EQUAL; + break; + default: + // TODO(tanujnay112): figure out concatenation operation from expressions? + UNREACHABLE("Unsupported expression"); + } + ast->lhs->Accept(this); + auto lhs_expr = dst_; + + ast->rhs->Accept(this); + auto rhs_expr = dst_; + if (compare) { + dst_ = codegen_->Compare(op_token, lhs_expr, rhs_expr); + } else { + dst_ = codegen_->BinaryOp(op_token, lhs_expr, rhs_expr); + } +} + +void UDFCodegen::Visit(IfStmtAST *ast) { + ast->cond_expr->Accept(this); + auto cond = dst_; + + If branch(fb_, cond); + ast->then_stmt->Accept(this); + if (ast->else_stmt != nullptr) { + branch.Else(); + ast->else_stmt->Accept(this); + } + branch.EndIf(); +} + +void UDFCodegen::Visit(IsNullExprAST *ast) { + ast->child_->Accept(this); + auto chld = dst_; + dst_ = codegen_->CallBuiltin(execution::ast::Builtin::IsValNull, {chld}); + if (!ast->is_null_check_) { + dst_ = codegen_->UnaryOp(execution::parsing::Token::Type::BANG, dst_); + } +} + +void UDFCodegen::Visit(SeqStmtAST *ast) { + for (auto &stmt : ast->stmts) { + stmt->Accept(this); + } +} + +void UDFCodegen::Visit(WhileStmtAST *ast) { + ast->cond_expr->Accept(this); + auto cond = dst_; + // cond = codegen_->Compare(execution::parsing::Token::Type::EQUAL_EQUAL, cond, ) + // cond = codegen_->CallBuiltin(execution::ast::Builtin::SqlToBool, {cond}); + Loop loop(fb_, cond); + ast->body_stmt->Accept(this); + loop.EndLoop(); +} + +void UDFCodegen::Visit(ForStmtAST *ast) { + needs_exec_ctx_ = true; + const auto query = common::ManagedPointer(ast->query_); + auto exec_ctx = fb_->GetParameterByPosition(0); + + // TODO(Matt): I don't think the binder should need the database name. It's already bound in the ConnectionContext + binder::BindNodeVisitor visitor(common::ManagedPointer(accessor_), db_oid_); + auto query_params = visitor.BindAndGetUDFParams(query, common::ManagedPointer(udf_ast_context_)); + + auto stats = optimizer::StatsStorage(); + + std::unique_ptr plan = trafficcop::TrafficCopUtil::Optimize( + accessor_->GetTxn(), common::ManagedPointer(accessor_), query, db_oid_, common::ManagedPointer(&stats), + std::make_unique(), 1000000); + // make lambda that just writes into this + std::vector var_idents; + auto lam_var = codegen_->MakeFreshIdentifier("looplamb"); + execution::util::RegionVector params(codegen_->GetAstContext()->GetRegion()); + params.push_back(codegen_->MakeField( + exec_ctx->As()->Name(), + codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); + size_t i = 0; + for (auto var : ast->vars_) { + var_idents.push_back(str_to_ident_.find(var)->second); + auto var_ident = var_idents.back(); + // NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, "Can't support non scalars yet!"); + auto type = codegen_->TplType(execution::sql::GetTypeId(plan->GetOutputSchema()->GetColumn(i).GetType())); + + fb_->Append(codegen_->Assign(codegen_->MakeExpr(var_ident), + codegen_->ConstNull(plan->GetOutputSchema()->GetColumn(i).GetType()))); + auto input = codegen_->MakeFreshIdentifier(var); + params.push_back(codegen_->MakeField(input, type)); + i++; + } + execution::ast::LambdaExpr *lambda_expr; + FunctionBuilder fn(codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)); + { + size_t j = 1; + for (auto var : var_idents) { + fn.Append(codegen_->Assign(codegen_->MakeExpr(var), fn.GetParameterByPosition(j))); + j++; + } + auto prev_fb = fb_; + fb_ = &fn; + ast->body_stmt_->Accept(this); + fb_ = prev_fb; + } + + execution::util::RegionVector captures(codegen_->GetAstContext()->GetRegion()); + for (auto it : str_to_ident_) { + if (it.first == "executionCtx") { + continue; + } + captures.push_back(codegen_->MakeExpr(it.second)); + } + + lambda_expr = fn.FinishLambda(std::move(captures)); + lambda_expr->SetName(lam_var); + + // want to pass something down that will materialize the lambda function for me into lambda_expr and will + // also feed in a lambda_expr to the compiler + execution::exec::ExecutionSettings exec_settings{}; + const std::string dummy_query = ""; + auto exec_query = execution::compiler::CompilationContext::Compile( + *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, + common::ManagedPointer(&dummy_query), lambda_expr, codegen_->GetAstContext()); + auto fns = exec_query->GetFunctions(); + auto decls = exec_query->GetDecls(); + + aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); + + fb_->Append( + codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLitExpr()->TypeRepr()), lambda_expr)); + + // make query state + auto query_state = codegen_->MakeFreshIdentifier("query_state"); + fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); + // set its execution context to whatever exec context was passed in here + fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); + std::vector>::iterator> sorted_vec; + for (auto it = query_params.begin(); it != query_params.end(); it++) { + sorted_vec.push_back(it); + } + + std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second < y->second; }); + for (auto entry : sorted_vec) { + // TODO(order these dudes) + type::TypeId type = type::TypeId::INVALID; + udf_ast_context_->GetVariableType(entry->first, &type); + // NOISEPAGE_ASSERT(ret, "didn't find param in udf ast context"); + + execution::ast::Builtin builtin; + switch (type) { + case type::TypeId::BOOLEAN: + builtin = execution::ast::Builtin::AddParamBool; + break; + case type::TypeId::TINYINT: + builtin = execution::ast::Builtin::AddParamTinyInt; + break; + case type::TypeId::SMALLINT: + builtin = execution::ast::Builtin::AddParamSmallInt; + break; + case type::TypeId::INTEGER: + builtin = execution::ast::Builtin::AddParamInt; + break; + case type::TypeId::BIGINT: + builtin = execution::ast::Builtin::AddParamBigInt; + break; + case type::TypeId::DECIMAL: + builtin = execution::ast::Builtin::AddParamDouble; + break; + case type::TypeId::DATE: + builtin = execution::ast::Builtin::AddParamDate; + break; + case type::TypeId::TIMESTAMP: + builtin = execution::ast::Builtin::AddParamTimestamp; + break; + case type::TypeId::VARCHAR: + builtin = execution::ast::Builtin::AddParamString; + break; + default: + UNREACHABLE("Unsupported parameter type"); + } + fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, codegen_->MakeExpr(str_to_ident_[entry->first])})); + } + // set param 1 + // set param 2 + // etc etc + fb_->Append(codegen_->Assign( + codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); + // set its execution context to whatever exec context was passed in here + + for (auto &sub_fn : fns) { + // aux_decls_.push_back(c) + if (sub_fn.find("Run") != std::string::npos) { + fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), + {codegen_->AddressOf(query_state), codegen_->MakeExpr(lam_var)})); + } else { + fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), {codegen_->AddressOf(query_state)})); + } + } + + fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); + + return; +} + +void UDFCodegen::Visit(RetStmtAST *ast) { + ast->expr->Accept(reinterpret_cast(this)); + auto ret_expr = dst_; + fb_->Append(codegen_->Return(ret_expr)); +} + +void UDFCodegen::Visit(SQLStmtAST *ast) { + needs_exec_ctx_ = true; + auto exec_ctx = fb_->GetParameterByPosition(0); + const auto query = common::ManagedPointer(ast->query); + + // TODO(Matt): I don't think the binder should need the database name. It's already bound in the ConnectionContext + binder::BindNodeVisitor visitor(common::ManagedPointer(accessor_), db_oid_); + // auto query_params = visitor.BindAndGetUDFParams(query, common::ManagedPointer(udf_ast_context_)); + auto query_params = ast->udf_params; + auto stats = optimizer::StatsStorage(); + + std::unique_ptr plan = trafficcop::TrafficCopUtil::Optimize( + accessor_->GetTxn(), common::ManagedPointer(accessor_), query, db_oid_, common::ManagedPointer(&stats), + std::make_unique(), 1000000); + // make lambda that just writes into this + + auto lam_var = codegen_->MakeFreshIdentifier("lamb"); + // NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, "Can't support non scalars yet!"); + auto &cols = plan->GetOutputSchema()->GetColumns(); + // auto &col = cols[0]; + execution::util::RegionVector params(codegen_->GetAstContext()->GetRegion()); + std::vector assignees; + execution::util::RegionVector captures(codegen_->GetAstContext()->GetRegion()); + size_t i = 0; + params.push_back(codegen_->MakeField( + exec_ctx->As()->Name(), + codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); + for (auto &col : cols) { + execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->var_name)->second); + type::TypeId udf_type; + udf_ast_context_->GetVariableType(ast->var_name, &udf_type); + if (udf_type == type::TypeId::INVALID) { + // record type + auto &struct_vars = udf_ast_context_->GetRecordType(ast->var_name); + if (captures.empty()) { + captures.push_back(capture_var); + } + capture_var = codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(struct_vars[i].first)); + assignees.push_back(capture_var); + } else { + assignees.push_back(capture_var); + captures.push_back(capture_var); + } + // auto capture_var = str_to_ident_.find(ast->var_name)->second; + auto type = codegen_->TplType(execution::sql::GetTypeId(col.GetType())); + + auto input_param = codegen_->MakeFreshIdentifier("input"); + params.push_back(codegen_->MakeField(input_param, type)); + i++; + } + + execution::ast::LambdaExpr *lambda_expr; + FunctionBuilder fn(codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)); + { + for (size_t j = 0; j < assignees.size(); j++) { + auto capture_var = assignees[j]; + auto input_param = fn.GetParameterByPosition(j + 1); + fn.Append(codegen_->Assign(capture_var, input_param)); + } + } + + lambda_expr = fn.FinishLambda(std::move(captures)); + lambda_expr->SetName(lam_var); + + // want to pass something down that will materialize the lambda function for me into lambda_expr and will + // also feed in a lambda_expr to the compiler + execution::exec::ExecutionSettings exec_settings{}; + const std::string dummy_query = ""; + auto exec_query = execution::compiler::CompilationContext::Compile( + *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, + common::ManagedPointer(&dummy_query), lambda_expr, codegen_->GetAstContext()); + auto fns = exec_query->GetFunctions(); + auto decls = exec_query->GetDecls(); + + aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); + + fb_->Append( + codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLitExpr()->TypeRepr()), lambda_expr)); + + // make query state + auto query_state = codegen_->MakeFreshIdentifier("query_state"); + fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); + // set its execution context to whatever exec context was passed in here + fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); + std::vector>::iterator> sorted_vec; + for (auto it = query_params.begin(); it != query_params.end(); it++) { + sorted_vec.push_back(it); + } + + std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second.second < y->second.second; }); + for (auto entry : sorted_vec) { + // TODO(order these dudes) + type::TypeId type = type::TypeId::INVALID; + execution::ast::Expr *expr = nullptr; + if (entry->second.first.length() > 0) { + auto &fields = udf_ast_context_->GetRecordType(entry->second.first); + auto it = std::find_if(fields.begin(), fields.end(), [=](auto p) { return p.first == entry->first; }); + type = it->second; + expr = codegen_->AccessStructMember(codegen_->MakeExpr(str_to_ident_[entry->second.first]), + codegen_->MakeIdentifier(entry->first)); + } else { + udf_ast_context_->GetVariableType(entry->first, &type); + expr = codegen_->MakeExpr(str_to_ident_[entry->first]); + } + + // NOISEPAGE_ASSERT(ret, "didn't find param in udf ast context"); + execution::ast::Builtin builtin; + switch (type) { + case type::TypeId::BOOLEAN: + builtin = execution::ast::Builtin::AddParamBool; + break; + case type::TypeId::TINYINT: + builtin = execution::ast::Builtin::AddParamTinyInt; + break; + case type::TypeId::SMALLINT: + builtin = execution::ast::Builtin::AddParamSmallInt; + break; + case type::TypeId::INTEGER: + builtin = execution::ast::Builtin::AddParamInt; + break; + case type::TypeId::BIGINT: + builtin = execution::ast::Builtin::AddParamBigInt; + break; + case type::TypeId::DECIMAL: + builtin = execution::ast::Builtin::AddParamDouble; + break; + case type::TypeId::DATE: + builtin = execution::ast::Builtin::AddParamDate; + break; + case type::TypeId::TIMESTAMP: + builtin = execution::ast::Builtin::AddParamTimestamp; + break; + case type::TypeId::VARCHAR: + builtin = execution::ast::Builtin::AddParamString; + break; + default: + UNREACHABLE("Unsupported parameter type"); + } + fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, expr})); + } + // set param 1 + // set param 2 + // etc etc + fb_->Append(codegen_->Assign( + codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); + + for (auto &col : cols) { + execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->var_name)->second); + auto lhs = capture_var; + if (cols.size() > 1) { + // record struct type + lhs = codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(col.GetName())); + } + fb_->Append(codegen_->Assign(lhs, codegen_->ConstNull(col.GetType()))); + } + // set its execution context to whatever exec context was passed in here + + for (auto &sub_fn : fns) { + // aux_decls_.push_back(c) + if (sub_fn.find("Run") != std::string::npos) { + fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), + {codegen_->AddressOf(query_state), codegen_->MakeExpr(lam_var)})); + } else { + fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), {codegen_->AddressOf(query_state)})); + } + } + + fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); + + return; +} + +void UDFCodegen::Visit(MemberExprAST *ast) { + ast->object->Accept(reinterpret_cast(this)); + auto object = dst_; + dst_ = codegen_->AccessStructMember(object, codegen_->MakeIdentifier(ast->field)); +} + +} // namespace udf +} // namespace compiler +} // namespace execution +} // namespace noisepage diff --git a/src/include/execution/ast/udf/udf_ast_context.h b/src/include/execution/ast/udf/udf_ast_context.h new file mode 100644 index 0000000000..6d9105b174 --- /dev/null +++ b/src/include/execution/ast/udf/udf_ast_context.h @@ -0,0 +1,53 @@ +#pragma once + +#include "type/type_id.h" + +// TODO(Kyle): Documentation. + +namespace noisepage { +namespace execution { +namespace ast { +namespace udf { + +class UDFASTContext { + public: + UDFASTContext() {} + + void SetVariableType(std::string &var, type::TypeId type) { symbol_table_[var] = type; } + + bool GetVariableType(const std::string &var, type::TypeId *type) { + auto it = symbol_table_.find(var); + if (it == symbol_table_.end()) { + return false; + } + if (type != nullptr) { + *type = it->second; + } + return true; + } + + void AddVariable(std::string name) { local_variables_.push_back(name); } + + const std::string &GetVariableAtIndex(int index) { + NOISEPAGE_ASSERT(local_variables_.size() >= index, "Bad var"); + return local_variables_[index - 1]; + } + + void SetRecordType(std::string var, std::vector> &&elems) { + record_types_[var] = std::move(elems); + } + + const std::vector> &GetRecordType(const std::string &var) { + return record_types_.find(var)->second; + } + + private: + std::unordered_map symbol_table_; + std::vector local_variables_; + std::unordered_map>> record_types_; +}; + +} // namespace udf +} // namespace ast +} // namespace execution +} // namespace noisepage diff --git a/src/include/execution/ast/udf/udf_ast_node_visitor.h b/src/include/execution/ast/udf/udf_ast_node_visitor.h new file mode 100644 index 0000000000..9966d39c61 --- /dev/null +++ b/src/include/execution/ast/udf/udf_ast_node_visitor.h @@ -0,0 +1,57 @@ +#pragma once + +// TODO(Kyle): This whole file needs documentation. + +namespace noisepage { +namespace execution { +namespace ast { +namespace udf { + +class AbstractAST; +class StmtAST; +class ExprAST; +class ValueExprAST; +class IsNullExprAST; +class VariableExprAST; +class BinaryExprAST; +class CallExprAST; +class MemberExprAST; +class SeqStmtAST; +class DeclStmtAST; +class IfStmtAST; +class WhileStmtAST; +class RetStmtAST; +class AssignStmtAST; +class SQLStmtAST; +class DynamicSQLStmtAST; +class ForStmtAST; +class FunctionAST; + +class ASTNodeVisitor { + public: + virtual ~ASTNodeVisitor(){}; + + virtual void Visit(AbstractAST *){}; + virtual void Visit(StmtAST *){}; + virtual void Visit(ExprAST *){}; + virtual void Visit(FunctionAST *){}; + virtual void Visit(ValueExprAST *){}; + virtual void Visit(VariableExprAST *){}; + virtual void Visit(BinaryExprAST *){}; + virtual void Visit(IsNullExprAST *){}; + virtual void Visit(CallExprAST *){}; + virtual void Visit(MemberExprAST *){}; + virtual void Visit(SeqStmtAST *){}; + virtual void Visit(DeclStmtAST *){}; + virtual void Visit(IfStmtAST *){}; + virtual void Visit(WhileStmtAST *){}; + virtual void Visit(RetStmtAST *){}; + virtual void Visit(AssignStmtAST *){}; + virtual void Visit(ForStmtAST *){}; + virtual void Visit(SQLStmtAST *){}; + virtual void Visit(DynamicSQLStmtAST *){}; +}; +} // namespace udf +} // namespace ast +} // namespace execution +} // namespace noisepage diff --git a/src/include/execution/ast/udf/udf_ast_nodes.h b/src/include/execution/ast/udf/udf_ast_nodes.h new file mode 100644 index 0000000000..52fd93e87e --- /dev/null +++ b/src/include/execution/ast/udf/udf_ast_nodes.h @@ -0,0 +1,237 @@ +#pragma once + +#include "parser/expression/constant_value_expression.h" +#include "parser/expression_defs.h" +#include "type/type_id.h" + +#include "execution/ast/udf/udf_ast_node_visitor.h" +#include "execution/sql/value.h" + +namespace noisepage { +namespace execution { +namespace ast { +namespace udf { + +// AbstractAST - Base class for all AST nodes. +class AbstractAST { + public: + virtual ~AbstractAST() = default; + + virtual void Accept(ASTNodeVisitor *visitor) { visitor->Visit(this); }; +}; + +// StmtAST - Base class for all statement nodes. +class StmtAST : public AbstractAST { + public: + virtual ~StmtAST() = default; + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; +}; + +// ExprAST - Base class for all expression nodes. +class ExprAST : public StmtAST { + public: + virtual ~ExprAST() = default; + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; +}; + +// DoubleExprAST - Expression class for numeric literals like "1.1". +class ValueExprAST : public ExprAST { + public: + std::unique_ptr value_; + + ValueExprAST(std::unique_ptr value) : value_(std::move(value)) {} + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; +}; + +class IsNullExprAST : public ExprAST { + public: + bool is_null_check_; + std::unique_ptr child_; + + IsNullExprAST(bool is_null_check, std::unique_ptr child) + : is_null_check_(is_null_check), child_(std::move(child)) {} + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; +}; + +// VariableExprAST - Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + public: + std::string name; + + VariableExprAST(const std::string &name) : name(name) {} + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; +}; + +// VariableExprAST - Expression class for referencing a variable, like "a". +class MemberExprAST : public ExprAST { + public: + std::unique_ptr object; + std::string field; + + MemberExprAST(std::unique_ptr &&object, std::string field) + : object(std::move(object)), field(field) {} + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; +}; + +// BinaryExprAST - Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + public: + parser::ExpressionType op; + std::unique_ptr lhs, rhs; + + BinaryExprAST(parser::ExpressionType op, std::unique_ptr lhs, std::unique_ptr rhs) + : op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {} + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; +}; + +// CallExprAST - Expression class for function calls. +class CallExprAST : public ExprAST { + public: + std::string callee; + std::vector> args; + + CallExprAST(const std::string &callee, std::vector> args) + : callee(callee), args(std::move(args)) {} + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; +}; + +// SeqStmtAST - Statement class for sequence of statements +class SeqStmtAST : public StmtAST { + public: + std::vector> stmts; + + SeqStmtAST(std::vector> stmts) : stmts(std::move(stmts)) {} + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; +}; + +// DeclStmtAST - Statement class for sequence of statements +class DeclStmtAST : public StmtAST { + public: + std::string name; + type::TypeId type; + std::unique_ptr initial; + + DeclStmtAST(std::string name, type::TypeId type, std::unique_ptr initial) + : name(std::move(name)), type(std::move(type)), initial(std::move(initial)) {} + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; +}; + +// IfStmtAST - Statement class for if/then/else. +class IfStmtAST : public StmtAST { + public: + std::unique_ptr cond_expr; + std::unique_ptr then_stmt, else_stmt; + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + + IfStmtAST(std::unique_ptr cond_expr, std::unique_ptr then_stmt, std::unique_ptr else_stmt) + : cond_expr(std::move(cond_expr)), then_stmt(std::move(then_stmt)), else_stmt(std::move(else_stmt)) {} +}; + +class ForStmtAST : public StmtAST { + public: + std::vector vars_; + std::unique_ptr query_; + std::unique_ptr body_stmt_; + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + + ForStmtAST(std::vector &&vars_vec, std::unique_ptr query, + std::unique_ptr body_stmt) + : vars_(std::move(vars_vec)), query_(std::move(query)), body_stmt_(std::move(body_stmt)) {} +}; + +// WhileAST - Statement class for while loop +class WhileStmtAST : public StmtAST { + public: + std::unique_ptr cond_expr; + std::unique_ptr body_stmt; + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + + WhileStmtAST(std::unique_ptr cond_expr, std::unique_ptr body_stmt) + : cond_expr(std::move(cond_expr)), body_stmt(std::move(body_stmt)) {} +}; + +// RetStmtAST - Statement class for sequence of statements +class RetStmtAST : public StmtAST { + public: + std::unique_ptr expr; + + RetStmtAST(std::unique_ptr expr) : expr(std::move(expr)) {} + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; +}; + +// AssignStmtAST - Expression class for a binary operator. +class AssignStmtAST : public ExprAST { + public: + std::unique_ptr lhs; + std::unique_ptr rhs; + + AssignStmtAST(std::unique_ptr lhs, std::unique_ptr rhs) + : lhs(std::move(lhs)), rhs(std::move(rhs)) {} + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; +}; + +// SQLStmtAST - Expression class for a SQL Statement. +class SQLStmtAST : public StmtAST { + public: + std::unique_ptr query; + std::string var_name; + std::unordered_map> udf_params; + + SQLStmtAST(std::unique_ptr query, std::string var_name, + std::unordered_map> &&udf_params) + : query(std::move(query)), var_name(std::move(var_name)), udf_params(std::move(udf_params)) {} + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; +}; + +// DynamicSQLStmtAST - Expression class for a SQL Statement. +class DynamicSQLStmtAST : public StmtAST { + public: + std::unique_ptr query; + std::string var_name; + + DynamicSQLStmtAST(std::unique_ptr query, std::string var_name) + : query(std::move(query)), var_name(std::move(var_name)) {} + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; +}; + +// FunctionAST - This class represents a function definition itself. +class FunctionAST : public AbstractAST { + public: + std::unique_ptr body; + std::vector param_names_; + std::vector param_types_; + + FunctionAST(std::unique_ptr body, std::vector &¶m_names, + std::vector &¶m_types) + : body(std::move(body)), param_names_(std::move(param_names)), param_types_(std::move(param_types)) {} + + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; +}; + +/*---------------------------------------------------------------- +/// Error* - These are little helper functions for error handling. +-----------------------------------------------------------------*/ + +std::unique_ptr LogError(const char *str); + +} // namespace udf +} // namespace ast +} // namespace execution +} // namespace noisepage diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h new file mode 100644 index 0000000000..3ec7888a41 --- /dev/null +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -0,0 +1,101 @@ +#pragma once + +#include "execution/ast/udf/udf_ast_context.h" +#include "execution/ast/udf/udf_ast_node_visitor.h" +#include "execution/compiler/codegen.h" +#include "execution/compiler/function_builder.h" +#include "execution/functions/function_context.h" + +// TODO(Kyle): Documentation. + +namespace noisepage::catalog { +class CatalogAccessor; +} + +namespace noisepage { +namespace execution { +namespace compiler { +namespace udf { + +// TODO(Kyle): Is distinguishing the standard codegen +// namespace stuff from the UDF stuff here going to be +// an issue (i.e. disambiguation)? + +class AbstractAST; +class StmtAST; +class ExprAST; +class ValueExprAST; +class VariableExprAST; +class BinaryExprAST; +class CallExprAST; +class MemberExprAST; +class SeqStmtAST; +class DeclStmtAST; +class IfStmtAST; +class WhileStmtAST; +class RetStmtAST; +class AssignStmtAST; +class SQLStmtAST; +class DynamicSQLStmtAST; +class ForStmtAST; + +class UDFCodegen : ASTNodeVisitor { + public: + UDFCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, parser::udf::UDFASTContext *udf_ast_context, + CodeGen *codegen, catalog::db_oid_t db_oid); + ~UDFCodegen(){}; + + catalog::type_oid_t GetCatalogTypeOidFromSQLType(execution::ast::BuiltinType::Kind type); + + void GenerateUDF(AbstractAST *); + void Visit(AbstractAST *) override; + void Visit(FunctionAST *) override; + void Visit(StmtAST *) override; + void Visit(ExprAST *) override; + void Visit(ValueExprAST *) override; + void Visit(VariableExprAST *) override; + void Visit(BinaryExprAST *) override; + void Visit(CallExprAST *) override; + void Visit(IsNullExprAST *) override; + void Visit(SeqStmtAST *) override; + void Visit(DeclStmtAST *) override; + void Visit(IfStmtAST *) override; + void Visit(WhileStmtAST *) override; + void Visit(RetStmtAST *) override; + void Visit(AssignStmtAST *) override; + void Visit(SQLStmtAST *) override; + void Visit(DynamicSQLStmtAST *) override; + void Visit(ForStmtAST *) override; + void Visit(MemberExprAST *) override; + + execution::ast::File *Finish() { + auto fn = fb_->Finish(); + //// util::RegionVector decls_reg_vec{decls->begin(), decls->end(), codegen.Region()}; + execution::util::RegionVector decls({fn}, codegen_->GetAstContext()->GetRegion()); + // for(auto decl : aux_decls_){ + // decls.push_back(decl); + // } + decls.insert(decls.begin(), aux_decls_.begin(), aux_decls_.end()); + auto file = codegen_->GetAstContext()->GetNodeFactory()->NewFile({0, 0}, std::move(decls)); + return file; + } + + static const char *GetReturnParamString(); + + private: + catalog::CatalogAccessor *accessor_; + FunctionBuilder *fb_; + UDFASTContext *udf_ast_context_; + CodeGen *codegen_; + type::TypeId current_type_{type::TypeId::INVALID}; + execution::ast::Expr *dst_; + std::unordered_map str_to_ident_; + execution::util::RegionVector aux_decls_; + catalog::db_oid_t db_oid_; + bool needs_exec_ctx_{false}; +}; + +} // namespace udf +} // namespace compiler +} // namespace execution +} // namespace noisepage diff --git a/src/include/parser/udf/udf_parser.h b/src/include/parser/udf/udf_parser.h new file mode 100644 index 0000000000..1d441b6909 --- /dev/null +++ b/src/include/parser/udf/udf_parser.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include + +#include "ast_nodes.h" +#include "catalog/catalog_accessor.h" + +#include "parser/expression_util.h" +#include "parser/postgresparser.h" +#include "parser/udf/udf_ast_context.h" + +// TODO(Kyle): Do we want to place UDF parsing in its own namespace? +namespace noisepage { +namespace parser { +namespace udf { + +class FunctionAST; +class PLpgSQLParser { + public: + PLpgSQLParser(common::ManagedPointer udf_ast_context, + const common::ManagedPointer accessor, catalog::db_oid_t db_oid) + : udf_ast_context_(udf_ast_context), accessor_(accessor), db_oid_(db_oid) {} + std::unique_ptr ParsePLpgSQL(std::vector &¶m_names, + std::vector &¶m_types, const std::string &func_body, + common::ManagedPointer ast_context); + + private: + std::unique_ptr ParseBlock(const nlohmann::json &block); + std::unique_ptr ParseFunction(const nlohmann::json &block); + std::unique_ptr ParseDecl(const nlohmann::json &decl); + std::unique_ptr ParseIf(const nlohmann::json &branch); + std::unique_ptr ParseWhile(const nlohmann::json &loop); + std::unique_ptr ParseFor(const nlohmann::json &loop); + std::unique_ptr ParseSQL(const nlohmann::json &sql_stmt); + std::unique_ptr ParseDynamicSQL(const nlohmann::json &sql_stmt); + // Feed the expression (as a sql string) to our parser then transform the + // noisepage expression into ast node + std::unique_ptr ParseExprSQL(const std::string expr_sql_str); + std::unique_ptr ParseExpr(common::ManagedPointer); + + common::ManagedPointer udf_ast_context_; + const common::ManagedPointer accessor_; + catalog::db_oid_t db_oid_; + // common::ManagedPointer sql_parser_; + std::unordered_map symbol_table_; +}; + +} // namespace udf +} // namespace parser +} // namespace noisepage diff --git a/src/parser/udf/udf_parser.cpp b/src/parser/udf/udf_parser.cpp new file mode 100644 index 0000000000..e904d0c6d3 --- /dev/null +++ b/src/parser/udf/udf_parser.cpp @@ -0,0 +1,347 @@ +#include + +#include "binder/bind_node_visitor.h" +#include "loggers/parser_logger.h" +#include "parser/udf/udf_parser.h" + +#include "libpg_query/pg_query.h" +#include "nlohmann/json.hpp" + +// TODO(Kyle): This whole file needs documentation... + +// TODO(Kyle): Do we want to put UDF parsing in its own namespace? +namespace noisepage { +namespace parser { +namespace udf { +using namespace nlohmann; + +// TODO(Kyle): constexpr +// TODO(Kyle): Define elsewhere? +const std::string kFunctionList = "FunctionList"; +const std::string kDatums = "datums"; +const std::string kPLpgSQL_var = "PLpgSQL_var"; +const std::string kRefname = "refname"; +const std::string kDatatype = "datatype"; +const std::string kDefaultVal = "default_val"; +const std::string kPLpgSQL_type = "PLpgSQL_type"; +const std::string kTypname = "typname"; +const std::string kAction = "action"; +const std::string kPLpgSQL_function = "PLpgSQL_function"; +const std::string kBody = "body"; +const std::string kPLpgSQL_stmt_block = "PLpgSQL_stmt_block"; +const std::string kPLpgSQL_stmt_return = "PLpgSQL_stmt_return"; +const std::string kPLpgSQL_stmt_if = "PLpgSQL_stmt_if"; +const std::string kPLpgSQL_stmt_while = "PLpgSQL_stmt_while"; +const std::string kPLpgSQL_stmt_fors = "PLpgSQL_stmt_fors"; +const std::string kCond = "cond"; +const std::string kThenBody = "then_body"; +const std::string kElseBody = "else_body"; +const std::string kExpr = "expr"; +const std::string kQuery = "query"; +const std::string kPLpgSQL_expr = "PLpgSQL_expr"; +const std::string kPLpgSQL_stmt_assign = "PLpgSQL_stmt_assign"; +const std::string kVarno = "varno"; +const std::string kPLpgSQL_stmt_execsql = "PLpgSQL_stmt_execsql"; +const std::string kSqlstmt = "sqlstmt"; +const std::string kRow = "row"; +const std::string kFields = "fields"; +const std::string kName = "name"; +const std::string kPLpgSQL_row = "PLpgSQL_row"; +const std::string kPLpgSQL_stmt_dynexecute = "PLpgSQL_stmt_dynexecute"; + +std::unique_ptr PLpgSQLParser::ParsePLpgSQL(std::vector &¶m_names, + std::vector &¶m_types, + const std::string &func_body, + common::ManagedPointer ast_context) { + auto result = pg_query_parse_plpgsql(func_body.c_str()); + if (result.error) { + PARSER_LOG_INFO("PL/pgSQL parse error : {}", result.error->message); + pg_query_free_plpgsql_parse_result(result); + throw PARSER_EXCEPTION("PL/pgSQL parsing error"); + } + // The result is a list, we need to wrap it + std::string ast_json_str = "{ \"" + kFunctionList + "\" : " + std::string(result.plpgsql_funcs) + " }"; + // LOG_DEBUG("Compiling JSON formatted function %s", ast_json_str.c_str()); + pg_query_free_plpgsql_parse_result(result); + + std::cout << ast_json_str << "\n"; + + std::istringstream ss(ast_json_str); + json ast_json; + ss >> ast_json; + const auto function_list = ast_json[kFunctionList]; + NOISEPAGE_ASSERT(function_list.is_array(), "Function list is not an array"); + if (function_list.size() != 1) { + PARSER_LOG_DEBUG("PL/pgSQL error : Function list size %u", function_list.size()); + throw PARSER_EXCEPTION("Function list has size other than 1"); + } + + size_t i = 0; + for (auto udf_name : param_names) { + // udf_ast_context_->AddVariable(udf_name); + udf_ast_context_->SetVariableType(udf_name, param_types[i++]); + } + const auto function = function_list[0][kPLpgSQL_function]; + std::unique_ptr function_ast( + new FunctionAST(ParseFunction(function), std::move(param_names), std::move(param_types))); + return function_ast; +} + +std::unique_ptr PLpgSQLParser::ParseFunction(const nlohmann::json &block) { + const auto decl_list = block[kDatums]; + const auto function_body = block[kAction][kPLpgSQL_stmt_block][kBody]; + + std::vector> stmts; + + PARSER_LOG_DEBUG("Parsing Declarations"); + NOISEPAGE_ASSERT(decl_list.is_array(), "Declaration list is not an array"); + for (uint32_t i = 1; i < decl_list.size(); i++) { + stmts.push_back(ParseDecl(decl_list[i])); + } + + stmts.push_back(ParseBlock(function_body)); + + std::unique_ptr seq_stmt_ast(new SeqStmtAST(std::move(stmts))); + return std::move(seq_stmt_ast); +} + +std::unique_ptr PLpgSQLParser::ParseBlock(const nlohmann::json &block) { + // TODO(boweic): Support statements size other than 1 + NOISEPAGE_ASSERT(block.is_array(), "Block isn't array"); + if (block.size() == 0) { + throw PARSER_EXCEPTION("PL/pgSQL parser : Empty block is not supported"); + } + + std::vector> stmts; + + for (uint32_t i = 0; i < block.size(); i++) { + const auto stmt = block[i]; + const auto stmt_names = stmt.items().begin(); + // NOISEPAGE_ASSERT(stmt_names->size() == 1, "Bad statement size"); + PARSER_LOG_DEBUG("Statement : {}", stmt_names.key()); + + if (stmt_names.key() == kPLpgSQL_stmt_return) { + auto expr = ParseExprSQL(stmt[kPLpgSQL_stmt_return][kExpr][kPLpgSQL_expr][kQuery].get()); + // TODO(boweic): Handle return stmt w/o expression + std::unique_ptr ret_stmt_ast(new RetStmtAST(std::move(expr))); + stmts.push_back(std::move(ret_stmt_ast)); + } else if (stmt_names.key() == kPLpgSQL_stmt_if) { + stmts.push_back(ParseIf(stmt[kPLpgSQL_stmt_if])); + } else if (stmt_names.key() == kPLpgSQL_stmt_assign) { + // TODO[Siva]: Need to fix Assignment expression / statement + const std::string &var_name = + udf_ast_context_->GetVariableAtIndex(stmt[kPLpgSQL_stmt_assign][kVarno].get()); + std::unique_ptr lhs(new VariableExprAST(var_name)); + auto rhs = ParseExprSQL(stmt[kPLpgSQL_stmt_assign][kExpr][kPLpgSQL_expr][kQuery].get()); + std::unique_ptr ass_expr_ast(new AssignStmtAST(std::move(lhs), std::move(rhs))); + stmts.push_back(std::move(ass_expr_ast)); + } else if (stmt_names.key() == kPLpgSQL_stmt_while) { + stmts.push_back(ParseWhile(stmt[kPLpgSQL_stmt_while])); + } else if (stmt_names.key() == kPLpgSQL_stmt_fors) { + stmts.push_back(ParseFor(stmt[kPLpgSQL_stmt_fors])); + } else if (stmt_names.key() == kPLpgSQL_stmt_execsql) { + stmts.push_back(ParseSQL(stmt[kPLpgSQL_stmt_execsql])); + } else if (stmt_names.key() == kPLpgSQL_stmt_dynexecute) { + stmts.push_back(ParseDynamicSQL(stmt[kPLpgSQL_stmt_dynexecute])); + } else { + throw PARSER_EXCEPTION("Statement type not supported"); + } + NOISEPAGE_ASSERT(stmts.back() != nullptr, "It broke"); + } + + std::unique_ptr seq_stmt_ast(new SeqStmtAST(std::move(stmts))); + return std::move(seq_stmt_ast); +} + +std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &decl) { + const auto &decl_names = decl.items().begin(); + for (auto &it : decl.items()) { + std::cout << it.key() << " : " << it.value() << "\n"; + } + // NOISEPAGE_ASSERT(decl_names->size() >= 1, "Bad declaration names membership size"); + PARSER_LOG_DEBUG("Declaration : {}", decl_names.key()); + + if (decl_names.key() == kPLpgSQL_var) { + auto var_name = decl[kPLpgSQL_var][kRefname].get(); + udf_ast_context_->AddVariable(var_name); + auto type = decl[kPLpgSQL_var][kDatatype][kPLpgSQL_type][kTypname].get(); + std::unique_ptr initial = nullptr; + if (decl[kPLpgSQL_var].find(kDefaultVal) != decl[kPLpgSQL_var].end()) { + initial = ParseExprSQL(decl[kPLpgSQL_var][kDefaultVal][kPLpgSQL_expr][kQuery].get()); + } + + PARSER_LOG_INFO("Registering type {0}: {1}", var_name.c_str(), type.c_str()); + + type::TypeId temp_type; + if (udf_ast_context_->GetVariableType(var_name, &temp_type)) { + return std::unique_ptr(new DeclStmtAST(var_name, temp_type, std::move(initial))); + } + + if ((type.find("integer") != std::string::npos) || type.find("INTEGER") != std::string::npos) { + udf_ast_context_->SetVariableType(var_name, type::TypeId::INTEGER); + return std::unique_ptr(new DeclStmtAST(var_name, type::TypeId::INTEGER, std::move(initial))); + } else if (type == "double" || type.rfind("numeric") == 0) { + udf_ast_context_->SetVariableType(var_name, type::TypeId::DECIMAL); + return std::unique_ptr(new DeclStmtAST(var_name, type::TypeId::DECIMAL, std::move(initial))); + } else if (type == "varchar") { + udf_ast_context_->SetVariableType(var_name, type::TypeId::VARCHAR); + return std::unique_ptr(new DeclStmtAST(var_name, type::TypeId::VARCHAR, std::move(initial))); + } else if (type.find("date") != std::string::npos) { + udf_ast_context_->SetVariableType(var_name, type::TypeId::DATE); + return std::unique_ptr(new DeclStmtAST(var_name, type::TypeId::DATE, std::move(initial))); + } else if (type == "record") { + udf_ast_context_->SetVariableType(var_name, type::TypeId::INVALID); + return std::unique_ptr(new DeclStmtAST(var_name, type::TypeId::INVALID, std::move(initial))); + } else { + NOISEPAGE_ASSERT(false, "Unsupported "); + // udf_ast_context_->SetVariableType(var_name, type::TypeId::INVALID); + // return std::unique_ptr( + // new DeclStmtAST(var_name, type::TypeId::INVALID)); + } + } else if (decl_names.key() == kPLpgSQL_row) { + auto var_name = decl[kPLpgSQL_row][kRefname].get(); + NOISEPAGE_ASSERT(var_name == "*internal*", "Unexpected refname"); + // TODO[Siva]: Support row types later + udf_ast_context_->SetVariableType(var_name, type::TypeId::INVALID); + return std::unique_ptr(new DeclStmtAST(var_name, type::TypeId::INVALID, nullptr)); + } + // TODO[Siva]: need to handle other types like row, table etc; + throw PARSER_EXCEPTION("Declaration type not supported"); +} + +std::unique_ptr PLpgSQLParser::ParseIf(const nlohmann::json &branch) { + PARSER_LOG_DEBUG("ParseIf"); + auto cond_expr = ParseExprSQL(branch[kCond][kPLpgSQL_expr][kQuery].get()); + auto then_stmt = ParseBlock(branch[kThenBody]); + std::unique_ptr else_stmt = nullptr; + if (branch.find(kElseBody) != branch.end()) { + else_stmt = ParseBlock(branch[kElseBody]); + } + return std::unique_ptr(new IfStmtAST(std::move(cond_expr), std::move(then_stmt), std::move(else_stmt))); +} + +std::unique_ptr PLpgSQLParser::ParseWhile(const nlohmann::json &loop) { + PARSER_LOG_DEBUG("ParseWhile"); + auto cond_expr = ParseExprSQL(loop[kCond][kPLpgSQL_expr][kQuery].get()); + auto body_stmt = ParseBlock(loop[kBody]); + return std::unique_ptr(new WhileStmtAST(std::move(cond_expr), std::move(body_stmt))); +} + +std::unique_ptr PLpgSQLParser::ParseFor(const nlohmann::json &loop) { + PARSER_LOG_DEBUG("ParseFor"); + auto sql_query = loop[kQuery][kPLpgSQL_expr][kQuery].get(); + ; + auto parse_result = PostgresParser::BuildParseTree(sql_query.c_str()); + if (parse_result == nullptr) { + PARSER_LOG_DEBUG("Bad SQL statement"); + return nullptr; + } + auto body_stmt = ParseBlock(loop[kBody]); + auto var_array = loop[kRow][kPLpgSQL_row][kFields]; + std::vector var_vec; + for (auto var : var_array) { + var_vec.push_back(var[kName].get()); + } + return std::unique_ptr(new ForStmtAST(std::move(var_vec), std::move(parse_result), std::move(body_stmt))); +} + +std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql_stmt) { + PARSER_LOG_DEBUG("ParseSQL"); + auto sql_query = sql_stmt[kSqlstmt][kPLpgSQL_expr][kQuery].get(); + auto var_name = sql_stmt[kRow][kPLpgSQL_row][kFields][0][kName].get(); + auto parse_result = PostgresParser::BuildParseTree(sql_query.c_str()); + if (parse_result == nullptr) { + PARSER_LOG_DEBUG("Bad SQL statement"); + return nullptr; + } + binder::BindNodeVisitor visitor(accessor_, db_oid_); + + std::unordered_map> query_params; + + try { + // TODO(Matt): I don't think the binder should need the database name. It's already bound in the ConnectionContext + // binder::BindNodeVisitor visitor(accessor_, db_oid_); + query_params = visitor.BindAndGetUDFParams(common::ManagedPointer(parse_result), udf_ast_context_); + } catch (BinderException &b) { + PARSER_LOG_DEBUG("Bad SQL statement"); + return nullptr; + } + + // check to see if a record type can be bound to this + type::TypeId type; + auto ret = udf_ast_context_->GetVariableType(var_name, &type); + if (!ret) { + throw PARSER_EXCEPTION("PL/pgSQL parser : Didn't declare variable"); + } + if (type == type::TypeId::INVALID) { + std::vector> elems; + auto sel = parse_result->GetStatement(0).CastManagedPointerTo()->GetSelectColumns(); + for (auto col : sel) { + elems.emplace_back(col->GetAliasName(), col->GetReturnValueType()); + } + udf_ast_context_->SetRecordType(var_name, std::move(elems)); + } + + return std::unique_ptr( + new SQLStmtAST(std::move(parse_result), std::move(var_name), std::move(query_params))); + // return nullptr; +} + +std::unique_ptr PLpgSQLParser::ParseDynamicSQL(const nlohmann::json &sql_stmt) { + PARSER_LOG_DEBUG("ParseDynamicSQL"); + auto sql_expr = ParseExprSQL(sql_stmt[kQuery][kPLpgSQL_expr][kQuery].get()); + auto var_name = sql_stmt[kRow][kPLpgSQL_row][kFields][0][kName].get(); + return std::unique_ptr(new DynamicSQLStmtAST(std::move(sql_expr), std::move(var_name))); +} + +std::unique_ptr PLpgSQLParser::ParseExprSQL(const std::string expr_sql_str) { + PARSER_LOG_DEBUG("Parsing Expr SQL : {}", expr_sql_str.c_str()); + auto stmt_list = PostgresParser::BuildParseTree(expr_sql_str.c_str()); + if (stmt_list == nullptr) { + return nullptr; + } + NOISEPAGE_ASSERT(stmt_list->GetStatements().size() == 1, "Bad number of statements"); + auto stmt = stmt_list->GetStatement(0); + NOISEPAGE_ASSERT(stmt->GetType() == parser::StatementType::SELECT, "Unsupported statement type"); + NOISEPAGE_ASSERT(stmt.CastManagedPointerTo()->GetSelectTable() == nullptr, + "Unsupported SQL Expr in UDF"); + auto &select_list = stmt.CastManagedPointerTo()->GetSelectColumns(); + NOISEPAGE_ASSERT(select_list.size() == 1, "Unsupported number of select columns in udf"); + return PLpgSQLParser::ParseExpr(select_list[0]); +} + +std::unique_ptr PLpgSQLParser::ParseExpr(common::ManagedPointer expr) { + if (expr->GetExpressionType() == parser::ExpressionType::COLUMN_VALUE) { + auto cve = expr.CastManagedPointerTo(); + if (cve->GetTableName().empty()) { + return std::unique_ptr(new VariableExprAST(cve->GetColumnName())); + } else { + auto vexpr = std::unique_ptr(new VariableExprAST(cve->GetTableName())); + return std::unique_ptr(new MemberExprAST(std::move(vexpr), cve->GetColumnName())); + } + } else if ((parser::ExpressionUtil::IsOperatorExpression(expr->GetExpressionType()) && + expr->GetChildrenSize() == 2) || + (parser::ExpressionUtil::IsComparisonExpression(expr->GetExpressionType()))) { + return std::unique_ptr( + new BinaryExprAST(expr->GetExpressionType(), ParseExpr(expr->GetChild(0)), ParseExpr(expr->GetChild(1)))); + } else if (expr->GetExpressionType() == parser::ExpressionType::FUNCTION) { + auto func_expr = expr.CastManagedPointerTo(); + std::vector> args; + auto num_args = func_expr->GetChildrenSize(); + for (size_t idx = 0; idx < num_args; ++idx) { + args.push_back(ParseExpr(func_expr->GetChild(idx))); + } + return std::unique_ptr(new CallExprAST(func_expr->GetFuncName(), std::move(args))); + } else if (expr->GetExpressionType() == parser::ExpressionType::VALUE_CONSTANT) { + return std::unique_ptr(new ValueExprAST(expr->Copy())); + } else if (expr->GetExpressionType() == parser::ExpressionType::OPERATOR_IS_NOT_NULL) { + return std::unique_ptr(new IsNullExprAST(false, ParseExpr(expr->GetChild(0)))); + } else if (expr->GetExpressionType() == parser::ExpressionType::OPERATOR_IS_NULL) { + return std::unique_ptr(new IsNullExprAST(true, ParseExpr(expr->GetChild(0)))); + } + throw PARSER_EXCEPTION("PL/pgSQL parser : Expression type not supported"); + return nullptr; +} +} // namespace udf +} // namespace parser +} // namespace noisepage \ No newline at end of file From ff6be66d7423644a6e22cc8548bd20c601388955 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sat, 3 Apr 2021 11:49:13 -0400 Subject: [PATCH 008/139] reorder --- src/execution/compiler/udf/udf_codegen.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 35838df5dc..64891201da 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -14,8 +14,8 @@ #include "execution/exec/execution_settings.h" #include "catalog/catalog_accessor.h" -#include "optimizer/cost_model/trivial_cost_model.h" #include "optimizer/statistics/stats_storage.h" +#include "optimizer/cost_model/trivial_cost_model.h" #include "traffic_cop/traffic_cop_util.h" From 75839c5cbd8b9e3c56450037c5c6923211a31de9 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sat, 3 Apr 2021 17:16:34 -0400 Subject: [PATCH 009/139] pull in ast for udfs, messy and unimplemented in some places because of missing deps, but everything compiles and links thus far --- CMakeLists.txt | 2 +- src/catalog/catalog_accessor.cpp | 5 + src/catalog/database_catalog.cpp | 5 + src/execution/ast/ast.cpp | 8 +- src/execution/ast/ast_clone.cpp | 265 +++++++ src/execution/ast/ast_dump.cpp | 17 + src/execution/ast/ast_pretty_print.cpp | 13 + src/execution/ast/context.cpp | 34 +- src/execution/ast/type.cpp | 32 +- src/execution/ast/type_printer.cpp | 11 + src/execution/compiler/udf/udf_codegen.cpp | 707 +++++++++--------- src/execution/sema/sema_expr.cpp | 96 +++ src/execution/sema/sema_stmt.cpp | 17 + src/execution/sema/sema_type.cpp | 14 +- src/execution/vm/bytecode_generator.cpp | 74 ++ src/execution/vm/llvm_engine.cpp | 13 + src/include/catalog/catalog_accessor.h | 5 + src/include/catalog/database_catalog.h | 3 + src/include/execution/ast/ast.h | 109 ++- src/include/execution/ast/ast_clone.h | 25 + src/include/execution/ast/ast_fwd.h | 1 + src/include/execution/ast/ast_node_factory.h | 26 + src/include/execution/ast/builtins.h | 12 + src/include/execution/ast/type.h | 82 +- .../execution/ast/udf/udf_ast_context.h | 7 +- .../execution/ast/udf/udf_ast_node_visitor.h | 40 +- .../execution/compiler/udf/udf_codegen.h | 79 +- .../execution/functions/function_context.h | 46 ++ src/include/execution/vm/bytecode_generator.h | 3 + src/include/parser/udf/udf_parser.h | 46 +- src/parser/udf/udf_parser.cpp | 88 +-- 31 files changed, 1400 insertions(+), 485 deletions(-) create mode 100644 src/execution/ast/ast_clone.cpp create mode 100644 src/include/execution/ast/ast_clone.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 8ccfe77068..b02ceb5edf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -234,7 +234,7 @@ set(NOISEPAGE_INCLUDE_DIRECTORIES "") # Add compilation flags to NOISEPAGE_COMPILE_OPTIONS based on the current CMAKE_BUILD_TYPE. string(TOUPPER ${CMAKE_BUILD_TYPE} CMAKE_BUILD_TYPE) if ("${CMAKE_BUILD_TYPE}" STREQUAL "DEBUG") - list(APPEND NOISEPAGE_COMPILE_OPTIONS "-ggdb" "-O0" "-fno-omit-frame-pointer" "-fno-optimize-sibling-calls") + list(APPEND NOISEPAGE_COMPILE_OPTIONS "-ggdb" "-O0" "-fno-omit-frame-pointer" "-fno-optimize-sibling-calls" "-Wfatal-errors") elseif ("${CMAKE_BUILD_TYPE}" STREQUAL "FASTDEBUG") list(APPEND NOISEPAGE_COMPILE_OPTIONS "-ggdb" "-O1" "-fno-omit-frame-pointer" "-fno-optimize-sibling-calls") elseif ("${CMAKE_BUILD_TYPE}" STREQUAL "RELEASE") diff --git a/src/catalog/catalog_accessor.cpp b/src/catalog/catalog_accessor.cpp index 805616ce44..c6638e0532 100644 --- a/src/catalog/catalog_accessor.cpp +++ b/src/catalog/catalog_accessor.cpp @@ -196,6 +196,11 @@ proc_oid_t CatalogAccessor::GetProcOid(const std::string &procname, const std::v return catalog::INVALID_PROC_OID; } +common::ManagedPointer CatalogAccessor::GetProcCtxPtr( + const proc_oid_t proc_oid) { + return dbc_->GetProcCtxPtr(txn_, proc_oid); +} + bool CatalogAccessor::SetFunctionContextPointer(proc_oid_t proc_oid, const execution::functions::FunctionContext *func_context) { return dbc_->SetFunctionContextPointer(txn_, proc_oid, func_context); diff --git a/src/catalog/database_catalog.cpp b/src/catalog/database_catalog.cpp index 5106d550b1..5b785e16c2 100644 --- a/src/catalog/database_catalog.cpp +++ b/src/catalog/database_catalog.cpp @@ -450,6 +450,11 @@ proc_oid_t DatabaseCatalog::GetProcOid(common::ManagedPointer DatabaseCatalog::GetProcCtxPtr( + common::ManagedPointer txn, proc_oid_t proc_oid) { + return pg_proc_.GetProcCtxPtr(txn, proc_oid); +} + template bool DatabaseCatalog::SetClassPointer(const common::ManagedPointer txn, const ClassOid oid, const Ptr *const pointer, const col_oid_t class_col) { diff --git a/src/execution/ast/ast.cpp b/src/execution/ast/ast.cpp index 5078214bf5..96c5abf0fb 100644 --- a/src/execution/ast/ast.cpp +++ b/src/execution/ast/ast.cpp @@ -8,8 +8,8 @@ namespace noisepage::execution::ast { // Function Declaration // --------------------------------------------------------- -FunctionDecl::FunctionDecl(const SourcePosition &pos, Identifier name, FunctionLitExpr *func) - : Decl(Kind::FunctionDecl, pos, name, func->TypeRepr()), func_(func) {} +FunctionDecl::FunctionDecl(const SourcePosition &pos, Identifier name, FunctionLitExpr *func, bool is_lambda) + : Decl(Kind::FunctionDecl, pos, name, func->TypeRepr()), func_(func), is_lambda_(is_lambda) {} // --------------------------------------------------------- // Structure Declaration @@ -90,8 +90,8 @@ bool ComparisonOpExpr::IsLiteralCompareNil(Expr **result) const { // Function Literal Expressions // --------------------------------------------------------- -FunctionLitExpr::FunctionLitExpr(FunctionTypeRepr *type_repr, BlockStmt *body) - : Expr(Kind::FunctionLitExpr, type_repr->Position()), type_repr_(type_repr), body_(body) {} +FunctionLitExpr::FunctionLitExpr(FunctionTypeRepr *type_repr, BlockStmt *body, bool is_lambda) + : Expr(Kind::FunctionLitExpr, type_repr->Position()), type_repr_(type_repr), body_(body), is_lambda_(is_lambda) {} // --------------------------------------------------------- // Call Expression diff --git a/src/execution/ast/ast_clone.cpp b/src/execution/ast/ast_clone.cpp new file mode 100644 index 0000000000..48579d2b82 --- /dev/null +++ b/src/execution/ast/ast_clone.cpp @@ -0,0 +1,265 @@ +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/raw_os_ostream.h" +#include "llvm/Support/raw_ostream.h" + +#include "execution/ast/ast.h" +#include "execution/ast/ast_clone.h" +#include "execution/ast/ast_visitor.h" +#include "execution/ast/context.h" +#include "execution/ast/type.h" + +namespace noisepage::execution::ast { + +/** + * TODO(Kyle): Document. + */ +class AstCloneImpl : public AstVisitor { + public: + explicit AstCloneImpl(AstNode *root, AstNodeFactory *factory, Context *old_context, Context *new_context, + std::string prefix) + : root_(root), factory_{factory}, old_context_{old_context}, new_context_{new_context}, prefix_{prefix} {} + + AstNode *Run() { return Visit(root_); } + + // Declare all node visit methods here +#define DECLARE_VISIT_METHOD(type) AstNode *Visit##type(type *node); + AST_NODES(DECLARE_VISIT_METHOD) +#undef DECLARE_VISIT_METHOD + + Identifier CloneIdentifier(Identifier &ident) { return new_context_->GetIdentifier(ident.GetData()); } + + Identifier CloneIdentifier(Identifier &&ident) { + (void)old_context_; + return new_context_->GetIdentifier(ident.GetData()); + } + + private: + AstNode *root_; + AstNodeFactory *factory_; + + Context *old_context_; + Context *new_context_; + std::string prefix_; + + llvm::DenseMap allocated_strings_; +}; + +AstNode *AstCloneImpl::VisitFile(File *node) { + util::RegionVector decls(new_context_->GetRegion()); + for (auto *decl : node->Declarations()) { + decls.push_back(reinterpret_cast(Visit(decl))); + } + return factory_->NewFile(node->Position(), std::move(decls)); +} + +AstNode *AstCloneImpl::VisitFieldDecl(FieldDecl *node) { + return factory_->NewFieldDecl(node->Position(), CloneIdentifier(node->Name()), + reinterpret_cast(Visit(node->TypeRepr()))); +} + +AstNode *AstCloneImpl::VisitFunctionDecl(FunctionDecl *node) { + return factory_->NewFunctionDecl(node->Position(), CloneIdentifier(node->Name()), + reinterpret_cast(VisitFunctionLitExpr(node->Function()))); +} + +AstNode *AstCloneImpl::VisitVariableDecl(VariableDecl *node) { + return factory_->NewVariableDecl( + node->Position(), CloneIdentifier(node->Name()), + node->TypeRepr() == nullptr ? nullptr : reinterpret_cast(Visit(node->TypeRepr())), + node->Initial() == nullptr ? nullptr : reinterpret_cast(Visit(node->Initial()))); +} + +AstNode *AstCloneImpl::VisitStructDecl(StructDecl *node) { + return factory_->NewStructDecl( + node->Position(), CloneIdentifier(node->Name()), + reinterpret_cast(VisitStructTypeRepr(reinterpret_cast(node->TypeRepr())))); +} + +AstNode *AstCloneImpl::VisitAssignmentStmt(AssignmentStmt *node) { + return factory_->NewAssignmentStmt(node->Position(), reinterpret_cast(Visit(node->Destination())), + reinterpret_cast(Visit(node->Source()))); +} + +AstNode *AstCloneImpl::VisitBlockStmt(BlockStmt *node) { + util::RegionVector stmts(new_context_->GetRegion()); + for (auto *stmt : node->Statements()) { + stmts.push_back(reinterpret_cast(Visit(stmt))); + } + return factory_->NewBlockStmt(node->Position(), node->RightBracePosition(), std::move(stmts)); +} + +AstNode *AstCloneImpl::VisitDeclStmt(DeclStmt *node) { + return factory_->NewDeclStmt(reinterpret_cast(Visit(node->Declaration()))); +} + +AstNode *AstCloneImpl::VisitExpressionStmt(ExpressionStmt *node) { + return factory_->NewExpressionStmt(reinterpret_cast(Visit(node->Expression()))); +} + +AstNode *AstCloneImpl::VisitForStmt(ForStmt *node) { + auto init = node->Init() == nullptr ? nullptr : reinterpret_cast(Visit(node->Init())); + auto next = node->Next() == nullptr ? nullptr : reinterpret_cast(Visit(node->Next())); + return factory_->NewForStmt(node->Position(), init, reinterpret_cast(Visit(node->Condition())), next, + reinterpret_cast(VisitBlockStmt(node->Body()))); +} + +AstNode *AstCloneImpl::VisitForInStmt(ForInStmt *node) { + return factory_->NewForInStmt(node->Position(), reinterpret_cast(Visit(node->Target())), + reinterpret_cast(Visit(node->Iterable())), + reinterpret_cast(VisitBlockStmt(node->Body()))); +} + +AstNode *AstCloneImpl::VisitIfStmt(IfStmt *node) { + auto *else_stmt = node->ElseStmt() == nullptr ? nullptr : reinterpret_cast(Visit((node->ElseStmt()))); + return factory_->NewIfStmt(node->Position(), reinterpret_cast(Visit(node->Condition())), + reinterpret_cast(VisitBlockStmt(node->ThenStmt())), else_stmt); +} + +AstNode *AstCloneImpl::VisitReturnStmt(ReturnStmt *node) { + if (node->Ret() == nullptr) { + return factory_->NewReturnStmt(node->Position(), nullptr); + } else { + return factory_->NewReturnStmt(node->Position(), reinterpret_cast(Visit(node->Ret()))); + } +} + +AstNode *AstCloneImpl::VisitCallExpr(CallExpr *node) { + util::RegionVector args(new_context_->GetRegion()); + + for (auto *arg : node->Arguments()) { + args.push_back(reinterpret_cast(Visit(arg))); + } + if (node->GetCallKind() == CallExpr::CallKind::Builtin) { + return factory_->NewBuiltinCallExpr(reinterpret_cast(Visit(node->Function())), std::move(args)); + } + return factory_->NewCallExpr(reinterpret_cast(Visit(node->Function())), std::move(args)); +} + +AstNode *AstCloneImpl::VisitBinaryOpExpr(BinaryOpExpr *node) { + return factory_->NewBinaryOpExpr(node->Position(), node->Op(), reinterpret_cast(Visit(node->Left())), + reinterpret_cast(Visit(node->Right()))); +} + +AstNode *AstCloneImpl::VisitComparisonOpExpr(ComparisonOpExpr *node) { + return factory_->NewComparisonOpExpr(node->Position(), node->Op(), reinterpret_cast(Visit(node->Left())), + reinterpret_cast(Visit(node->Right()))); +} + +AstNode *AstCloneImpl::VisitFunctionLitExpr(FunctionLitExpr *node) { + return factory_->NewFunctionLitExpr(reinterpret_cast(VisitFunctionTypeRepr(node->TypeRepr())), + reinterpret_cast(VisitBlockStmt(node->Body()))); +} + +AstNode *AstCloneImpl::VisitIdentifierExpr(IdentifierExpr *node) { + return factory_->NewIdentifierExpr(node->Position(), CloneIdentifier(node->Name())); +} + +AstNode *AstCloneImpl::VisitImplicitCastExpr(ImplicitCastExpr *node) { + // TODO(Kyle): The type might have to be cloned + return Visit(node->Input()); +} + +AstNode *AstCloneImpl::VisitIndexExpr(IndexExpr *node) { + return factory_->NewIndexExpr(node->Position(), reinterpret_cast(Visit(node->Object())), + reinterpret_cast(Visit(node->Index()))); +} + +AstNode *AstCloneImpl::VisitLambdaExpr(LambdaExpr *node) { + util::RegionVector capture_idents(new_context_->GetRegion()); + for (auto ident : node->GetCaptureIdents()) { + capture_idents.push_back(reinterpret_cast(Visit(ident))); + } + return factory_->NewLambdaExpr(node->Position(), + reinterpret_cast(Visit(node->GetFunctionLitExpr())), + std::move(capture_idents)); +} + +AstNode *AstCloneImpl::VisitLitExpr(LitExpr *node) { + AstNode *literal = nullptr; + switch (node->GetLiteralKind()) { + case LitExpr::LitKind::Nil: { + literal = factory_->NewNilLiteral(node->Position()); + break; + } + case LitExpr::LitKind::Boolean: { + literal = factory_->NewBoolLiteral(node->Position(), node->BoolVal()); + break; + } + case LitExpr::LitKind::Int: { + literal = factory_->NewIntLiteral(node->Position(), node->Int64Val()); + break; + } + case LitExpr::LitKind::Float: { + literal = factory_->NewFloatLiteral(node->Position(), node->Float64Val()); + break; + } + case LitExpr::LitKind::String: { + literal = factory_->NewStringLiteral(node->Position(), CloneIdentifier(node->StringVal())); + break; + } + } + NOISEPAGE_ASSERT(literal != nullptr, "Unknown literal kind"); + return literal; +} + +AstNode *AstCloneImpl::VisitBreakStmt(BreakStmt *node) { return factory_->NewBreakStmt(node->Position()); } + +AstNode *AstCloneImpl::VisitMemberExpr(MemberExpr *node) { + return factory_->NewMemberExpr(node->Position(), reinterpret_cast(Visit(node->Object())), + reinterpret_cast(Visit(node->Member()))); +} + +AstNode *AstCloneImpl::VisitUnaryOpExpr(UnaryOpExpr *node) { + return factory_->NewUnaryOpExpr(node->Position(), node->Op(), reinterpret_cast(Visit(node->Input()))); +} + +AstNode *AstCloneImpl::VisitBadExpr(BadExpr *node) { return factory_->NewBadExpr(node->Position()); } + +AstNode *AstCloneImpl::VisitStructTypeRepr(StructTypeRepr *node) { + util::RegionVector field_decls(new_context_->GetRegion()); + field_decls.reserve(node->Fields().size()); + for (auto field : node->Fields()) { + field_decls.push_back(reinterpret_cast((VisitFieldDecl(field)))); + } + return factory_->NewStructType(node->Position(), std::move(field_decls)); +} + +AstNode *AstCloneImpl::VisitPointerTypeRepr(PointerTypeRepr *node) { + return factory_->NewPointerType(node->Position(), reinterpret_cast(Visit(node->Base()))); +} + +AstNode *AstCloneImpl::VisitFunctionTypeRepr(FunctionTypeRepr *node) { + util::RegionVector params(new_context_->GetRegion()); + for (auto *param : node->Parameters()) { + params.push_back(reinterpret_cast(VisitFieldDecl(param))); + } + + return factory_->NewFunctionType(node->Position(), std::move(params), + reinterpret_cast(Visit(node->ReturnType()))); +} + +AstNode *AstCloneImpl::VisitArrayTypeRepr(ArrayTypeRepr *node) { + return factory_->NewArrayType(node->Position(), reinterpret_cast(Visit(node->Length())), + reinterpret_cast(Visit(node->ElementType()))); +} + +AstNode *AstCloneImpl::VisitMapTypeRepr(MapTypeRepr *node) { + return factory_->NewMapType(node->Position(), reinterpret_cast(Visit(node->KeyType())), + reinterpret_cast(Visit(node->ValType()))); +} + +AstNode *AstCloneImpl::VisitLambdaTypeRepr(LambdaTypeRepr *node) { + return factory_->NewLambdaType(node->Position(), reinterpret_cast(Visit(node->FunctionType()))); +} + +AstNode *AstClone::Clone(AstNode *node, AstNodeFactory *factory, std::string prefix, Context *old_context, + Context *new_context) { + AstCloneImpl cloner(node, factory, old_context, new_context, prefix); + return cloner.Run(); +} + +} // namespace noisepage::execution::ast diff --git a/src/execution/ast/ast_dump.cpp b/src/execution/ast/ast_dump.cpp index 38a48e23cc..ececaef0c4 100644 --- a/src/execution/ast/ast_dump.cpp +++ b/src/execution/ast/ast_dump.cpp @@ -171,6 +171,11 @@ void AstDumperImpl::VisitFunctionDecl(FunctionDecl *node) { DumpExpr(node->Function()); } +void AstDumperImpl::VisitLambdaExpr(LambdaExpr *node) { + DumpNodeCommon(node); + DumpExpr(node->GetFunctionLitExpr()); +} + void AstDumperImpl::VisitVariableDecl(VariableDecl *node) { DumpNodeCommon(node); DumpIdentifier(node->Name()); @@ -203,6 +208,8 @@ void AstDumperImpl::VisitBlockStmt(BlockStmt *node) { } } +void AstDumperImpl::VisitBreakStmt(BreakStmt *node) { DumpNodeCommon(node); } + void AstDumperImpl::VisitDeclStmt(DeclStmt *node) { AstVisitor::Visit(node->Declaration()); } void AstDumperImpl::VisitExpressionStmt(ExpressionStmt *node) { AstVisitor::Visit(node->Expression()); } @@ -257,6 +264,11 @@ void AstDumperImpl::VisitCallExpr(CallExpr *node) { } case CallExpr::CallKind::Regular: { out_ << "Regular"; + break; + } + case CallExpr::CallKind::Lambda: { + out_ << "Lambda"; + break; } } } @@ -375,6 +387,11 @@ void AstDumperImpl::VisitMapTypeRepr(MapTypeRepr *node) { DumpExpr(node->ValType()); } +void AstDumperImpl::VisitLambdaTypeRepr(LambdaTypeRepr *node) { + DumpNodeCommon(node); + DumpExpr(node->FunctionType()); +} + std::string AstDump::Dump(AstNode *node) { llvm::SmallString<256> buffer; llvm::raw_svector_ostream stream(buffer); diff --git a/src/execution/ast/ast_pretty_print.cpp b/src/execution/ast/ast_pretty_print.cpp index 330c119237..fa8bfa4fb8 100644 --- a/src/execution/ast/ast_pretty_print.cpp +++ b/src/execution/ast/ast_pretty_print.cpp @@ -204,6 +204,12 @@ void AstPrettyPrintImpl::VisitMapTypeRepr(MapTypeRepr *node) { Visit(node->ValType()); } +void AstPrettyPrintImpl::VisitLambdaTypeRepr(LambdaTypeRepr *node) { + os_ << "lambda["; + Visit(node->FunctionType()); + os_ << "]"; +} + void AstPrettyPrintImpl::VisitLitExpr(LitExpr *node) { switch (node->GetLiteralKind()) { case LitExpr::LitKind::Nil: @@ -224,6 +230,8 @@ void AstPrettyPrintImpl::VisitLitExpr(LitExpr *node) { } } +void AstPrettyPrintImpl::VisitBreakStmt(BreakStmt *node) { os_ << "break;\n"; } + void AstPrettyPrintImpl::VisitStructTypeRepr(StructTypeRepr *node) { // We want to ensure all types are aligned. Pre-process the fields to // find longest field names, then align as appropriate. @@ -283,6 +291,11 @@ void AstPrettyPrintImpl::VisitIndexExpr(IndexExpr *node) { os_ << "]"; } +void AstPrettyPrintImpl::VisitLambdaExpr(LambdaExpr *node) { + os_ << "lambda "; + VisitFunctionLitExpr(node->GetFunctionLitExpr()); +} + void AstPrettyPrintImpl::VisitFunctionTypeRepr(FunctionTypeRepr *node) { os_ << "("; bool first = true; diff --git a/src/execution/ast/context.cpp b/src/execution/ast/context.cpp index 30d42ee2a1..86dab369db 100644 --- a/src/execution/ast/context.cpp +++ b/src/execution/ast/context.cpp @@ -150,8 +150,10 @@ struct Context::Implementation { llvm::DenseMap builtin_types_; llvm::DenseMap builtin_funcs_; llvm::DenseMap pointer_types_; + llvm::DenseMap reference_types_; llvm::DenseMap, ArrayType *> array_types_; llvm::DenseMap, MapType *> map_types_; + llvm::DenseMap lambda_types_; llvm::DenseSet struct_types_; llvm::DenseSet func_types_; @@ -232,6 +234,8 @@ Identifier Context::GetBuiltinType(BuiltinType::Kind kind) { PointerType *Type::PointerTo() { return PointerType::Get(this); } +ReferenceType *Type::ReferenceTo() { return ReferenceType::Get(this); } + // static BuiltinType *BuiltinType::Get(Context *ctx, BuiltinType::Kind kind) { return ctx->Impl()->builtin_types_list_[kind]; } @@ -251,6 +255,19 @@ PointerType *PointerType::Get(Type *base) { return pointer_type; } +// static +ReferenceType *ReferenceType::Get(Type *base) { + Context *ctx = base->GetContext(); + + ReferenceType *&reference_type = ctx->Impl()->reference_types_[base]; + + if (reference_type == nullptr) { + reference_type = new (ctx->GetRegion()) ReferenceType(base); + } + + return reference_type; +} + // static ArrayType *ArrayType::Get(uint64_t length, Type *elem_type) { Context *ctx = elem_type->GetContext(); @@ -287,6 +304,19 @@ Field CreatePaddingElement(uint32_t id, uint32_t size, Context *ctx) { }; // namespace +// static +LambdaType *LambdaType::Get(FunctionType *fn_type) { + Context *ctx = fn_type->GetContext(); + + LambdaType *&lambda_type = ctx->Impl()->lambda_types_[fn_type]; + + if (lambda_type == nullptr) { + lambda_type = new (ctx->GetRegion()) LambdaType(fn_type); + } + + return lambda_type; +} + // static StructType *StructType::Get(Context *ctx, util::RegionVector &&fields) { // Empty structs get an artificial element @@ -366,7 +396,7 @@ StructType *StructType::Get(util::RegionVector &&fields) { } // static -FunctionType *FunctionType::Get(util::RegionVector &¶ms, Type *ret) { +FunctionType *FunctionType::Get(util::RegionVector &¶ms, Type *ret, bool is_lambda = false) { Context *ctx = ret->GetContext(); const FunctionTypeKeyInfo::KeyTy key(ret, params); @@ -380,7 +410,7 @@ FunctionType *FunctionType::Get(util::RegionVector &¶ms, Type *ret) { if (inserted) { // The function type was not in the cache, create the type now and insert it // into the cache - func_type = new (ctx->GetRegion()) FunctionType(std::move(params), ret); + func_type = new (ctx->GetRegion()) FunctionType(std::move(params), ret, is_lambda); *iter = func_type; } else { func_type = *iter; diff --git a/src/execution/ast/type.cpp b/src/execution/ast/type.cpp index 2ea4e771fc..4b88f11a88 100644 --- a/src/execution/ast/type.cpp +++ b/src/execution/ast/type.cpp @@ -3,6 +3,7 @@ #include #include +#include "execution/ast/context.h" #include "execution/exec/execution_context.h" #include "execution/sql/aggregation_hash_table.h" #include "execution/sql/aggregators.h" @@ -85,10 +86,30 @@ const bool BuiltinType::SIGNED_FLAGS[] = { // Function Type // --------------------------------------------------------- -FunctionType::FunctionType(util::RegionVector &¶ms, Type *ret) +FunctionType::FunctionType(util::RegionVector &¶ms, Type *ret, bool is_lambda) : Type(ret->GetContext(), sizeof(void *), alignof(void *), TypeId::FunctionType), params_(std::move(params)), - ret_(ret) {} + ret_(ret), + is_lambda_(is_lambda) {} + +bool FunctionType::IsEqual(const FunctionType *other) { + if (other->params_.size() != params_.size()) { + return false; + } + + for (size_t i = 0; i < params_.size(); i++) { + if (params_[i].type_ != other->params_[i].type_) { + return false; + } + } + + return true; +} + +void FunctionType::RegisterCapture() { + NOISEPAGE_ASSERT(captures_ != nullptr, "no capture given?"); + params_.emplace_back(GetContext()->GetIdentifier("captures"), captures_); +} // --------------------------------------------------------- // Map Type @@ -100,6 +121,13 @@ MapType::MapType(Type *key_type, Type *val_type) key_type_(key_type), val_type_(val_type) {} +// --------------------------------------------------------- +// Lambda Type +// --------------------------------------------------------- + +LambdaType::LambdaType(FunctionType *fn_type) + : Type(fn_type->GetContext(), fn_type->GetSize(), fn_type->GetAlignment(), TypeId::LambdaType), fn_type_(fn_type) {} + // --------------------------------------------------------- // Struct Type // --------------------------------------------------------- diff --git a/src/execution/ast/type_printer.cpp b/src/execution/ast/type_printer.cpp index 1662f4fc9f..e262944dc0 100644 --- a/src/execution/ast/type_printer.cpp +++ b/src/execution/ast/type_printer.cpp @@ -54,6 +54,11 @@ void TypePrinter::VisitPointerType(const PointerType *type) { Visit(type->GetBase()); } +void TypePrinter::VisitReferenceType(const ReferenceType *type) { + Os() << "&"; + Visit(type->GetBase()); +} + void TypePrinter::VisitStructType(const StructType *type) { Os() << "struct{"; bool first = true; @@ -85,6 +90,12 @@ void execution::ast::TypePrinter::VisitMapType(const MapType *type) { Visit(type->GetValueType()); } +void execution::ast::TypePrinter::VisitLambdaType(const LambdaType *type) { + Os() << "lambda["; + Visit(type->GetFunctionType()); + Os() << "]"; +} + } // namespace // static diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 64891201da..c553bb46fa 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -3,9 +3,8 @@ #include "binder/bind_node_visitor.h" #include "execution/ast/ast.h" - -// TODO(Kyle): Not Ported Yet -// #include "execution/ast/ast_clone.h" +#include "execution/ast/ast_clone.h" +#include "execution/ast/context.h" #include "execution/compiler/compilation_context.h" #include "execution/compiler/executable_query.h" @@ -14,8 +13,8 @@ #include "execution/exec/execution_settings.h" #include "catalog/catalog_accessor.h" -#include "optimizer/statistics/stats_storage.h" #include "optimizer/cost_model/trivial_cost_model.h" +#include "optimizer/statistics/stats_storage.h" #include "traffic_cop/traffic_cop_util.h" @@ -51,11 +50,11 @@ UDFCodegen::UDFCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, const char *UDFCodegen::GetReturnParamString() { return "return_val"; } -void UDFCodegen::GenerateUDF(AbstractAST *ast) { ast->Accept(this); } +void UDFCodegen::GenerateUDF(ast::udf::AbstractAST *ast) { ast->Accept(this); } -void UDFCodegen::Visit(AbstractAST *ast) { UNREACHABLE("Not implemented"); } +void UDFCodegen::Visit(ast::udf::AbstractAST *ast) { UNREACHABLE("Not implemented"); } -void UDFCodegen::Visit(DynamicSQLStmtAST *ast) { UNREACHABLE("Not implemented"); } +void UDFCodegen::Visit(ast::udf::DynamicSQLStmtAST *ast) { UNREACHABLE("Not implemented"); } catalog::type_oid_t UDFCodegen::GetCatalogTypeOidFromSQLType(execution::ast::BuiltinType::Kind type) { switch (type) { @@ -71,8 +70,19 @@ catalog::type_oid_t UDFCodegen::GetCatalogTypeOidFromSQLType(execution::ast::Bui } } -void UDFCodegen::Visit(CallExprAST *ast) { - // UNREACHABLE("Not implemented"); +execution::ast::File *UDFCodegen::Finish() { + auto fn = fb_->Finish(); + // util::RegionVector decls_reg_vec{decls->begin(), decls->end(), codegen.Region()}; + execution::util::RegionVector decls({fn}, codegen_->GetAstContext()->GetRegion()); + // for(auto decl : aux_decls_){ + // decls.push_back(decl); + // } + decls.insert(decls.begin(), aux_decls_.begin(), aux_decls_.end()); + auto file = codegen_->GetAstContext()->GetNodeFactory()->NewFile({0, 0}, std::move(decls)); + return file; +} + +void UDFCodegen::Visit(ast::udf::CallExprAST *ast) { auto &args = ast->args; std::vector args_ast; std::vector args_ast_region_vec; @@ -89,6 +99,7 @@ void UDFCodegen::Visit(CallExprAST *ast) { } auto proc_oid = accessor_->GetProcOid(ast->callee, arg_types); NOISEPAGE_ASSERT(proc_oid != catalog::INVALID_PROC_OID, "Invalid call"); + auto context = accessor_->GetProcCtxPtr(proc_oid); if (context->IsBuiltin()) { fb_->Append(codegen_->MakeStmt(codegen_->CallBuiltin(context->GetBuiltin(), std::move(args_ast)))); @@ -109,14 +120,15 @@ void UDFCodegen::Visit(CallExprAST *ast) { } fb_->Append(codegen_->MakeStmt(codegen_->Call(ident_expr, args_ast_region_vec))); } - // fb_->Append(codegen_->Call) + + // fb_->Append(codegen_->Call) } -void UDFCodegen::Visit(StmtAST *ast) { UNREACHABLE("Not implemented"); } +void UDFCodegen::Visit(ast::udf::StmtAST *ast) { UNREACHABLE("Not implemented"); } -void UDFCodegen::Visit(ExprAST *ast) { UNREACHABLE("Not implemented"); } +void UDFCodegen::Visit(ast::udf::ExprAST *ast) { UNREACHABLE("Not implemented"); } -void UDFCodegen::Visit(DeclStmtAST *ast) { +void UDFCodegen::Visit(ast::udf::DeclStmtAST *ast) { if (ast->name == "*internal*") { return; } @@ -148,7 +160,7 @@ void UDFCodegen::Visit(DeclStmtAST *ast) { current_type_ = prev_type; } -void UDFCodegen::Visit(FunctionAST *ast) { +void UDFCodegen::Visit(ast::udf::FunctionAST *ast) { for (size_t i = 0; i < ast->param_types_.size(); i++) { // auto param_type = codegen_->TplType(ast->param_types_[i]); str_to_ident_.emplace(ast->param_names_[i], codegen_->MakeFreshIdentifier("udf")); @@ -156,13 +168,13 @@ void UDFCodegen::Visit(FunctionAST *ast) { ast->body.get()->Accept(this); } -void UDFCodegen::Visit(VariableExprAST *ast) { +void UDFCodegen::Visit(ast::udf::VariableExprAST *ast) { auto it = str_to_ident_.find(ast->name); NOISEPAGE_ASSERT(it != str_to_ident_.end(), "variable not declared"); dst_ = codegen_->MakeExpr(it->second); } -void UDFCodegen::Visit(ValueExprAST *ast) { +void UDFCodegen::Visit(ast::udf::ValueExprAST *ast) { auto val = common::ManagedPointer(ast->value_).CastManagedPointerTo(); if (val->IsNull()) { dst_ = codegen_->ConstNull(current_type_); @@ -196,12 +208,12 @@ void UDFCodegen::Visit(ValueExprAST *ast) { } } -void UDFCodegen::Visit(AssignStmtAST *ast) { +void UDFCodegen::Visit(ast::udf::AssignStmtAST *ast) { type::TypeId left_type = type::TypeId::INVALID; udf_ast_context_->GetVariableType(ast->lhs->name, &left_type); current_type_ = left_type; - reinterpret_cast(ast->rhs.get())->Accept(this); + reinterpret_cast(ast->rhs.get())->Accept(this); auto rhs_expr = dst_; auto it = str_to_ident_.find(ast->lhs->name); @@ -217,7 +229,7 @@ void UDFCodegen::Visit(AssignStmtAST *ast) { // } } -void UDFCodegen::Visit(BinaryExprAST *ast) { +void UDFCodegen::Visit(ast::udf::BinaryExprAST *ast) { execution::parsing::Token::Type op_token; bool compare = false; switch (ast->op) { @@ -278,7 +290,7 @@ void UDFCodegen::Visit(BinaryExprAST *ast) { } } -void UDFCodegen::Visit(IfStmtAST *ast) { +void UDFCodegen::Visit(ast::udf::IfStmtAST *ast) { ast->cond_expr->Accept(this); auto cond = dst_; @@ -291,7 +303,7 @@ void UDFCodegen::Visit(IfStmtAST *ast) { branch.EndIf(); } -void UDFCodegen::Visit(IsNullExprAST *ast) { +void UDFCodegen::Visit(ast::udf::IsNullExprAST *ast) { ast->child_->Accept(this); auto chld = dst_; dst_ = codegen_->CallBuiltin(execution::ast::Builtin::IsValNull, {chld}); @@ -300,13 +312,7 @@ void UDFCodegen::Visit(IsNullExprAST *ast) { } } -void UDFCodegen::Visit(SeqStmtAST *ast) { - for (auto &stmt : ast->stmts) { - stmt->Accept(this); - } -} - -void UDFCodegen::Visit(WhileStmtAST *ast) { +void UDFCodegen::Visit(ast::udf::WhileStmtAST *ast) { ast->cond_expr->Accept(this); auto cond = dst_; // cond = codegen_->Compare(execution::parsing::Token::Type::EQUAL_EQUAL, cond, ) @@ -316,332 +322,343 @@ void UDFCodegen::Visit(WhileStmtAST *ast) { loop.EndLoop(); } -void UDFCodegen::Visit(ForStmtAST *ast) { - needs_exec_ctx_ = true; - const auto query = common::ManagedPointer(ast->query_); - auto exec_ctx = fb_->GetParameterByPosition(0); - - // TODO(Matt): I don't think the binder should need the database name. It's already bound in the ConnectionContext - binder::BindNodeVisitor visitor(common::ManagedPointer(accessor_), db_oid_); - auto query_params = visitor.BindAndGetUDFParams(query, common::ManagedPointer(udf_ast_context_)); - - auto stats = optimizer::StatsStorage(); - - std::unique_ptr plan = trafficcop::TrafficCopUtil::Optimize( - accessor_->GetTxn(), common::ManagedPointer(accessor_), query, db_oid_, common::ManagedPointer(&stats), - std::make_unique(), 1000000); - // make lambda that just writes into this - std::vector var_idents; - auto lam_var = codegen_->MakeFreshIdentifier("looplamb"); - execution::util::RegionVector params(codegen_->GetAstContext()->GetRegion()); - params.push_back(codegen_->MakeField( - exec_ctx->As()->Name(), - codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); - size_t i = 0; - for (auto var : ast->vars_) { - var_idents.push_back(str_to_ident_.find(var)->second); - auto var_ident = var_idents.back(); - // NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, "Can't support non scalars yet!"); - auto type = codegen_->TplType(execution::sql::GetTypeId(plan->GetOutputSchema()->GetColumn(i).GetType())); - - fb_->Append(codegen_->Assign(codegen_->MakeExpr(var_ident), - codegen_->ConstNull(plan->GetOutputSchema()->GetColumn(i).GetType()))); - auto input = codegen_->MakeFreshIdentifier(var); - params.push_back(codegen_->MakeField(input, type)); - i++; - } - execution::ast::LambdaExpr *lambda_expr; - FunctionBuilder fn(codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)); - { - size_t j = 1; - for (auto var : var_idents) { - fn.Append(codegen_->Assign(codegen_->MakeExpr(var), fn.GetParameterByPosition(j))); - j++; - } - auto prev_fb = fb_; - fb_ = &fn; - ast->body_stmt_->Accept(this); - fb_ = prev_fb; - } - - execution::util::RegionVector captures(codegen_->GetAstContext()->GetRegion()); - for (auto it : str_to_ident_) { - if (it.first == "executionCtx") { - continue; - } - captures.push_back(codegen_->MakeExpr(it.second)); - } - - lambda_expr = fn.FinishLambda(std::move(captures)); - lambda_expr->SetName(lam_var); - - // want to pass something down that will materialize the lambda function for me into lambda_expr and will - // also feed in a lambda_expr to the compiler - execution::exec::ExecutionSettings exec_settings{}; - const std::string dummy_query = ""; - auto exec_query = execution::compiler::CompilationContext::Compile( - *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, - common::ManagedPointer(&dummy_query), lambda_expr, codegen_->GetAstContext()); - auto fns = exec_query->GetFunctions(); - auto decls = exec_query->GetDecls(); - - aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); - - fb_->Append( - codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLitExpr()->TypeRepr()), lambda_expr)); - - // make query state - auto query_state = codegen_->MakeFreshIdentifier("query_state"); - fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); - // set its execution context to whatever exec context was passed in here - fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); - std::vector>::iterator> sorted_vec; - for (auto it = query_params.begin(); it != query_params.end(); it++) { - sorted_vec.push_back(it); - } - - std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second < y->second; }); - for (auto entry : sorted_vec) { - // TODO(order these dudes) - type::TypeId type = type::TypeId::INVALID; - udf_ast_context_->GetVariableType(entry->first, &type); - // NOISEPAGE_ASSERT(ret, "didn't find param in udf ast context"); - - execution::ast::Builtin builtin; - switch (type) { - case type::TypeId::BOOLEAN: - builtin = execution::ast::Builtin::AddParamBool; - break; - case type::TypeId::TINYINT: - builtin = execution::ast::Builtin::AddParamTinyInt; - break; - case type::TypeId::SMALLINT: - builtin = execution::ast::Builtin::AddParamSmallInt; - break; - case type::TypeId::INTEGER: - builtin = execution::ast::Builtin::AddParamInt; - break; - case type::TypeId::BIGINT: - builtin = execution::ast::Builtin::AddParamBigInt; - break; - case type::TypeId::DECIMAL: - builtin = execution::ast::Builtin::AddParamDouble; - break; - case type::TypeId::DATE: - builtin = execution::ast::Builtin::AddParamDate; - break; - case type::TypeId::TIMESTAMP: - builtin = execution::ast::Builtin::AddParamTimestamp; - break; - case type::TypeId::VARCHAR: - builtin = execution::ast::Builtin::AddParamString; - break; - default: - UNREACHABLE("Unsupported parameter type"); - } - fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, codegen_->MakeExpr(str_to_ident_[entry->first])})); - } - // set param 1 - // set param 2 - // etc etc - fb_->Append(codegen_->Assign( - codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); - // set its execution context to whatever exec context was passed in here - - for (auto &sub_fn : fns) { - // aux_decls_.push_back(c) - if (sub_fn.find("Run") != std::string::npos) { - fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), - {codegen_->AddressOf(query_state), codegen_->MakeExpr(lam_var)})); - } else { - fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), {codegen_->AddressOf(query_state)})); - } - } - - fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); - - return; +// TODO(Kyle): Implement +void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { + throw NOT_IMPLEMENTED_EXCEPTION("Visit(ForStmtAst*) Not Implemented"); + // needs_exec_ctx_ = true; + // const auto query = common::ManagedPointer(ast->query_); + // auto exec_ctx = fb_->GetParameterByPosition(0); + + // // TODO(Matt): I don't think the binder should need the database name. It's already bound in the ConnectionContext + // binder::BindNodeVisitor visitor(common::ManagedPointer(accessor_), db_oid_); + + // auto query_params = visitor.BindAndGetUDFParams(query, common::ManagedPointer(udf_ast_context_)); + + // auto stats = optimizer::StatsStorage(); + + // std::unique_ptr plan = trafficcop::TrafficCopUtil::Optimize( + // accessor_->GetTxn(), common::ManagedPointer(accessor_), query, db_oid_, common::ManagedPointer(&stats), + // std::make_unique(), 1000000); + // // make lambda that just writes into this + // std::vector var_idents; + // auto lam_var = codegen_->MakeFreshIdentifier("looplamb"); + // execution::util::RegionVector params(codegen_->GetAstContext()->GetRegion()); + // params.push_back(codegen_->MakeField( + // exec_ctx->As()->Name(), + // codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); + // size_t i = 0; + // for (auto var : ast->vars_) { + // var_idents.push_back(str_to_ident_.find(var)->second); + // auto var_ident = var_idents.back(); + // // NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, "Can't support non scalars yet!"); + // auto type = codegen_->TplType(execution::sql::GetTypeId(plan->GetOutputSchema()->GetColumn(i).GetType())); + + // fb_->Append(codegen_->Assign(codegen_->MakeExpr(var_ident), + // codegen_->ConstNull(plan->GetOutputSchema()->GetColumn(i).GetType()))); + // auto input = codegen_->MakeFreshIdentifier(var); + // params.push_back(codegen_->MakeField(input, type)); + // i++; + // } + // execution::ast::LambdaExpr *lambda_expr; + // FunctionBuilder fn(codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)); + // { + // size_t j = 1; + // for (auto var : var_idents) { + // fn.Append(codegen_->Assign(codegen_->MakeExpr(var), fn.GetParameterByPosition(j))); + // j++; + // } + // auto prev_fb = fb_; + // fb_ = &fn; + // ast->body_stmt_->Accept(this); + // fb_ = prev_fb; + // } + + // execution::util::RegionVector captures(codegen_->GetAstContext()->GetRegion()); + // for (auto it : str_to_ident_) { + // if (it.first == "executionCtx") { + // continue; + // } + // captures.push_back(codegen_->MakeExpr(it.second)); + // } + + // lambda_expr = fn.FinishLambda(std::move(captures)); + // lambda_expr->SetName(lam_var); + + // // want to pass something down that will materialize the lambda function for me into lambda_expr and will + // // also feed in a lambda_expr to the compiler + // execution::exec::ExecutionSettings exec_settings{}; + // const std::string dummy_query = ""; + // auto exec_query = execution::compiler::CompilationContext::Compile( + // *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, + // common::ManagedPointer(&dummy_query), lambda_expr, codegen_->GetAstContext()); + // auto fns = exec_query->GetFunctions(); + // auto decls = exec_query->GetDecls(); + + // aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); + + // fb_->Append( + // codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLitExpr()->TypeRepr()), + // lambda_expr)); + + // // make query state + // auto query_state = codegen_->MakeFreshIdentifier("query_state"); + // fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); + // // set its execution context to whatever exec context was passed in here + // fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); + // std::vector>::iterator> sorted_vec; + // for (auto it = query_params.begin(); it != query_params.end(); it++) { + // sorted_vec.push_back(it); + // } + + // std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second < y->second; }); + // for (auto entry : sorted_vec) { + // // TODO(order these dudes) + // type::TypeId type = type::TypeId::INVALID; + // udf_ast_context_->GetVariableType(entry->first, &type); + // // NOISEPAGE_ASSERT(ret, "didn't find param in udf ast context"); + + // execution::ast::Builtin builtin; + // switch (type) { + // case type::TypeId::BOOLEAN: + // builtin = execution::ast::Builtin::AddParamBool; + // break; + // case type::TypeId::TINYINT: + // builtin = execution::ast::Builtin::AddParamTinyInt; + // break; + // case type::TypeId::SMALLINT: + // builtin = execution::ast::Builtin::AddParamSmallInt; + // break; + // case type::TypeId::INTEGER: + // builtin = execution::ast::Builtin::AddParamInt; + // break; + // case type::TypeId::BIGINT: + // builtin = execution::ast::Builtin::AddParamBigInt; + // break; + // case type::TypeId::DECIMAL: + // builtin = execution::ast::Builtin::AddParamDouble; + // break; + // case type::TypeId::DATE: + // builtin = execution::ast::Builtin::AddParamDate; + // break; + // case type::TypeId::TIMESTAMP: + // builtin = execution::ast::Builtin::AddParamTimestamp; + // break; + // case type::TypeId::VARCHAR: + // builtin = execution::ast::Builtin::AddParamString; + // break; + // default: + // UNREACHABLE("Unsupported parameter type"); + // } + // fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, codegen_->MakeExpr(str_to_ident_[entry->first])})); + // } + // // set param 1 + // // set param 2 + // // etc etc + // fb_->Append(codegen_->Assign( + // codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); + // // set its execution context to whatever exec context was passed in here + + // for (auto &sub_fn : fns) { + // // aux_decls_.push_back(c) + // if (sub_fn.find("Run") != std::string::npos) { + // fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), + // {codegen_->AddressOf(query_state), codegen_->MakeExpr(lam_var)})); + // } else { + // fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), + // {codegen_->AddressOf(query_state)})); + // } + // } + + // fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); + + // return; } -void UDFCodegen::Visit(RetStmtAST *ast) { +void UDFCodegen::Visit(ast::udf::RetStmtAST *ast) { ast->expr->Accept(reinterpret_cast(this)); auto ret_expr = dst_; fb_->Append(codegen_->Return(ret_expr)); } -void UDFCodegen::Visit(SQLStmtAST *ast) { - needs_exec_ctx_ = true; - auto exec_ctx = fb_->GetParameterByPosition(0); - const auto query = common::ManagedPointer(ast->query); - - // TODO(Matt): I don't think the binder should need the database name. It's already bound in the ConnectionContext - binder::BindNodeVisitor visitor(common::ManagedPointer(accessor_), db_oid_); - // auto query_params = visitor.BindAndGetUDFParams(query, common::ManagedPointer(udf_ast_context_)); - auto query_params = ast->udf_params; - auto stats = optimizer::StatsStorage(); - - std::unique_ptr plan = trafficcop::TrafficCopUtil::Optimize( - accessor_->GetTxn(), common::ManagedPointer(accessor_), query, db_oid_, common::ManagedPointer(&stats), - std::make_unique(), 1000000); - // make lambda that just writes into this - - auto lam_var = codegen_->MakeFreshIdentifier("lamb"); - // NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, "Can't support non scalars yet!"); - auto &cols = plan->GetOutputSchema()->GetColumns(); - // auto &col = cols[0]; - execution::util::RegionVector params(codegen_->GetAstContext()->GetRegion()); - std::vector assignees; - execution::util::RegionVector captures(codegen_->GetAstContext()->GetRegion()); - size_t i = 0; - params.push_back(codegen_->MakeField( - exec_ctx->As()->Name(), - codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); - for (auto &col : cols) { - execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->var_name)->second); - type::TypeId udf_type; - udf_ast_context_->GetVariableType(ast->var_name, &udf_type); - if (udf_type == type::TypeId::INVALID) { - // record type - auto &struct_vars = udf_ast_context_->GetRecordType(ast->var_name); - if (captures.empty()) { - captures.push_back(capture_var); - } - capture_var = codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(struct_vars[i].first)); - assignees.push_back(capture_var); - } else { - assignees.push_back(capture_var); - captures.push_back(capture_var); - } - // auto capture_var = str_to_ident_.find(ast->var_name)->second; - auto type = codegen_->TplType(execution::sql::GetTypeId(col.GetType())); - - auto input_param = codegen_->MakeFreshIdentifier("input"); - params.push_back(codegen_->MakeField(input_param, type)); - i++; - } - - execution::ast::LambdaExpr *lambda_expr; - FunctionBuilder fn(codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)); - { - for (size_t j = 0; j < assignees.size(); j++) { - auto capture_var = assignees[j]; - auto input_param = fn.GetParameterByPosition(j + 1); - fn.Append(codegen_->Assign(capture_var, input_param)); - } - } - - lambda_expr = fn.FinishLambda(std::move(captures)); - lambda_expr->SetName(lam_var); - - // want to pass something down that will materialize the lambda function for me into lambda_expr and will - // also feed in a lambda_expr to the compiler - execution::exec::ExecutionSettings exec_settings{}; - const std::string dummy_query = ""; - auto exec_query = execution::compiler::CompilationContext::Compile( - *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, - common::ManagedPointer(&dummy_query), lambda_expr, codegen_->GetAstContext()); - auto fns = exec_query->GetFunctions(); - auto decls = exec_query->GetDecls(); - - aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); - - fb_->Append( - codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLitExpr()->TypeRepr()), lambda_expr)); - - // make query state - auto query_state = codegen_->MakeFreshIdentifier("query_state"); - fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); - // set its execution context to whatever exec context was passed in here - fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); - std::vector>::iterator> sorted_vec; - for (auto it = query_params.begin(); it != query_params.end(); it++) { - sorted_vec.push_back(it); - } - - std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second.second < y->second.second; }); - for (auto entry : sorted_vec) { - // TODO(order these dudes) - type::TypeId type = type::TypeId::INVALID; - execution::ast::Expr *expr = nullptr; - if (entry->second.first.length() > 0) { - auto &fields = udf_ast_context_->GetRecordType(entry->second.first); - auto it = std::find_if(fields.begin(), fields.end(), [=](auto p) { return p.first == entry->first; }); - type = it->second; - expr = codegen_->AccessStructMember(codegen_->MakeExpr(str_to_ident_[entry->second.first]), - codegen_->MakeIdentifier(entry->first)); - } else { - udf_ast_context_->GetVariableType(entry->first, &type); - expr = codegen_->MakeExpr(str_to_ident_[entry->first]); - } - - // NOISEPAGE_ASSERT(ret, "didn't find param in udf ast context"); - execution::ast::Builtin builtin; - switch (type) { - case type::TypeId::BOOLEAN: - builtin = execution::ast::Builtin::AddParamBool; - break; - case type::TypeId::TINYINT: - builtin = execution::ast::Builtin::AddParamTinyInt; - break; - case type::TypeId::SMALLINT: - builtin = execution::ast::Builtin::AddParamSmallInt; - break; - case type::TypeId::INTEGER: - builtin = execution::ast::Builtin::AddParamInt; - break; - case type::TypeId::BIGINT: - builtin = execution::ast::Builtin::AddParamBigInt; - break; - case type::TypeId::DECIMAL: - builtin = execution::ast::Builtin::AddParamDouble; - break; - case type::TypeId::DATE: - builtin = execution::ast::Builtin::AddParamDate; - break; - case type::TypeId::TIMESTAMP: - builtin = execution::ast::Builtin::AddParamTimestamp; - break; - case type::TypeId::VARCHAR: - builtin = execution::ast::Builtin::AddParamString; - break; - default: - UNREACHABLE("Unsupported parameter type"); - } - fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, expr})); - } - // set param 1 - // set param 2 - // etc etc - fb_->Append(codegen_->Assign( - codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); - - for (auto &col : cols) { - execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->var_name)->second); - auto lhs = capture_var; - if (cols.size() > 1) { - // record struct type - lhs = codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(col.GetName())); - } - fb_->Append(codegen_->Assign(lhs, codegen_->ConstNull(col.GetType()))); - } - // set its execution context to whatever exec context was passed in here - - for (auto &sub_fn : fns) { - // aux_decls_.push_back(c) - if (sub_fn.find("Run") != std::string::npos) { - fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), - {codegen_->AddressOf(query_state), codegen_->MakeExpr(lam_var)})); - } else { - fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), {codegen_->AddressOf(query_state)})); - } - } - - fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); - - return; +// TODO(Kyle): Implement +void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { + throw NOT_IMPLEMENTED_EXCEPTION("Visit(SQLStmtAST*) Not Implemented"); + // needs_exec_ctx_ = true; + // auto exec_ctx = fb_->GetParameterByPosition(0); + // const auto query = common::ManagedPointer(ast->query); + + // // TODO(Matt): I don't think the binder should need the database name. It's already bound in the ConnectionContext + // binder::BindNodeVisitor visitor(common::ManagedPointer(accessor_), db_oid_); + + // TODO(Kyle): Implement + // // auto query_params = visitor.BindAndGetUDFParams(query, common::ManagedPointer(udf_ast_context_)); + // auto query_params = ast->udf_params; + // auto stats = optimizer::StatsStorage(); + + // std::unique_ptr plan = trafficcop::TrafficCopUtil::Optimize( + // accessor_->GetTxn(), common::ManagedPointer(accessor_), query, db_oid_, common::ManagedPointer(&stats), + // std::make_unique(), 1000000); + // // make lambda that just writes into this + + // auto lam_var = codegen_->MakeFreshIdentifier("lamb"); + // // NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, "Can't support non scalars yet!"); + // auto &cols = plan->GetOutputSchema()->GetColumns(); + // // auto &col = cols[0]; + // execution::util::RegionVector params(codegen_->GetAstContext()->GetRegion()); + // std::vector assignees; + // execution::util::RegionVector captures(codegen_->GetAstContext()->GetRegion()); + // size_t i = 0; + // params.push_back(codegen_->MakeField( + // exec_ctx->As()->Name(), + // codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); + // for (auto &col : cols) { + // execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->var_name)->second); + // type::TypeId udf_type; + // udf_ast_context_->GetVariableType(ast->var_name, &udf_type); + // if (udf_type == type::TypeId::INVALID) { + // // record type + // auto &struct_vars = udf_ast_context_->GetRecordType(ast->var_name); + // if (captures.empty()) { + // captures.push_back(capture_var); + // } + // capture_var = codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(struct_vars[i].first)); + // assignees.push_back(capture_var); + // } else { + // assignees.push_back(capture_var); + // captures.push_back(capture_var); + // } + // // auto capture_var = str_to_ident_.find(ast->var_name)->second; + // auto type = codegen_->TplType(execution::sql::GetTypeId(col.GetType())); + + // auto input_param = codegen_->MakeFreshIdentifier("input"); + // params.push_back(codegen_->MakeField(input_param, type)); + // i++; + // } + + // execution::ast::LambdaExpr *lambda_expr; + // FunctionBuilder fn(codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)); + // { + // for (size_t j = 0; j < assignees.size(); j++) { + // auto capture_var = assignees[j]; + // auto input_param = fn.GetParameterByPosition(j + 1); + // fn.Append(codegen_->Assign(capture_var, input_param)); + // } + // } + + // lambda_expr = fn.FinishLambda(std::move(captures)); + // lambda_expr->SetName(lam_var); + + // // want to pass something down that will materialize the lambda function for me into lambda_expr and will + // // also feed in a lambda_expr to the compiler + // execution::exec::ExecutionSettings exec_settings{}; + // const std::string dummy_query = ""; + // auto exec_query = execution::compiler::CompilationContext::Compile( + // *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, + // common::ManagedPointer(&dummy_query), lambda_expr, codegen_->GetAstContext()); + // auto fns = exec_query->GetFunctions(); + // auto decls = exec_query->GetDecls(); + + // aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); + + // fb_->Append( + // codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLitExpr()->TypeRepr()), + // lambda_expr)); + + // // make query state + // auto query_state = codegen_->MakeFreshIdentifier("query_state"); + // fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); + // // set its execution context to whatever exec context was passed in here + // fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); + // std::vector>::iterator> sorted_vec; + // for (auto it = query_params.begin(); it != query_params.end(); it++) { + // sorted_vec.push_back(it); + // } + + // std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second.second < y->second.second; + // }); for (auto entry : sorted_vec) { + // // TODO(order these dudes) + // type::TypeId type = type::TypeId::INVALID; + // execution::ast::Expr *expr = nullptr; + // if (entry->second.first.length() > 0) { + // auto &fields = udf_ast_context_->GetRecordType(entry->second.first); + // auto it = std::find_if(fields.begin(), fields.end(), [=](auto p) { return p.first == entry->first; }); + // type = it->second; + // expr = codegen_->AccessStructMember(codegen_->MakeExpr(str_to_ident_[entry->second.first]), + // codegen_->MakeIdentifier(entry->first)); + // } else { + // udf_ast_context_->GetVariableType(entry->first, &type); + // expr = codegen_->MakeExpr(str_to_ident_[entry->first]); + // } + + // // NOISEPAGE_ASSERT(ret, "didn't find param in udf ast context"); + // execution::ast::Builtin builtin; + // switch (type) { + // case type::TypeId::BOOLEAN: + // builtin = execution::ast::Builtin::AddParamBool; + // break; + // case type::TypeId::TINYINT: + // builtin = execution::ast::Builtin::AddParamTinyInt; + // break; + // case type::TypeId::SMALLINT: + // builtin = execution::ast::Builtin::AddParamSmallInt; + // break; + // case type::TypeId::INTEGER: + // builtin = execution::ast::Builtin::AddParamInt; + // break; + // case type::TypeId::BIGINT: + // builtin = execution::ast::Builtin::AddParamBigInt; + // break; + // case type::TypeId::DECIMAL: + // builtin = execution::ast::Builtin::AddParamDouble; + // break; + // case type::TypeId::DATE: + // builtin = execution::ast::Builtin::AddParamDate; + // break; + // case type::TypeId::TIMESTAMP: + // builtin = execution::ast::Builtin::AddParamTimestamp; + // break; + // case type::TypeId::VARCHAR: + // builtin = execution::ast::Builtin::AddParamString; + // break; + // default: + // UNREACHABLE("Unsupported parameter type"); + // } + // fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, expr})); + // } + // // set param 1 + // // set param 2 + // // etc etc + // fb_->Append(codegen_->Assign( + // codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); + + // for (auto &col : cols) { + // execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->var_name)->second); + // auto lhs = capture_var; + // if (cols.size() > 1) { + // // record struct type + // lhs = codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(col.GetName())); + // } + // fb_->Append(codegen_->Assign(lhs, codegen_->ConstNull(col.GetType()))); + // } + // // set its execution context to whatever exec context was passed in here + + // for (auto &sub_fn : fns) { + // // aux_decls_.push_back(c) + // if (sub_fn.find("Run") != std::string::npos) { + // fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), + // {codegen_->AddressOf(query_state), codegen_->MakeExpr(lam_var)})); + // } else { + // fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), + // {codegen_->AddressOf(query_state)})); + // } + // } + + // fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); + + // return; } -void UDFCodegen::Visit(MemberExprAST *ast) { +void UDFCodegen::Visit(ast::udf::MemberExprAST *ast) { ast->object->Accept(reinterpret_cast(this)); auto object = dst_; dst_ = codegen_->AccessStructMember(object, codegen_->MakeIdentifier(ast->field)); diff --git a/src/execution/sema/sema_expr.cpp b/src/execution/sema/sema_expr.cpp index bb48f23bf2..332bc8f381 100644 --- a/src/execution/sema/sema_expr.cpp +++ b/src/execution/sema/sema_expr.cpp @@ -141,6 +141,102 @@ void Sema::VisitCallExpr(ast::CallExpr *node) { node->SetType(func_type->GetReturnType()); } +void Sema::VisitLambdaExpr(ast::LambdaExpr *node) { + // make struct type + // node->SetType(Resolve(node->GetFunctionLitExpr()->TypeRepr())); + // const auto &locals = GetCurrentScope()->GetLocals(); + auto factory = GetContext()->GetNodeFactory(); + util::RegionVector fields(GetContext()->GetRegion()); + // std::unordered_set used_idents; + // TODO support more than just assignment statements + // for(auto s : node->GetFunctionLitExpr()->Body()->Statements()){ + // if(s->IsAssignmentStmt()) { + // auto expr = s->As()->Destination()->As(); + // used_idents.insert(expr->Name()); + // auto s_expr = s->As()->Source()->SafeAs(); + // if(s_expr != nullptr){ + // used_idents.insert(s_expr->Name()); + // } + // } + // } + // for(auto local : used_idents){ + // auto name = local; + // auto iter = std::find_if(locals.begin(), locals.end(), [=](auto p){ return p.first == name; }); + // if(iter == locals.end()){ + // continue; + // } + // auto type = iter->second; + // ast::Expr *type_repr = nullptr; + // if(type->IsBuiltinType()) { + // type_repr = factory->NewPointerType(SourcePosition(), + // factory->NewIdentifierExpr(SourcePosition(), + // GetContext()->GetIdentifier(ast::BuiltinType::Get(GetContext(), + // type->As()->GetKind()) + // ->GetTplName()))); + // }else{ + // if(type->IsLambdaType()){ + // continue; + // } + // NOISEPAGE_ASSERT(false, "UNSUPPORTED CAPTURED TYPE"); + // } + // type_repr->SetType(type->PointerTo()); + // ast::FieldDecl *field = factory->NewFieldDecl(SourcePosition(), name, type_repr); + // fields.push_back(field); + // } + for (auto expr : node->GetCaptureIdents()) { + auto ident = expr->As(); + Resolve(ident); + if (ident->GetType()->SafeAs()) { + auto type_repr = factory->NewPointerType( + SourcePosition(), + factory->NewIdentifierExpr( + SourcePosition(), + GetContext()->GetIdentifier( + ast::BuiltinType::Get(GetContext(), ident->GetType()->As()->GetKind()) + ->GetTplName()))); + fields.push_back(factory->NewFieldDecl(SourcePosition(), ident->Name(), type_repr)); + } else { + util::RegionVector fields2(GetContext()->GetRegion()); + for (auto field : ident->GetType()->SafeAs()->GetFieldsWithoutPadding()) { + fields2.push_back(factory->NewFieldDecl( + SourcePosition(), field.name_, + factory->NewIdentifierExpr( + SourcePosition(), + GetContext()->GetIdentifier( + ast::BuiltinType::Get(GetContext(), field.type_->As()->GetKind()) + ->GetTplName())))); + } + + auto type_repr = + factory->NewPointerType(SourcePosition(), factory->NewStructType(SourcePosition(), std::move(fields2))); + fields.push_back(factory->NewFieldDecl(SourcePosition(), ident->Name(), type_repr)); + } + } + fields.push_back( + factory->NewFieldDecl(SourcePosition(), GetContext()->GetIdentifier("function"), + factory->NewPointerType(SourcePosition(), node->GetFunctionLitExpr()->TypeRepr()))); + + ast::StructTypeRepr *struct_type_repr = factory->NewStructType(SourcePosition(), std::move(fields)); + // TODO(tanujnay112) Find a better name + ast::StructDecl *struct_decl = factory->NewStructDecl( + SourcePosition(), GetContext()->GetIdentifier("lambda" + std::to_string(node->Position().line_)), + struct_type_repr); + VisitStructDecl(struct_decl); + node->capture_type_ = Resolve(struct_type_repr); + node->SetType(ast::LambdaType::Get(Resolve(node->GetFunctionLitExpr()->TypeRepr())->As())); + // GetCurrentScope()->Declare(struct_decl->Name(), node->capture_type_); + + // TODO(Kyle): Why do we need to modify internals? + // auto type = Resolve(node->GetFunctionLitExpr()->TypeRepr()); + // auto fn_type = type->As(); + // fn_type->GetParams().emplace_back(GetContext()->GetIdentifier("captures"), + // GetBuiltinType(ast::BuiltinType::Kind::Int32)->PointerTo()); + // fn_type->is_lambda_ = true; + // fn_type->captures_ = node->GetCaptureStructType()->As(); + + VisitFunctionLitExpr(node->GetFunctionLitExpr()); +} + void Sema::VisitFunctionLitExpr(ast::FunctionLitExpr *node) { // Resolve the type, if not resolved already if (auto *type = node->TypeRepr()->GetType(); type == nullptr) { diff --git a/src/execution/sema/sema_stmt.cpp b/src/execution/sema/sema_stmt.cpp index e0972962e9..9b004a064f 100644 --- a/src/execution/sema/sema_stmt.cpp +++ b/src/execution/sema/sema_stmt.cpp @@ -70,6 +70,23 @@ void Sema::VisitForStmt(ast::ForStmt *node) { Visit(node->Body()); } +// TODO(Kyle): Implement. +void Sema::VisitBreakStmt(ast::BreakStmt *node) { + // look for a loop in my scope stack + // auto scope = GetCurrentScope(); + // bool found_loop = false; + // while(scope != nullptr){ + // found_loop |= scope->GetKind() == Scope::Kind::Loop; + // if(found_loop){ + // break; + // } + // scope = scope->Outer(); + // } + // if(!found_loop){ + // error_reporter_->Report(node->Position(), ErrorMessages::kNoScopeToBreak); + // } +} + void Sema::VisitForInStmt(ast::ForInStmt *node) { NOISEPAGE_ASSERT(false, "Not supported"); } void Sema::VisitExpressionStmt(ast::ExpressionStmt *node) { Visit(node->Expression()); } diff --git a/src/execution/sema/sema_type.cpp b/src/execution/sema/sema_type.cpp index fb5087c715..9e744d2439 100644 --- a/src/execution/sema/sema_type.cpp +++ b/src/execution/sema/sema_type.cpp @@ -51,7 +51,8 @@ void Sema::VisitFunctionTypeRepr(ast::FunctionTypeRepr *node) { } // Create type - ast::FunctionType *func_type = ast::FunctionType::Get(std::move(param_types), ret); + // TODO(Kyle): this is a bad API + ast::FunctionType *func_type = ast::FunctionType::Get(std::move(param_types), ret, false); node->SetType(func_type); } @@ -90,4 +91,15 @@ void Sema::VisitMapTypeRepr(ast::MapTypeRepr *node) { node->SetType(ast::MapType::Get(key_type, value_type)); } +void Sema::VisitLambdaTypeRepr(ast::LambdaTypeRepr *node) { + ast::FunctionType *fn_type = Resolve(node->FunctionType())->SafeAs(); + if (fn_type == nullptr) { + return; + } + + fn_type->GetParams().emplace_back(GetContext()->GetIdentifier("captures"), + GetBuiltinType(ast::BuiltinType::Kind::Int32)->PointerTo()); + node->SetType(ast::LambdaType::Get(fn_type)); +} + } // namespace noisepage::execution::sema diff --git a/src/execution/vm/bytecode_generator.cpp b/src/execution/vm/bytecode_generator.cpp index 06019228be..44f4a189a3 100644 --- a/src/execution/vm/bytecode_generator.cpp +++ b/src/execution/vm/bytecode_generator.cpp @@ -186,6 +186,14 @@ void BytecodeGenerator::VisitForStmt(ast::ForStmt *node) { loop_builder.JumpToHeader(); } +void BytecodeGenerator::VisitBreakStmt(ast::BreakStmt *node) { + // TODO(Kyle): Implement. + throw NOT_IMPLEMENTED_EXCEPTION("VisitBreakStmt Not Implemented"); + // if(current_loop_ != nullptr && current_loop_->GetPrev() != nullptr) { + // current_loop_->GetPrev()->Break(); + // } +} + void BytecodeGenerator::VisitForInStmt(UNUSED_ATTRIBUTE ast::ForInStmt *node) { NOISEPAGE_ASSERT(false, "For-in statements not supported"); } @@ -209,6 +217,72 @@ void BytecodeGenerator::VisitFunctionDecl(ast::FunctionDecl *node) { } } +void BytecodeGenerator::VisitLambdaExpr(ast::LambdaExpr *node) { + // TODO(Kyle): Implement. + throw NOT_IMPLEMENTED_EXCEPTION("VisitLambdaExpr Not Implemented"); + // // The function's TPL type + // auto *func_type = node->GetFunctionLitExpr()->GetType()->As(); + + // // Allocate the function + // // func_type->RegisterCapture(); + // if(!GetExecutionResult()->HasDestination()){ + // return; + // } + // auto captures = GetCurrentFunction()->NewLocal(node->GetCaptureStructType(), node->GetName().GetString() + + // "captures"); auto fields = node->GetCaptureStructType()->As()->GetFieldsWithoutPadding(); + // // auto &locals = GetCurrentFunction()->GetLocals(); + // for(size_t i = 0;i < fields.size() - 1;i++){ + // auto field = fields[i]; + // ast::IdentifierExpr ident(node->Position(), field.name_); + // ident.SetType(field.type_->GetPointeeType()); + // auto local = VisitExpressionForLValue(&ident); + // // auto local_it = std::find_if(locals.begin(), locals.end(), [=](const auto &loc){ return loc.GetName() == + // field.name_.GetString();}); + // // bool is_capture = false; + // // LocalVar local; + // // if(local_it == locals.end()){ + // // // should be inside captures + // // NOISEPAGE_ASSERT(GetCurrentFunction()->IsLambda(), "not lambda and local to capture not found"); + // // is_capture = true; + // // auto caller_captures = GetCurrentFunction()->GetFuncType()->GetCapturesType()->GetFieldsWithoutPadding(); + // // + // // auto cap_it = std::find_if(caller_captures.begin(), caller_captures.end(), + // // [=](const auto &loc){ return loc.GetName() == field.name_.GetString();}); + // // NOISEPAGE_ASSERT(cap_it != caller_captures.end(), "local to capture straight up not found"); + // // GetEmitter()-> + // // } + // LocalVar fieldvar = GetCurrentFunction()->NewLocal( + // fields[i].type_->PointerTo(), ""); + // GetEmitter()->EmitLea(fieldvar, captures.AddressOf(), + // node->GetCaptureStructType() + // ->As()->GetOffsetOfFieldByName(fields[i].name_)); + // GetEmitter()->EmitAssign(Bytecode::Assign8, fieldvar.ValueOf(), local); + // } + + // GetEmitter()->EmitAssign(Bytecode::Assign8, GetExecutionResult()->GetDestination(), captures.AddressOf()); + // FunctionInfo *func_info = AllocateFunc(node->GetName().GetString(), func_type); + // GetCurrentFunction()->DeferAction([=](){ + // func_info->captures_ = captures; + // func_info->is_lambda_ = true; + // { + // // Visit the body of the function. We use this handy scope object to track + // // the start and end position of this function's bytecode in the module's + // // bytecode array. Upon destruction, the scoped class will set the bytecode + // // range in the function. + // EnterFunction(func_info->GetId()); + // BytecodePositionScope position_scope(this, func_info); + // Visit(node->GetFunctionLitExpr()->Body()); + // } + // for(auto f : func_info->actions_){ + // f(); + // } + // }); +} + +void BytecodeGenerator::VisitLambdaTypeRepr(ast::LambdaTypeRepr *node) { + UNREACHABLE("Should not visit type-representation nodes!"); +} + void BytecodeGenerator::VisitIdentifierExpr(ast::IdentifierExpr *node) { // Lookup the local in the current function. It must be there through a // previous variable declaration (or parameter declaration). What is returned diff --git a/src/execution/vm/llvm_engine.cpp b/src/execution/vm/llvm_engine.cpp index 5bf74ec236..e123d85d13 100644 --- a/src/execution/vm/llvm_engine.cpp +++ b/src/execution/vm/llvm_engine.cpp @@ -28,6 +28,7 @@ #include #include +#include "common/error/exception.h" #include "execution/ast/type.h" #include "execution/vm/bytecode_module.h" #include "execution/vm/bytecode_traits.h" @@ -182,6 +183,10 @@ llvm::Type *LLVMEngine::TypeMap::GetLLVMType(const ast::Type *type) { llvm_type = llvm::PointerType::getUnqual(GetLLVMType(ptr_type->GetBase())); break; } + case ast::Type::TypeId::ReferenceType: { + throw NOT_IMPLEMENTED_EXCEPTION("ReferenceType Not Implemented"); + break; + } case ast::Type::TypeId::ArrayType: { auto *arr_type = type->As(); llvm::Type *elem_type = GetLLVMType(arr_type->GetElementType()); @@ -204,6 +209,14 @@ llvm::Type *LLVMEngine::TypeMap::GetLLVMType(const ast::Type *type) { llvm_type = GetLLVMFunctionType(type->As()); break; } + case ast::Type::TypeId::LambdaType: { + throw NOT_IMPLEMENTED_EXCEPTION("LambdaType Not Implemented"); + break; + } + default: { + UNREACHABLE("Unknown Type"); + break; + } } // diff --git a/src/include/catalog/catalog_accessor.h b/src/include/catalog/catalog_accessor.h index d3cbc4b8dc..4c7927e7b5 100644 --- a/src/include/catalog/catalog_accessor.h +++ b/src/include/catalog/catalog_accessor.h @@ -340,6 +340,11 @@ class EXPORT CatalogAccessor { */ proc_oid_t GetProcOid(const std::string &procname, const std::vector &all_arg_types); + /** + * TODO(Kyle): Document. + */ + common::ManagedPointer GetProcCtxPtr(proc_oid_t proc_oid); + /** * Sets the proc context pointer column of proc_oid to func_context * @param proc_oid The proc_oid whose pointer column we are setting here diff --git a/src/include/catalog/database_catalog.h b/src/include/catalog/database_catalog.h index 5c51cc7714..4c137facb6 100644 --- a/src/include/catalog/database_catalog.h +++ b/src/include/catalog/database_catalog.h @@ -169,6 +169,9 @@ class DatabaseCatalog { /** @brief Get the OID of the specified procedure. @see PgProcImpl::GetProcOid */ proc_oid_t GetProcOid(common::ManagedPointer txn, namespace_oid_t procns, const std::string &procname, const std::vector &all_arg_types); + /** @brief Get the procedure context pointer column of the specified procedure */ + common::ManagedPointer GetProcCtxPtr( + common::ManagedPointer txn, proc_oid_t proc_oid); /** @brief Set the procedure context for the specified procedure. @see PgProcImpl::SetFunctionContextPointer */ bool SetFunctionContextPointer(common::ManagedPointer txn, proc_oid_t proc_oid, const execution::functions::FunctionContext *func_context); diff --git a/src/include/execution/ast/ast.h b/src/include/execution/ast/ast.h index 6c45d7d0a9..6a0ded3e21 100644 --- a/src/include/execution/ast/ast.h +++ b/src/include/execution/ast/ast.h @@ -48,6 +48,7 @@ namespace ast { T(DeclStmt) \ T(ExpressionStmt) \ T(ForStmt) \ + T(BreakStmt) \ T(ForInStmt) \ T(IfStmt) \ T(ReturnStmt) @@ -66,12 +67,14 @@ namespace ast { T(IdentifierExpr) \ T(ImplicitCastExpr) \ T(IndexExpr) \ + T(LambdaExpr) \ T(LitExpr) \ T(MemberExpr) \ T(UnaryOpExpr) \ /* Type Representation Expressions */ \ T(ArrayTypeRepr) \ T(FunctionTypeRepr) \ + T(LambdaTypeRepr) \ T(MapTypeRepr) \ T(PointerTypeRepr) \ T(StructTypeRepr) @@ -351,14 +354,20 @@ class FunctionDecl : public Decl { * @param pos source position * @param name identifier * @param func function literal (param types, return type, body) + * @param is_lambda `true` if this function is constructed from a lambda expresison */ - FunctionDecl(const SourcePosition &pos, Identifier name, FunctionLitExpr *func); + FunctionDecl(const SourcePosition &pos, Identifier name, FunctionLitExpr *func, bool is_lambda = false); /** * @return The function literal defining the body of the function declaration. */ FunctionLitExpr *Function() const { return func_; } + /** + * @return `true` if this function is a lambda, `false` otherwise. + */ + bool IsLambda() const noexcept { return is_lambda_; } + /** * Is the given node a function declaration? Needed as part of the custom AST RTTI infrastructure. * @param node The node to check. @@ -371,6 +380,8 @@ class FunctionDecl : public Decl { private: // The function definition (signature and body). FunctionLitExpr *func_; + // Is this function generated by a lambda expression. + const bool is_lambda_; }; /** @@ -691,6 +702,26 @@ class IterationStmt : public Stmt { BlockStmt *body_; }; +/** + * A break statement. + */ +class BreakStmt : public Stmt { + public: + /** + * Constructor + * @param pos source position + */ + BreakStmt(const SourcePosition &pos) : Stmt(Kind::BreakStmt, pos) {} + + /** + * Is the given node a return statement? + * Needed as part of the custom AST RTTI infrastructure. + * @param node The node to check. + * @return `true` if the node is a return statement, `false` otherwise. + */ + static bool classof(const AstNode *node) { return node->GetKind() == Kind::BreakStmt; } +}; + /** * A vanilla for-statement. */ @@ -731,6 +762,14 @@ class ForStmt : public IterationStmt { return node->GetKind() == Kind::ForStmt; } + private: + friend class sema::Sema; + + /** + * TODO(Kyle): Why? + */ + void SetCond(Expr *cond) { cond_ = cond; } + private: Stmt *init_; Expr *cond_; @@ -1039,6 +1078,37 @@ class BinaryOpExpr : public Expr { Expr *right_; }; +/** + * A lambda expression. + * TODO(Kyle): Document. + */ +class LambdaExpr : public Expr { + public: + LambdaExpr(const SourcePosition &pos, FunctionLitExpr *func, util::RegionVector &&captures) + : Expr(Kind::LambdaExpr, pos), captures_{nullptr}, func_lit_(func), capture_idents_{std::move(captures)} {} + + FunctionLitExpr *GetFunctionLitExpr() const { return func_lit_; } + + ast::StructTypeRepr *GetCaptureStruct() const { return captures_; } + + ast::Type *GetCaptureStructType() const { return capture_type_; } + + const Identifier &GetName() const { return name_; } + + const util::RegionVector &GetCaptureIdents() const { return capture_idents_; } + + void SetName(Identifier name) { name_ = name; } + + private: + friend class sema::Sema; + + Identifier name_; + ast::StructTypeRepr *captures_; + ast::Type *capture_type_; + FunctionLitExpr *func_lit_; + util::RegionVector capture_idents_; +}; + /** * A function call expression. */ @@ -1047,7 +1117,7 @@ class CallExpr : public Expr { /** * Type of call (builtin call or regular function call) */ - enum class CallKind : uint8_t { Regular, Builtin }; + enum class CallKind : uint8_t { Regular, Builtin, Lambda }; /** * Constructor for regular calls @@ -1085,6 +1155,11 @@ class CallExpr : public Expr { */ uint32_t NumArgs() const { return static_cast(args_.size()); } + /** + * TODO(Kyle): Document. + */ + void PushArgument(Expr *expr) { args_.push_back(expr); } + /** * @return The kind of call, either regular or a call to a builtin function. */ @@ -1201,8 +1276,9 @@ class FunctionLitExpr : public Expr { * Constructor * @param type_repr type representation (param types, return type) * @param body body of the function + * @param is_lambda `true` if the literal is a lambda, `false` otherwise */ - FunctionLitExpr(FunctionTypeRepr *type_repr, BlockStmt *body); + FunctionLitExpr(FunctionTypeRepr *type_repr, BlockStmt *body, bool is_lambda = false); /** * @return The function's signature. @@ -1215,10 +1291,15 @@ class FunctionLitExpr : public Expr { BlockStmt *Body() const { return body_; } /** - * @return True if the function has no statements; false otherwise. + * @return `true` if the function has no statements; `false` otherwise. */ bool IsEmpty() const { return Body()->IsEmpty(); } + /** + * @return `true` if the function is a lambda, `false` otherwise. + */ + bool IsLambda() const { return is_lambda_; } + /** * Is the given node a function literal? Needed as part of the custom AST RTTI infrastructure. * @param node The node to check. @@ -1233,6 +1314,8 @@ class FunctionLitExpr : public Expr { FunctionTypeRepr *type_repr_; // The body of the function. BlockStmt *body_; + // Is this function literal a lambda. + const bool is_lambda_; }; /** @@ -1797,6 +1880,24 @@ class MapTypeRepr : public Expr { Expr *val_; }; +/** + * Lambda type. + * TODO(Kyle): Document. + */ +class LambdaTypeRepr : public Expr { + public: + LambdaTypeRepr(const SourcePosition &pos, Expr *fn_type) : Expr(Kind::LambdaTypeRepr, pos), fn_type_(fn_type) {} + + Expr *FunctionType() const { return fn_type_; } + + static bool classof(const AstNode *node) { // NOLINT + return node->GetKind() == Kind::LambdaTypeRepr; + } + + private: + Expr *fn_type_; +}; + /** * Pointer type. */ diff --git a/src/include/execution/ast/ast_clone.h b/src/include/execution/ast/ast_clone.h new file mode 100644 index 0000000000..2d8b6396f0 --- /dev/null +++ b/src/include/execution/ast/ast_clone.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +#include "execution/ast/ast_node_factory.h" +#include "execution/ast/context.h" + +namespace noisepage::execution::ast { + +class AstNode; + +/** + * TODO(Kyle): Document. + */ +class AstClone { + public: + /** + * Clones an ASTNode and its descendants. + * TODO(Kyle): Document. + */ + static AstNode *Clone(AstNode *node, AstNodeFactory *factory, std::string prefix, Context *old_context, + Context *new_context); +}; + +} // namespace noisepage::execution::ast diff --git a/src/include/execution/ast/ast_fwd.h b/src/include/execution/ast/ast_fwd.h index 77ee5be674..e1d7091cb8 100644 --- a/src/include/execution/ast/ast_fwd.h +++ b/src/include/execution/ast/ast_fwd.h @@ -12,6 +12,7 @@ class Decl; class FieldDecl; class File; // NOLINT it picks up madoka's File class FunctionDecl; +class LambdaExpr; class Stmt; class StructDecl; class VariableDecl; diff --git a/src/include/execution/ast/ast_node_factory.h b/src/include/execution/ast/ast_node_factory.h index cae09e4690..8c10eb11eb 100644 --- a/src/include/execution/ast/ast_node_factory.h +++ b/src/include/execution/ast/ast_node_factory.h @@ -45,6 +45,17 @@ class AstNodeFactory { return new (region_) FunctionDecl(pos, name, fun); } + /** + * @param pos source position + * @param fun function literal (params, return type, body) + * @param captures lambda captures + * @return created LambdaExpr node. + */ + LambdaExpr *NewLambdaExpr(const SourcePosition &pos, FunctionLitExpr *fun, + util::RegionVector &&captures) { + return new (region_) LambdaExpr(pos, fun, std::move(captures)); + } + /** * @param pos source position * @param name struct name @@ -133,6 +144,12 @@ class AstNodeFactory { return new (region_) IfStmt(pos, cond, then_stmt, else_stmt); } + /** + * @param pos source position + * @return created BreakStmt node + */ + BreakStmt *NewBreakStmt(const SourcePosition &pos) { return new (region_) BreakStmt(pos); } + /** * @param pos source position * @param ret returned expression @@ -337,6 +354,15 @@ class AstNodeFactory { return new (region_) MapTypeRepr(pos, key_type, val_type); } + /** + * @param pos source position + * @param fn_type the function type + * @return created LambdaTypeRepr + */ + LambdaTypeRepr *NewLambdaType(const SourcePosition &pos, Expr *fn_type) { + return new (region_) LambdaTypeRepr(pos, fn_type); + } + private: util::Region *region_; }; diff --git a/src/include/execution/ast/builtins.h b/src/include/execution/ast/builtins.h index 2f70005e60..5a63b9c2ae 100644 --- a/src/include/execution/ast/builtins.h +++ b/src/include/execution/ast/builtins.h @@ -329,6 +329,18 @@ namespace noisepage::execution::ast { F(GetParamDate, getParamDate) \ F(GetParamTimestamp, getParamTimestamp) \ F(GetParamString, getParamString) \ + F(StartNewParams, startNewParams) \ + F(FinishNewParams, finishNewParams) \ + F(AddParamBool, addParamBool) \ + F(AddParamTinyInt, addParamTinyInt) \ + F(AddParamSmallInt, addParamSmallInt) \ + F(AddParamInt, addParamInt) \ + F(AddParamBigInt, addParamBigInt) \ + F(AddParamReal, addParamReal) \ + F(AddParamDouble, addParamDouble) \ + F(AddParamDate, addParamDate) \ + F(AddParamTimestamp, addParamTimestamp) \ + F(AddParamString, addParamString) \ \ /* String functions */ \ F(Lower, lower) \ diff --git a/src/include/execution/ast/type.h b/src/include/execution/ast/type.h index 41a6f24071..03c80735fa 100644 --- a/src/include/execution/ast/type.h +++ b/src/include/execution/ast/type.h @@ -20,7 +20,9 @@ class Context; F(BuiltinType) \ F(StringType) \ F(PointerType) \ + F(ReferenceType) \ F(ArrayType) \ + F(LambdaType) \ F(MapType) \ F(StructType) \ F(FunctionType) @@ -307,6 +309,11 @@ class Type : public util::RegionObject { */ PointerType *PointerTo(); + /** + * @return A new type that is a reference to the current type. + */ + ReferenceType *ReferenceTo(); + /** * @return If this is a pointer type, the type of the element pointed to. Returns null otherwise. */ @@ -500,6 +507,37 @@ class PointerType : public Type { Type *base_; }; +/** + * Reference type. + */ +class ReferenceType : public Type { + public: + /** + * @return base type + */ + Type *GetBase() const { return base_; } + + /** + * Static Constructor + * @param base type + * @return reference to base type + */ + static ReferenceType *Get(Type *base); + + /** + * @param type checked type + * @return whether type is a reference type. + */ + static bool classof(const Type *type) { return type->GetTypeId() == TypeId::ReferenceType; } // NOLINT + + private: + explicit ReferenceType(Type *base) + : Type(base->GetContext(), sizeof(int8_t *), alignof(int8_t *), TypeId::ReferenceType), base_(base) {} + + private: + Type *base_; +}; + /** * Array type. */ @@ -588,10 +626,15 @@ struct Field { class FunctionType : public Type { public: /** - * @return A constant reference to the list of parameters to a function. + * @return An immutable reference to the list of parameters to a function. */ const util::RegionVector &GetParams() const { return params_; } + /** + * @return A mutable reference to the list of parameters to a function. + */ + util::RegionVector &GetParams() { return params_; } + /** * @return The number of parameters to the function. */ @@ -602,13 +645,25 @@ class FunctionType : public Type { */ Type *GetReturnType() const { return ret_; } + bool IsEqual(const FunctionType *other); + + bool IsLambda() const { return is_lambda_; } + + ast::StructType *GetCapturesType() const { + NOISEPAGE_ASSERT(is_lambda_, "Getting capture type from not lambda"); + return captures_; + } + + void RegisterCapture(); + /** * Create a function with parameters @em params and returning types of type @em ret. * @param params The parameters to the function. * @param ret The type of the object the function returns. + * @param is_lambda `true` if this function is a lambda, `false` otherwise. * @return The function type. */ - static FunctionType *Get(util::RegionVector &¶ms, Type *ret); + static FunctionType *Get(util::RegionVector &¶ms, Type *ret, bool is_lambda); /** * @param type type to compare with @@ -617,11 +672,13 @@ class FunctionType : public Type { static bool classof(const Type *type) { return type->GetTypeId() == TypeId::FunctionType; } // NOLINT private: - explicit FunctionType(util::RegionVector &¶ms, Type *ret); + explicit FunctionType(util::RegionVector &¶ms, Type *ret, bool is_lambda); private: util::RegionVector params_; Type *ret_; + const bool is_lambda_; + ast::StructType *captures_{}; }; /** @@ -661,6 +718,25 @@ class MapType : public Type { Type *val_type_; }; +/** + * Lambda type. + * TODO(Kyle): Document. + */ +class LambdaType : public Type { + public: + FunctionType *GetFunctionType() const { return fn_type_; } + + static LambdaType *Get(FunctionType *fn_type); + + static bool classof(const Type *type) { return type->GetTypeId() == TypeId::LambdaType; } // NOLINT + + private: + LambdaType(FunctionType *fn_type); + + private: + FunctionType *fn_type_; +}; + /** * Struct type. */ diff --git a/src/include/execution/ast/udf/udf_ast_context.h b/src/include/execution/ast/udf/udf_ast_context.h index 6d9105b174..05714d43e2 100644 --- a/src/include/execution/ast/udf/udf_ast_context.h +++ b/src/include/execution/ast/udf/udf_ast_context.h @@ -26,11 +26,12 @@ class UDFASTContext { return true; } - void AddVariable(std::string name) { local_variables_.push_back(name); } + void AddVariable(const std::string &name) { local_variables_.push_back(name); } - const std::string &GetVariableAtIndex(int index) { + const std::string &GetVariableAtIndex(const std::size_t index) { NOISEPAGE_ASSERT(local_variables_.size() >= index, "Bad var"); - return local_variables_[index - 1]; + // TODO(Kyle): Why did this originally have index - 1? + return local_variables_.at(index); } void SetRecordType(std::string var, std::vector> &&elems) { diff --git a/src/include/execution/ast/udf/udf_ast_node_visitor.h b/src/include/execution/ast/udf/udf_ast_node_visitor.h index 9966d39c61..178a7ae65b 100644 --- a/src/include/execution/ast/udf/udf_ast_node_visitor.h +++ b/src/include/execution/ast/udf/udf_ast_node_visitor.h @@ -30,27 +30,27 @@ class FunctionAST; class ASTNodeVisitor { public: virtual ~ASTNodeVisitor(){}; - - virtual void Visit(AbstractAST *){}; - virtual void Visit(StmtAST *){}; - virtual void Visit(ExprAST *){}; - virtual void Visit(FunctionAST *){}; - virtual void Visit(ValueExprAST *){}; - virtual void Visit(VariableExprAST *){}; - virtual void Visit(BinaryExprAST *){}; - virtual void Visit(IsNullExprAST *){}; - virtual void Visit(CallExprAST *){}; - virtual void Visit(MemberExprAST *){}; - virtual void Visit(SeqStmtAST *){}; - virtual void Visit(DeclStmtAST *){}; - virtual void Visit(IfStmtAST *){}; - virtual void Visit(WhileStmtAST *){}; - virtual void Visit(RetStmtAST *){}; - virtual void Visit(AssignStmtAST *){}; - virtual void Visit(ForStmtAST *){}; - virtual void Visit(SQLStmtAST *){}; - virtual void Visit(DynamicSQLStmtAST *){}; + virtual void Visit(AbstractAST *ast) = 0; + virtual void Visit(StmtAST *ast) = 0; + virtual void Visit(ExprAST *ast) = 0; + virtual void Visit(FunctionAST *ast) = 0; + virtual void Visit(ValueExprAST *ast) = 0; + virtual void Visit(VariableExprAST *ast) = 0; + virtual void Visit(BinaryExprAST *ast) = 0; + virtual void Visit(IsNullExprAST *ast) = 0; + virtual void Visit(CallExprAST *ast) = 0; + virtual void Visit(MemberExprAST *ast) = 0; + virtual void Visit(SeqStmtAST *ast) = 0; + virtual void Visit(DeclStmtAST *ast) = 0; + virtual void Visit(IfStmtAST *ast) = 0; + virtual void Visit(WhileStmtAST *ast) = 0; + virtual void Visit(RetStmtAST *ast) = 0; + virtual void Visit(AssignStmtAST *ast) = 0; + virtual void Visit(ForStmtAST *ast) = 0; + virtual void Visit(SQLStmtAST *ast) = 0; + virtual void Visit(DynamicSQLStmtAST *ast) = 0; }; + } // namespace udf } // namespace ast } // namespace execution diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index 3ec7888a41..c0f10ee5f2 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -14,13 +14,10 @@ class CatalogAccessor; namespace noisepage { namespace execution { -namespace compiler { -namespace udf { - -// TODO(Kyle): Is distinguishing the standard codegen -// namespace stuff from the UDF stuff here going to be -// an issue (i.e. disambiguation)? +// Forward declarations +namespace ast { +namespace udf { class AbstractAST; class StmtAST; class ExprAST; @@ -36,56 +33,58 @@ class WhileStmtAST; class RetStmtAST; class AssignStmtAST; class SQLStmtAST; +class FunctionAST; +class IsNullExprAST; class DynamicSQLStmtAST; class ForStmtAST; +} // namespace udf +} // namespace ast -class UDFCodegen : ASTNodeVisitor { +namespace compiler { +namespace udf { + +// TODO(Kyle): Is distinguishing the standard codegen +// namespace stuff from the UDF stuff here going to be +// an issue (i.e. disambiguation)? + +class UDFCodegen : ast::udf::ASTNodeVisitor { public: - UDFCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, parser::udf::UDFASTContext *udf_ast_context, + UDFCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, ast::udf::UDFASTContext *udf_ast_context, CodeGen *codegen, catalog::db_oid_t db_oid); ~UDFCodegen(){}; catalog::type_oid_t GetCatalogTypeOidFromSQLType(execution::ast::BuiltinType::Kind type); - void GenerateUDF(AbstractAST *); - void Visit(AbstractAST *) override; - void Visit(FunctionAST *) override; - void Visit(StmtAST *) override; - void Visit(ExprAST *) override; - void Visit(ValueExprAST *) override; - void Visit(VariableExprAST *) override; - void Visit(BinaryExprAST *) override; - void Visit(CallExprAST *) override; - void Visit(IsNullExprAST *) override; - void Visit(SeqStmtAST *) override; - void Visit(DeclStmtAST *) override; - void Visit(IfStmtAST *) override; - void Visit(WhileStmtAST *) override; - void Visit(RetStmtAST *) override; - void Visit(AssignStmtAST *) override; - void Visit(SQLStmtAST *) override; - void Visit(DynamicSQLStmtAST *) override; - void Visit(ForStmtAST *) override; - void Visit(MemberExprAST *) override; + void GenerateUDF(ast::udf::AbstractAST *); + + void Visit(ast::udf::AbstractAST *) override; + void Visit(ast::udf::FunctionAST *) override; + void Visit(ast::udf::StmtAST *) override; + void Visit(ast::udf::ExprAST *) override; + void Visit(ast::udf::ValueExprAST *) override; + void Visit(ast::udf::VariableExprAST *) override; + void Visit(ast::udf::BinaryExprAST *) override; + void Visit(ast::udf::CallExprAST *) override; + void Visit(ast::udf::IsNullExprAST *) override; + void Visit(ast::udf::SeqStmtAST *) override; + void Visit(ast::udf::DeclStmtAST *) override; + void Visit(ast::udf::IfStmtAST *) override; + void Visit(ast::udf::WhileStmtAST *) override; + void Visit(ast::udf::RetStmtAST *) override; + void Visit(ast::udf::AssignStmtAST *) override; + void Visit(ast::udf::SQLStmtAST *) override; + void Visit(ast::udf::DynamicSQLStmtAST *) override; + void Visit(ast::udf::ForStmtAST *) override; + void Visit(ast::udf::MemberExprAST *) override; - execution::ast::File *Finish() { - auto fn = fb_->Finish(); - //// util::RegionVector decls_reg_vec{decls->begin(), decls->end(), codegen.Region()}; - execution::util::RegionVector decls({fn}, codegen_->GetAstContext()->GetRegion()); - // for(auto decl : aux_decls_){ - // decls.push_back(decl); - // } - decls.insert(decls.begin(), aux_decls_.begin(), aux_decls_.end()); - auto file = codegen_->GetAstContext()->GetNodeFactory()->NewFile({0, 0}, std::move(decls)); - return file; - } + execution::ast::File *Finish(); static const char *GetReturnParamString(); private: catalog::CatalogAccessor *accessor_; FunctionBuilder *fb_; - UDFASTContext *udf_ast_context_; + ast::udf::UDFASTContext *udf_ast_context_; CodeGen *codegen_; type::TypeId current_type_{type::TypeId::INVALID}; execution::ast::Expr *dst_; diff --git a/src/include/execution/functions/function_context.h b/src/include/execution/functions/function_context.h index d39029242f..39018160c9 100644 --- a/src/include/execution/functions/function_context.h +++ b/src/include/execution/functions/function_context.h @@ -6,8 +6,12 @@ #include "catalog/catalog_defs.h" #include "common/managed_pointer.h" +#include "execution/ast/ast.h" #include "execution/ast/builtins.h" +#include "execution/ast/context.h" +#include "execution/util/region.h" #include "type/type_id.h" + namespace noisepage::execution::functions { /** @@ -43,6 +47,19 @@ class FunctionContext { is_builtin_{true}, builtin_{builtin}, is_exec_ctx_required_{is_exec_ctx_required} {} + + FunctionContext(std::string func_name, type::TypeId func_ret_type, std::vector &&args_type, + std::unique_ptr ast_region, std::unique_ptr ast_context, ast::File *file, + bool is_exec_ctx_required = true) + : func_name_(std::move(func_name)), + func_ret_type_(func_ret_type), + args_type_(std::move(args_type)), + is_builtin_{false}, + is_exec_ctx_required_{is_exec_ctx_required}, + ast_region_{std::move(ast_region)}, + ast_context_{std::move(ast_context)}, + file_{file} {} + /** * @return The name of the function represented by this context object */ @@ -80,6 +97,31 @@ class FunctionContext { return is_exec_ctx_required_; } + /** + * @return returns the main functiondecl of this udf (to be used only if not builtin) + */ + common::ManagedPointer GetMainFunctionDecl() const { + NOISEPAGE_ASSERT(!IsBuiltin(), "Getting a non-builtin from a builtin function"); + return common::ManagedPointer( + reinterpret_cast(file_->Declarations().back())); + } + + /** + * @return returns the file with the functiondecl and supporting decls (to be used only if not builtin) + */ + ast::File *GetFile() const { + NOISEPAGE_ASSERT(!IsBuiltin(), "Getting a non-builtin from a builtin function"); + return file_; + } + + /** + * TODO(Kyle): Document. + */ + ast::Context *GetASTContext() const { + NOISEPAGE_ASSERT(!IsBuiltin(), "Getting a non-builtin from a builtin function"); + return ast_context_.get(); + } + private: std::string func_name_; type::TypeId func_ret_type_; @@ -87,6 +129,10 @@ class FunctionContext { bool is_builtin_; ast::Builtin builtin_; bool is_exec_ctx_required_; + + std::unique_ptr ast_region_; + std::unique_ptr ast_context_; + ast::File *file_; }; } // namespace noisepage::execution::functions diff --git a/src/include/execution/vm/bytecode_generator.h b/src/include/execution/vm/bytecode_generator.h index ce14ff3f0b..16277e2caf 100644 --- a/src/include/execution/vm/bytecode_generator.h +++ b/src/include/execution/vm/bytecode_generator.h @@ -155,6 +155,9 @@ class BytecodeGenerator final : public ast::AstVisitor { void VisitExpressionForTest(ast::Expr *expr, BytecodeLabel *then_label, BytecodeLabel *else_label, TestFallthrough fallthrough); + // Visit the body of a break statement + void VisitBreakStatement(ast::BreakStmt *break_stmt); + // Visit the body of an iteration statement void VisitIterationStatement(ast::IterationStmt *iteration, LoopBuilder *loop_builder); diff --git a/src/include/parser/udf/udf_parser.h b/src/include/parser/udf/udf_parser.h index 1d441b6909..fd3d4a5ef3 100644 --- a/src/include/parser/udf/udf_parser.h +++ b/src/include/parser/udf/udf_parser.h @@ -3,43 +3,53 @@ #include #include -#include "ast_nodes.h" #include "catalog/catalog_accessor.h" +#include "execution/ast/udf/udf_ast_context.h" +#include "execution/ast/udf/udf_ast_nodes.h" #include "parser/expression_util.h" #include "parser/postgresparser.h" -#include "parser/udf/udf_ast_context.h" // TODO(Kyle): Do we want to place UDF parsing in its own namespace? namespace noisepage { + +// Forward declaration +namespace execution::ast::udf { +class FunctionAST; +} + namespace parser { namespace udf { +/** + * Namespace alias to make below more manageable. + */ +namespace udfexec = execution::ast::udf; -class FunctionAST; class PLpgSQLParser { public: - PLpgSQLParser(common::ManagedPointer udf_ast_context, + PLpgSQLParser(common::ManagedPointer udf_ast_context, const common::ManagedPointer accessor, catalog::db_oid_t db_oid) : udf_ast_context_(udf_ast_context), accessor_(accessor), db_oid_(db_oid) {} - std::unique_ptr ParsePLpgSQL(std::vector &¶m_names, - std::vector &¶m_types, const std::string &func_body, - common::ManagedPointer ast_context); + std::unique_ptr ParsePLpgSQL(std::vector &¶m_names, + std::vector &¶m_types, + const std::string &func_body, + common::ManagedPointer ast_context); private: - std::unique_ptr ParseBlock(const nlohmann::json &block); - std::unique_ptr ParseFunction(const nlohmann::json &block); - std::unique_ptr ParseDecl(const nlohmann::json &decl); - std::unique_ptr ParseIf(const nlohmann::json &branch); - std::unique_ptr ParseWhile(const nlohmann::json &loop); - std::unique_ptr ParseFor(const nlohmann::json &loop); - std::unique_ptr ParseSQL(const nlohmann::json &sql_stmt); - std::unique_ptr ParseDynamicSQL(const nlohmann::json &sql_stmt); + std::unique_ptr ParseBlock(const nlohmann::json &block); + std::unique_ptr ParseFunction(const nlohmann::json &block); + std::unique_ptr ParseDecl(const nlohmann::json &decl); + std::unique_ptr ParseIf(const nlohmann::json &branch); + std::unique_ptr ParseWhile(const nlohmann::json &loop); + std::unique_ptr ParseFor(const nlohmann::json &loop); + std::unique_ptr ParseSQL(const nlohmann::json &sql_stmt); + std::unique_ptr ParseDynamicSQL(const nlohmann::json &sql_stmt); // Feed the expression (as a sql string) to our parser then transform the // noisepage expression into ast node - std::unique_ptr ParseExprSQL(const std::string expr_sql_str); - std::unique_ptr ParseExpr(common::ManagedPointer); + std::unique_ptr ParseExprSQL(const std::string expr_sql_str); + std::unique_ptr ParseExpr(common::ManagedPointer); - common::ManagedPointer udf_ast_context_; + common::ManagedPointer udf_ast_context_; const common::ManagedPointer accessor_; catalog::db_oid_t db_oid_; // common::ManagedPointer sql_parser_; diff --git a/src/parser/udf/udf_parser.cpp b/src/parser/udf/udf_parser.cpp index e904d0c6d3..735d3781ae 100644 --- a/src/parser/udf/udf_parser.cpp +++ b/src/parser/udf/udf_parser.cpp @@ -1,6 +1,7 @@ #include #include "binder/bind_node_visitor.h" +#include "execution/ast/udf/udf_ast_nodes.h" #include "loggers/parser_logger.h" #include "parser/udf/udf_parser.h" @@ -13,7 +14,9 @@ namespace noisepage { namespace parser { namespace udf { + using namespace nlohmann; +using namespace execution::ast::udf; // TODO(Kyle): constexpr // TODO(Kyle): Define elsewhere? @@ -82,8 +85,8 @@ std::unique_ptr PLpgSQLParser::ParsePLpgSQL(std::vectorSetVariableType(udf_name, param_types[i++]); } const auto function = function_list[0][kPLpgSQL_function]; - std::unique_ptr function_ast( - new FunctionAST(ParseFunction(function), std::move(param_names), std::move(param_types))); + auto function_ast = + std::make_unique(ParseFunction(function), std::move(param_names), std::move(param_types)); return function_ast; } @@ -245,46 +248,47 @@ std::unique_ptr PLpgSQLParser::ParseFor(const nlohmann::json &loop) { return std::unique_ptr(new ForStmtAST(std::move(var_vec), std::move(parse_result), std::move(body_stmt))); } +// TODO(Kyle): Implement std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql_stmt) { - PARSER_LOG_DEBUG("ParseSQL"); - auto sql_query = sql_stmt[kSqlstmt][kPLpgSQL_expr][kQuery].get(); - auto var_name = sql_stmt[kRow][kPLpgSQL_row][kFields][0][kName].get(); - auto parse_result = PostgresParser::BuildParseTree(sql_query.c_str()); - if (parse_result == nullptr) { - PARSER_LOG_DEBUG("Bad SQL statement"); - return nullptr; - } - binder::BindNodeVisitor visitor(accessor_, db_oid_); - - std::unordered_map> query_params; - - try { - // TODO(Matt): I don't think the binder should need the database name. It's already bound in the ConnectionContext - // binder::BindNodeVisitor visitor(accessor_, db_oid_); - query_params = visitor.BindAndGetUDFParams(common::ManagedPointer(parse_result), udf_ast_context_); - } catch (BinderException &b) { - PARSER_LOG_DEBUG("Bad SQL statement"); - return nullptr; - } - - // check to see if a record type can be bound to this - type::TypeId type; - auto ret = udf_ast_context_->GetVariableType(var_name, &type); - if (!ret) { - throw PARSER_EXCEPTION("PL/pgSQL parser : Didn't declare variable"); - } - if (type == type::TypeId::INVALID) { - std::vector> elems; - auto sel = parse_result->GetStatement(0).CastManagedPointerTo()->GetSelectColumns(); - for (auto col : sel) { - elems.emplace_back(col->GetAliasName(), col->GetReturnValueType()); - } - udf_ast_context_->SetRecordType(var_name, std::move(elems)); - } - - return std::unique_ptr( - new SQLStmtAST(std::move(parse_result), std::move(var_name), std::move(query_params))); - // return nullptr; + throw NOT_IMPLEMENTED_EXCEPTION("ParseSQL Not Implemented"); + // auto sql_query = sql_stmt[kSqlstmt][kPLpgSQL_expr][kQuery].get(); + // auto var_name = sql_stmt[kRow][kPLpgSQL_row][kFields][0][kName].get(); + // auto parse_result = PostgresParser::BuildParseTree(sql_query.c_str()); + // if (parse_result == nullptr) { + // PARSER_LOG_DEBUG("Bad SQL statement"); + // return nullptr; + // } + // binder::BindNodeVisitor visitor(accessor_, db_oid_); + + // std::unordered_map> query_params; + + // try { + // // TODO(Matt): I don't think the binder should need the database name. It's already bound in the + // ConnectionContext + // // binder::BindNodeVisitor visitor(accessor_, db_oid_); + // query_params = visitor.BindAndGetUDFParams(common::ManagedPointer(parse_result), udf_ast_context_); + // } catch (BinderException &b) { + // PARSER_LOG_DEBUG("Bad SQL statement"); + // return nullptr; + // } + + // // check to see if a record type can be bound to this + // type::TypeId type; + // auto ret = udf_ast_context_->GetVariableType(var_name, &type); + // if (!ret) { + // throw PARSER_EXCEPTION("PL/pgSQL parser : Didn't declare variable"); + // } + // if (type == type::TypeId::INVALID) { + // std::vector> elems; + // auto sel = parse_result->GetStatement(0).CastManagedPointerTo()->GetSelectColumns(); + // for (auto col : sel) { + // elems.emplace_back(col->GetAliasName(), col->GetReturnValueType()); + // } + // udf_ast_context_->SetRecordType(var_name, std::move(elems)); + // } + + // return std::unique_ptr( + // new SQLStmtAST(std::move(parse_result), std::move(var_name), std::move(query_params))); } std::unique_ptr PLpgSQLParser::ParseDynamicSQL(const nlohmann::json &sql_stmt) { @@ -340,8 +344,8 @@ std::unique_ptr PLpgSQLParser::ParseExpr(common::ManagedPointer(new IsNullExprAST(true, ParseExpr(expr->GetChild(0)))); } throw PARSER_EXCEPTION("PL/pgSQL parser : Expression type not supported"); - return nullptr; } + } // namespace udf } // namespace parser } // namespace noisepage \ No newline at end of file From 421d8734cd40fc8a7e017653a9ae5fbec7756fc1 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sat, 3 Apr 2021 20:35:33 -0400 Subject: [PATCH 010/139] integrate binder, builds successfully --- src/binder/bind_node_visitor.cpp | 64 +++++++++++++++++-- src/include/binder/bind_node_visitor.h | 15 +++++ src/include/binder/binder_sherpa.h | 9 +++ .../expression/column_value_expression.h | 11 ++++ 4 files changed, 94 insertions(+), 5 deletions(-) diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index febeb64b39..c9b66aa935 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -55,7 +55,7 @@ void BindNodeVisitor::BindNameToNode( common::ManagedPointer parse_result, const common::ManagedPointer> parameters, const common::ManagedPointer> desired_parameter_types) { - NOISEPAGE_ASSERT(parse_result != nullptr, "We shouldn't be tring to bind something without a ParseResult."); + NOISEPAGE_ASSERT(parse_result != nullptr, "We shouldn't be trying to bind something without a ParseResult."); sherpa_ = std::make_unique(parse_result, parameters, desired_parameter_types); NOISEPAGE_ASSERT(sherpa_->GetParseResult()->GetStatements().size() == 1, "Binder can only bind one at a time."); sherpa_->GetParseResult()->GetStatement(0)->Accept( @@ -64,6 +64,19 @@ void BindNodeVisitor::BindNameToNode( BindNodeVisitor::~BindNodeVisitor() = default; +std::unordered_map> BindNodeVisitor::BindAndGetUDFParams( + common::ManagedPointer parse_result, + common::ManagedPointer udf_ast_context) { + // TODO(Kyle): Revisit this. + NOISEPAGE_ASSERT(parse_result != nullptr, "We shouldn't be trying to bind something without a ParseResult."); + sherpa_ = std::make_unique(parse_result, nullptr, nullptr); + NOISEPAGE_ASSERT(sherpa_->GetParseResult()->GetStatements().size() == 1, "Binder can only bind one at a time."); + udf_ast_context_ = udf_ast_context; + sherpa_->GetParseResult()->GetStatement(0)->Accept( + common::ManagedPointer(this).CastManagedPointerTo()); + return udf_params_; +} + void BindNodeVisitor::Visit(common::ManagedPointer node) { BINDER_LOG_TRACE("Visiting AnalyzeStatement ..."); SqlNodeVisitor::Visit(node); @@ -549,8 +562,18 @@ void BindNodeVisitor::Visit(common::ManagedPointerSetColumnPosTuple(expr)) { + if (udf_ast_context_ != nullptr && udf_ast_context_->GetVariableType(expr->GetColumnName(), &the_type)) { + expr->SetReturnValueType(the_type); + auto idx = 0; + if (udf_params_.count(expr->GetColumnName()) == 0) { + udf_params_[expr->GetColumnName()] = std::make_pair("", udf_params_.size()); + idx = udf_params_.size() - 1; + } + expr->SetParamIdx(idx); + } else if (context_ == nullptr || !context_->SetColumnPosTuple(expr)) { throw BINDER_EXCEPTION(fmt::format("column \"{}\" does not exist", col_name), common::ErrorCode::ERRCODE_UNDEFINED_COLUMN); } @@ -562,9 +585,23 @@ void BindNodeVisitor::Visit(common::ManagedPointerCheckNestedTableColumn(table_name, col_name, expr)) { - throw BINDER_EXCEPTION(fmt::format("Invalid table reference {}", expr->GetTableName()), - common::ErrorCode::ERRCODE_UNDEFINED_TABLE); + } else if (udf_ast_context_ != nullptr && udf_ast_context_->GetVariableType(expr->GetTableName(), &the_type)) { + // record type + NOISEPAGE_ASSERT(the_type == type::TypeId::INVALID, "unknown type"); + auto &fields = udf_ast_context_->GetRecordType(expr->GetTableName()); + auto it = std::find_if(fields.begin(), fields.end(), [=](auto p) { return p.first == expr->GetColumnName(); }); + auto idx = 0; + if (it != fields.end()) { + if (udf_params_.count(expr->GetColumnName()) == 0) { + udf_params_[expr->GetColumnName()] = std::make_pair(expr->GetTableName(), udf_params_.size()); + idx = udf_params_.size() - 1; + } + expr->SetReturnValueType(it->second); + expr->SetParamIdx(idx); + } else if (context_ == nullptr || !context_->CheckNestedTableColumn(table_name, col_name, expr)) { + throw BINDER_EXCEPTION(fmt::format("Invalid table reference {}", expr->GetTableName()), + common::ErrorCode::ERRCODE_UNDEFINED_TABLE); + } } } } @@ -592,6 +629,20 @@ void BindNodeVisitor::Visit(common::ManagedPointer expr->SetChild(i, child->GetChild(0)); } } + + for (auto i = 0UL; i < expr->GetChildrenSize(); ++i) { + auto child = expr->GetChild(i); + if (child->GetExpressionType() == parser::ExpressionType::COLUMN_VALUE) { + auto index = child.CastManagedPointerTo()->GetParamIdx(); + if (index >= 0) { + // replace with PVE + std::unique_ptr pve = std::make_unique(index); + pve->SetReturnValueType(child->GetReturnValueType()); + expr->SetChild(i, common::ManagedPointer(pve)); + sherpa_->GetParseResult()->AddExpression(std::move(pve)); + } + } + } } void BindNodeVisitor::Visit(common::ManagedPointer expr) { @@ -655,6 +706,9 @@ void BindNodeVisitor::Visit(common::ManagedPointer e void BindNodeVisitor::Visit(common::ManagedPointer expr) { BINDER_LOG_TRACE("Visiting ParameterValueExpression ..."); SqlNodeVisitor::Visit(expr); + if (sherpa_ == nullptr || sherpa_->GetParameters() == nullptr) { + return; + } const common::ManagedPointer param = common::ManagedPointer(&((*(sherpa_->GetParameters()))[expr->GetValueIdx()])); const auto desired_type = sherpa_->GetDesiredType(expr.CastManagedPointerTo()); diff --git a/src/include/binder/bind_node_visitor.h b/src/include/binder/bind_node_visitor.h index 25d6815dd0..66315bec21 100644 --- a/src/include/binder/bind_node_visitor.h +++ b/src/include/binder/bind_node_visitor.h @@ -6,6 +6,9 @@ #include "binder/sql_node_visitor.h" #include "catalog/catalog_defs.h" +#include "execution/ast/udf/udf_ast_context.h" +#include "parser/postgresparser.h" +#include "parser/select_statement.h" #include "type/type_id.h" namespace noisepage { @@ -47,6 +50,13 @@ class BindNodeVisitor final : public SqlNodeVisitor { /** Destructor. Must be defined due to forward declaration. */ ~BindNodeVisitor() final; + /** + * TODO(Kyle): Document. + */ + std::unordered_map> BindAndGetUDFParams( + common::ManagedPointer parse_result, + common::ManagedPointer udf_ast_context); + /** * Perform binding on the passed in tree. Bind the relation names to oids * @param parse_result Result generated by the parser. A collection of statements and expressions in the query @@ -100,6 +110,11 @@ class BindNodeVisitor final : public SqlNodeVisitor { /** Current context of the query or subquery */ common::ManagedPointer context_ = nullptr; + /** Context for UDF AST */ + common::ManagedPointer udf_ast_context_{}; + /** Parameters for UDF */ + std::unordered_map> udf_params_; + /** Catalog accessor */ const common::ManagedPointer catalog_accessor_; const catalog::db_oid_t db_oid_; diff --git a/src/include/binder/binder_sherpa.h b/src/include/binder/binder_sherpa.h index bd29fb07aa..8d3e0373d3 100644 --- a/src/include/binder/binder_sherpa.h +++ b/src/include/binder/binder_sherpa.h @@ -45,6 +45,15 @@ class BinderSherpa { */ common::ManagedPointer> GetParameters() const { return parameters_; } + /** + * Add a parameter to the binder sherpa state. + * @param param The parameter expression. + */ + void AddParameter(const parser::ConstantValueExpression param) { + parameters_->push_back(param); + desired_parameter_types_->push_back(param.GetReturnValueType()); + } + /** * @param expr The expression whose type constraints we want to look up. * @return The previously recorded type constraints, or the expression's current return value type if none exist. diff --git a/src/include/parser/expression/column_value_expression.h b/src/include/parser/expression/column_value_expression.h index cb68983575..9baebae0d9 100644 --- a/src/include/parser/expression/column_value_expression.h +++ b/src/include/parser/expression/column_value_expression.h @@ -111,6 +111,14 @@ class ColumnValueExpression : public AbstractExpression { /** @return column oid */ catalog::col_oid_t GetColumnOid() const { return column_oid_; } + // TODO(Kyle): Why are we narrowing here? + + /** @return parameter index */ + std::int32_t GetParamIdx() const { return param_idx_; } + + /** @brief set the parameter index */ + void SetParamIdx(std::uint32_t param_idx) { param_idx_ = static_cast(param_idx); } + /** * Get Column Full Name [tbl].[col] */ @@ -195,6 +203,9 @@ class ColumnValueExpression : public AbstractExpression { /** OID of the column */ catalog::col_oid_t column_oid_ = catalog::INVALID_COLUMN_OID; + + /** parameter index */ + std::int32_t param_idx_{-1}; }; DEFINE_JSON_HEADER_DECLARATIONS(ColumnValueExpression); From b8d458c90b1976cd6f25ffd65d2418edda706088 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sat, 3 Apr 2021 21:20:33 -0400 Subject: [PATCH 011/139] integrate execution exec --- src/execution/exec/execution_context.cpp | 4 ---- .../execution/exec/execution_context.h | 21 ++++++++++++++++--- src/include/execution/exec/output.h | 5 +++-- src/include/execution/vm/bytecode_handlers.h | 10 ++++----- .../expression/constant_value_expression.h | 10 +++++++++ .../expression/constant_value_expression.cpp | 5 +++++ src/self_driving/planning/pilot_util.cpp | 8 +++++-- src/traffic_cop/traffic_cop.cpp | 13 +++++------- 8 files changed, 51 insertions(+), 25 deletions(-) diff --git a/src/execution/exec/execution_context.cpp b/src/execution/exec/execution_context.cpp index 7ee3e2e7ac..137d44baf4 100644 --- a/src/execution/exec/execution_context.cpp +++ b/src/execution/exec/execution_context.cpp @@ -195,10 +195,6 @@ void ExecutionContext::InitializeParallelOUFeatureVector(selfdriving::ExecOUFeat } } -const parser::ConstantValueExpression &ExecutionContext::GetParam(const uint32_t param_idx) const { - return (*params_)[param_idx]; -} - void ExecutionContext::RegisterHook(size_t hook_idx, HookFn hook) { NOISEPAGE_ASSERT(hook_idx < hooks_.capacity(), "Incorrect number of reserved hooks"); hooks_[hook_idx] = hook; diff --git a/src/include/execution/exec/execution_context.h b/src/include/execution/exec/execution_context.h index fbdbe670e6..da040d0341 100644 --- a/src/include/execution/exec/execution_context.h +++ b/src/include/execution/exec/execution_context.h @@ -11,6 +11,7 @@ #include "execution/sql/memory_tracker.h" #include "execution/sql/runtime_types.h" #include "execution/sql/thread_state_container.h" +#include "execution/sql/value.h" #include "execution/util/region.h" #include "metrics/metrics_defs.h" #include "planner/plannodes/output_schema.h" @@ -180,6 +181,16 @@ class EXPORT ExecutionContext { */ void InitializeParallelOUFeatureVector(selfdriving::ExecOUFeatureVector *ouvec, pipeline_id_t pipeline_id); + // TODO(Kyle): Document + revisit this. + + void StartParams() { udf_param_stack_.push_back({}); } + + void PopParams() { udf_param_stack_.pop_back(); } + + void AddParam(common::ManagedPointer val) { + udf_param_stack_.back().push_back(val.CastManagedPointerTo()); + } + /** * @return the db oid */ @@ -203,7 +214,7 @@ class EXPORT ExecutionContext { * Set the execution parameters. * @param params The execution parameters. */ - void SetParams(common::ManagedPointer> params) { + void SetParams(common::ManagedPointer>> params) { params_ = params; } @@ -211,7 +222,9 @@ class EXPORT ExecutionContext { * @param param_idx index of parameter to access * @return immutable parameter at provided index */ - const parser::ConstantValueExpression &GetParam(uint32_t param_idx) const; + common::ManagedPointer GetParam(uint32_t param_idx) const { + return udf_param_stack_.empty() ? (*params_)[param_idx] : udf_param_stack_.back()[param_idx]; + } /** * Set the PipelineOperatingUnits @@ -347,12 +360,14 @@ class EXPORT ExecutionContext { common::ManagedPointer accessor_; common::ManagedPointer metrics_manager_; - common::ManagedPointer> params_; + common::ManagedPointer>> params_; uint8_t execution_mode_; uint32_t rows_affected_ = 0; common::ManagedPointer replication_manager_; + std::vector>> udf_param_stack_; + bool memory_use_override_ = false; uint32_t memory_use_override_value_ = 0; uint32_t num_concurrent_estimate_ = 0; diff --git a/src/include/execution/exec/output.h b/src/include/execution/exec/output.h index 6882786402..21742a7acd 100644 --- a/src/include/execution/exec/output.h +++ b/src/include/execution/exec/output.h @@ -89,8 +89,9 @@ class EXPORT OutputBuffer { private: sql::MemoryPool *memory_pool_; - uint32_t num_tuples_; - uint32_t tuple_size_; + // TODO(Kyle): Tanuj made this atomic, does it need to be? + std::uint32_t num_tuples_; + std::uint32_t tuple_size_; byte *tuples_; /** diff --git a/src/include/execution/vm/bytecode_handlers.h b/src/include/execution/vm/bytecode_handlers.h index 056660dbe3..3576278646 100644 --- a/src/include/execution/vm/bytecode_handlers.h +++ b/src/include/execution/vm/bytecode_handlers.h @@ -2119,15 +2119,13 @@ VM_OP_WARM void OpAbortTxn(noisepage::execution::exec::ExecutionContext *exec_ct } // Parameter calls +// TODO(Kyle): this used to have a conditional check; was it safe to remove? #define GEN_SCALAR_PARAM_GET(Name, SqlType) \ VM_OP_HOT void OpGetParam##Name(noisepage::execution::sql::SqlType *ret, \ noisepage::execution::exec::ExecutionContext *exec_ctx, uint32_t param_idx) { \ - const auto &cve = exec_ctx->GetParam(param_idx); \ - if (cve.IsNull()) { \ - ret->is_null_ = true; \ - } else { \ - *ret = cve.Get##SqlType(); \ - } \ + const auto &val = \ + *reinterpret_cast(exec_ctx->GetParam(param_idx).Get()); \ + *ret = val; \ } GEN_SCALAR_PARAM_GET(Bool, BoolVal) diff --git a/src/include/parser/expression/constant_value_expression.h b/src/include/parser/expression/constant_value_expression.h index a6b5caceb7..80a3582d5d 100644 --- a/src/include/parser/expression/constant_value_expression.h +++ b/src/include/parser/expression/constant_value_expression.h @@ -106,6 +106,11 @@ class ConstantValueExpression : public AbstractExpression { } } + // TODO(Kyle): Is this safe? + common::ManagedPointer GetVal() const { + return common::ManagedPointer(&std::get(value_)); + } + /** * @return copy of the underlying Val */ @@ -227,6 +232,11 @@ class ConstantValueExpression : public AbstractExpression { template T Peek() const; + /** + * TODO(Kyle): Document. + */ + const execution::sql::Val *PeekPtr() const; + void Accept(common::ManagedPointer v) override; /** @return A string representation of this ConstantValueExpression. */ diff --git a/src/parser/expression/constant_value_expression.cpp b/src/parser/expression/constant_value_expression.cpp index 17ea24da3e..a39ec07cfa 100644 --- a/src/parser/expression/constant_value_expression.cpp +++ b/src/parser/expression/constant_value_expression.cpp @@ -91,6 +91,11 @@ T ConstantValueExpression::Peek() const { UNREACHABLE("Invalid type for Peek."); } +const execution::sql::Val *ConstantValueExpression::PeekPtr() const { + // TODO(Kyle): seems unsafe. + return reinterpret_cast(&value_); +} + ConstantValueExpression &ConstantValueExpression::operator=(const ConstantValueExpression &other) { if (this != &other) { // self-assignment check expected // AbstractExpression fields we need copied over diff --git a/src/self_driving/planning/pilot_util.cpp b/src/self_driving/planning/pilot_util.cpp index 911a65740b..267f3395dd 100644 --- a/src/self_driving/planning/pilot_util.cpp +++ b/src/self_driving/planning/pilot_util.cpp @@ -225,8 +225,12 @@ const std::list &PilotUtil::Collec auto exec_ctx = std::make_unique( db_oid, common::ManagedPointer(txn), callback, out_plan->GetOutputSchema().Get(), common::ManagedPointer(accessor), exec_settings, metrics_manager, DISABLED); - - exec_ctx->SetParams(common::ManagedPointer>(¶ms)); + // TODO(Kyle): It sucks to call this in a loop... better way? + std::vector> param_values{}; + param_values.reserve(params.size()); + std::transform(params.cbegin(), params.cend(), std::back_inserter(param_values), + [](const parser::ConstantValueExpression &cve) { return common::ManagedPointer{cve.PeekPtr()}; }); + exec_ctx->SetParams(common::ManagedPointer(¶m_values)); exec_query->Run(common::ManagedPointer(exec_ctx), execution::vm::ExecutionMode::Interpret); txn_manager->Abort(txn); } diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index c06d61968c..6aa380aa32 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -446,14 +446,11 @@ TrafficCopResult TrafficCop::RunExecutableQuery(const common::ManagedPointerGetDatabaseOid(), connection_ctx->Transaction(), callback, physical_plan->GetOutputSchema().Get(), connection_ctx->Accessor(), exec_settings, metrics, replication_manager_); - exec_ctx->SetParams(portal->Parameters()); - - // TODO(Kyle): Refactor to algorithm - // std::vector> params{}; - // for (auto &cve : *(portal->Parameters())){ - // params.push_back(common::ManagedPointer(cve.PeekPtr())); - // } - // exec_ctx->SetParams(common::ManagedPointer(¶ms)); + std::vector> params{}; + params.reserve(portal->Parameters()->size()); + std::transform(portal->Parameters()->cbegin(), portal->Parameters()->cend(), std::back_inserter(params), + [](const parser::ConstantValueExpression &cve) { return common::ManagedPointer{cve.PeekPtr()}; }); + exec_ctx->SetParams(common::ManagedPointer(¶ms)); const auto exec_query = portal->GetStatement()->GetExecutableQuery(); From c5fc040bc0469ef984da53fea7ceb9e239fed2c8 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sat, 3 Apr 2021 21:50:55 -0400 Subject: [PATCH 012/139] integrate sema --- src/execution/sema/scope.cpp | 14 +++ src/execution/sema/sema_builtin.cpp | 162 ++++++++++++++++++--------- src/execution/sema/sema_checking.cpp | 6 + src/execution/sema/sema_decl.cpp | 3 + src/execution/sema/sema_expr.cpp | 47 +++++++- src/execution/sema/sema_stmt.cpp | 15 +++ src/include/execution/sema/scope.h | 10 ++ 7 files changed, 204 insertions(+), 53 deletions(-) diff --git a/src/execution/sema/scope.cpp b/src/execution/sema/scope.cpp index 16baf64cd5..f2357eb56f 100644 --- a/src/execution/sema/scope.cpp +++ b/src/execution/sema/scope.cpp @@ -29,4 +29,18 @@ ast::Type *Scope::LookupLocal(ast::Identifier name) const { return (iter == decls_.end() ? nullptr : iter->second); } +Scope::Kind Scope::GetKind() const { return scope_kind_; } + +std::vector> Scope::GetLocals() const { + std::vector> locals; + auto scope = this; + do { + for (auto it : scope->decls_) { + locals.emplace_back(it.first, it.second); + } + scope = scope->outer_; + } while (scope->scope_kind_ != Scope::Kind::Function); + return locals; +} + } // namespace noisepage::execution::sema diff --git a/src/execution/sema/sema_builtin.cpp b/src/execution/sema/sema_builtin.cpp index 949e52cea8..4d7514fadb 100644 --- a/src/execution/sema/sema_builtin.cpp +++ b/src/execution/sema/sema_builtin.cpp @@ -2073,22 +2073,27 @@ void Sema::CheckBuiltinPtrCastCall(ast::CallExpr *call) { return; } + if (call->Arguments()[0]->GetType() != nullptr && call->Arguments()[1]->GetType() != nullptr && + call->Arguments()[0]->GetType()->IsPointerType() && call->Arguments()[1]->GetType()->IsPointerType()) { + return; + } + // The first argument will be a UnaryOpExpr with the '*' (star) op. This is // because parsing function calls assumes expression arguments, not types. So, // something like '*Type', which would be the first argument to @ptrCast, will // get parsed as a dereference expression before a type expression. // TODO(pmenon): Fix the above to parse correctly - auto unary_op = call->Arguments()[0]->SafeAs(); - if (unary_op == nullptr || unary_op->Op() != parsing::Token::Type::STAR) { - GetErrorReporter()->Report(call->Position(), ErrorMessages::kBadArgToPtrCast, call->Arguments()[0]->GetType(), 1); - return; + if (!call->Arguments()[0]->Is()) { + auto unary_op = call->Arguments()[0]->SafeAs(); + if (unary_op == nullptr || unary_op->Op() != parsing::Token::Type::STAR) { + GetErrorReporter()->Report(call->Position(), ErrorMessages::kBadArgToPtrCast, call->Arguments()[0]->GetType(), 1); + return; + } + call->SetArgument( + 0, GetContext()->GetNodeFactory()->NewPointerType(call->Arguments()[0]->Position(), unary_op->Input())); } - // Replace the unary with a PointerTypeRepr node and resolve it - call->SetArgument( - 0, GetContext()->GetNodeFactory()->NewPointerType(call->Arguments()[0]->Position(), unary_op->Input())); - for (auto *arg : call->Arguments()) { auto *resolved_type = Resolve(arg); if (resolved_type == nullptr) { @@ -2836,9 +2841,10 @@ void Sema::CheckBuiltinAbortCall(ast::CallExpr *call) { } void Sema::CheckBuiltinParamCall(ast::CallExpr *call, ast::Builtin builtin) { - if (!CheckArgCount(call, 2)) { - return; - } + // TODO(Kyle): Revisit. + // if (!CheckArgCount(call, 1)) { + // return; + // } // first argument is an exec ctx auto exec_ctx_kind = ast::BuiltinType::ExecutionContext; @@ -2848,48 +2854,92 @@ void Sema::CheckBuiltinParamCall(ast::CallExpr *call, ast::Builtin builtin) { } // second argument is the index of the parameter - if (!call->Arguments()[1]->GetType()->IsIntegerType()) { - ReportIncorrectCallArg(call, 0, GetBuiltinType(ast::BuiltinType::Kind::Uint32)); - return; - } - - // Type output sql value - ast::BuiltinType::Kind sql_type; - switch (builtin) { - case ast::Builtin::GetParamBool: { - sql_type = ast::BuiltinType::Boolean; - break; - } - case ast::Builtin::GetParamTinyInt: - case ast::Builtin::GetParamSmallInt: - case ast::Builtin::GetParamInt: - case ast::Builtin::GetParamBigInt: { - sql_type = ast::BuiltinType::Integer; - break; - } - case ast::Builtin::GetParamReal: - case ast::Builtin::GetParamDouble: { - sql_type = ast::BuiltinType::Real; - break; - } - case ast::Builtin::GetParamDate: { - sql_type = ast::BuiltinType::Date; - break; + if (builtin < ast::Builtin::StartNewParams) { + if (!call->Arguments()[1]->GetType()->IsIntegerType()) { + ReportIncorrectCallArg(call, 1, GetBuiltinType(ast::BuiltinType::Kind::Uint32)); + return; } - case ast::Builtin::GetParamTimestamp: { - sql_type = ast::BuiltinType::Timestamp; - break; + + // Type output sql value + ast::BuiltinType::Kind sql_type; + switch (builtin) { + case ast::Builtin::GetParamBool: { + sql_type = ast::BuiltinType::Boolean; + break; + } + case ast::Builtin::GetParamTinyInt: + case ast::Builtin::GetParamSmallInt: + case ast::Builtin::GetParamInt: + case ast::Builtin::GetParamBigInt: { + sql_type = ast::BuiltinType::Integer; + break; + } + case ast::Builtin::GetParamReal: + case ast::Builtin::GetParamDouble: { + sql_type = ast::BuiltinType::Real; + break; + } + case ast::Builtin::GetParamDate: { + sql_type = ast::BuiltinType::Date; + break; + } + case ast::Builtin::GetParamTimestamp: { + sql_type = ast::BuiltinType::Timestamp; + break; + } + case ast::Builtin::GetParamString: { + sql_type = ast::BuiltinType::StringVal; + break; + } + default: + UNREACHABLE("Undefined parameter call!!"); } - case ast::Builtin::GetParamString: { - sql_type = ast::BuiltinType::StringVal; - break; + // Return sql type + call->SetType(ast::BuiltinType::Get(GetContext(), sql_type)); + return; + } else { + if (builtin > ast::Builtin::FinishNewParams) { + ast::BuiltinType::Kind add_sql_type; + switch (builtin) { + case ast::Builtin::AddParamBool: { + add_sql_type = ast::BuiltinType::Boolean; + break; + } + case ast::Builtin::AddParamTinyInt: + case ast::Builtin::AddParamSmallInt: + case ast::Builtin::AddParamInt: + case ast::Builtin::AddParamBigInt: { + add_sql_type = ast::BuiltinType::Integer; + break; + } + case ast::Builtin::AddParamReal: + case ast::Builtin::AddParamDouble: { + add_sql_type = ast::BuiltinType::Real; + break; + } + case ast::Builtin::AddParamDate: { + add_sql_type = ast::BuiltinType::Date; + break; + } + case ast::Builtin::AddParamTimestamp: { + add_sql_type = ast::BuiltinType::Timestamp; + break; + } + case ast::Builtin::AddParamString: { + add_sql_type = ast::BuiltinType::StringVal; + break; + } + default: { + UNREACHABLE("Undefined parameter call!!"); + } + } + if (call->Arguments()[1]->GetType() != GetBuiltinType(add_sql_type)) { + ReportIncorrectCallArg(call, 1, GetBuiltinType(add_sql_type)); + return; + } } - default: - UNREACHABLE("Undefined parameter call!!"); } - - // Return sql type - call->SetType(ast::BuiltinType::Get(GetContext(), sql_type)); + call->SetType(ast::BuiltinType::Get(GetContext(), ast::BuiltinType::Nil)); } void Sema::CheckBuiltinStringCall(ast::CallExpr *call, ast::Builtin builtin) { @@ -3700,7 +3750,19 @@ void Sema::CheckBuiltinCall(ast::CallExpr *call) { case ast::Builtin::GetParamDouble: case ast::Builtin::GetParamDate: case ast::Builtin::GetParamTimestamp: - case ast::Builtin::GetParamString: { + case ast::Builtin::GetParamString: + case ast::Builtin::AddParamBool: + case ast::Builtin::AddParamTinyInt: + case ast::Builtin::AddParamSmallInt: + case ast::Builtin::AddParamInt: + case ast::Builtin::AddParamBigInt: + case ast::Builtin::AddParamReal: + case ast::Builtin::AddParamDouble: + case ast::Builtin::AddParamDate: + case ast::Builtin::AddParamTimestamp: + case ast::Builtin::AddParamString: + case ast::Builtin::StartNewParams: + case ast::Builtin::FinishNewParams: { CheckBuiltinParamCall(call, builtin); break; } diff --git a/src/execution/sema/sema_checking.cpp b/src/execution/sema/sema_checking.cpp index a97accec57..27c2d6143f 100644 --- a/src/execution/sema/sema_checking.cpp +++ b/src/execution/sema/sema_checking.cpp @@ -269,6 +269,12 @@ bool Sema::CheckAssignmentConstraints(ast::Type *target_type, ast::Expr **expr) return true; } + if (target_type->IsLambdaType() && (*expr)->GetType()->IsLambdaType()) { + auto fn_type = (*expr)->GetType()->As()->GetFunctionType(); + auto target_fn = target_type->As()->GetFunctionType(); + return fn_type->IsEqual(target_fn); + } + // Integer expansion if (target_type->IsIntegerType() && (*expr)->GetType()->IsIntegerType()) { if (target_type->GetSize() > (*expr)->GetType()->GetSize()) { diff --git a/src/execution/sema/sema_decl.cpp b/src/execution/sema/sema_decl.cpp index 71beaf7e81..aa79334a09 100644 --- a/src/execution/sema/sema_decl.cpp +++ b/src/execution/sema/sema_decl.cpp @@ -24,6 +24,9 @@ void Sema::VisitVariableDecl(ast::VariableDecl *node) { } if (node->HasInitialValue()) { + if (node->Initial()->GetKind() == ast::AstNode::Kind::LambdaExpr) { + node->Initial()->As()->name_ = node->Name(); + } initializer_type = Resolve(node->Initial()); } diff --git a/src/execution/sema/sema_expr.cpp b/src/execution/sema/sema_expr.cpp index 332bc8f381..bb5284298f 100644 --- a/src/execution/sema/sema_expr.cpp +++ b/src/execution/sema/sema_expr.cpp @@ -83,6 +83,11 @@ void Sema::VisitCallExpr(ast::CallExpr *node) { return; } + // TODO(Kyle): This seems weird + if (node->GetType() != nullptr) { + return; + } + // Resolve the function type ast::Type *type = Resolve(node->Function()); if (type == nullptr) { @@ -91,13 +96,30 @@ void Sema::VisitCallExpr(ast::CallExpr *node) { // Check that the resolved function type is actually a function auto *func_type = type->SafeAs(); + auto *struct_type = type->SafeAs(); + auto lambda_adjustment = 1; if (func_type == nullptr) { - GetErrorReporter()->Report(node->Position(), ErrorMessages::kNonFunction); - return; + if (struct_type != nullptr) { + func_type = struct_type->GetFunctionType(); + // TODO(Kyle): find a better way to see if sema has processed this already + ast::IdentifierExpr *last_arg = nullptr; + if (!node->Arguments().empty()) { + last_arg = node->Arguments().back()->SafeAs(); + } + if (last_arg != nullptr && last_arg->Name() == node->GetFuncName()) { + // already processed + lambda_adjustment = 0; + } + } else { + GetErrorReporter()->Report(node->Position(), ErrorMessages::kNonFunction); + return; + } } // Check argument count matches - if (!CheckArgCount(node, func_type->GetNumParams())) { + // TODO(Kyle): Refactor this, gross. + if (!CheckArgCount( + node, struct_type != nullptr ? func_type->GetNumParams() - lambda_adjustment : func_type->GetNumParams())) { return; } @@ -133,6 +155,10 @@ void Sema::VisitCallExpr(ast::CallExpr *node) { } } + if (struct_type != nullptr && lambda_adjustment > 0) { + node->PushArgument(GetContext()->GetNodeFactory()->NewIdentifierExpr(SourcePosition(), node->GetFuncName())); + } + if (has_errors) { return; } @@ -141,6 +167,7 @@ void Sema::VisitCallExpr(ast::CallExpr *node) { node->SetType(func_type->GetReturnType()); } +// TODO(Kyle): Implement this void Sema::VisitLambdaExpr(ast::LambdaExpr *node) { // make struct type // node->SetType(Resolve(node->GetFunctionLitExpr()->TypeRepr())); @@ -253,6 +280,15 @@ void Sema::VisitFunctionLitExpr(ast::FunctionLitExpr *node) { // The function scope FunctionSemaScope function_scope(this, node); + if (node->IsLambda()) { + auto ¶ms = func_type->GetParams(); + auto captures = params[params.size() - 1]; + auto capture_type = captures.type_->As(); + for (auto field : capture_type->GetFieldsWithoutPadding()) { + GetCurrentScope()->Declare(field.name_, field.type_->GetPointeeType()->ReferenceTo()); + } + } + // Declare function parameters in scope for (const auto ¶m : func_type->GetParams()) { GetCurrentScope()->Declare(param.name_, param.type_); @@ -288,6 +324,11 @@ void Sema::VisitIdentifierExpr(ast::IdentifierExpr *node) { return; } + if (auto *type = GetCurrentScope()->Lookup(node->Name())) { + node->SetType(type); + return; + } + // Error GetErrorReporter()->Report(node->Position(), ErrorMessages::kUndefinedVariable, node->Name()); } diff --git a/src/execution/sema/sema_stmt.cpp b/src/execution/sema/sema_stmt.cpp index 9b004a064f..c7790fef4a 100644 --- a/src/execution/sema/sema_stmt.cpp +++ b/src/execution/sema/sema_stmt.cpp @@ -24,6 +24,11 @@ void Sema::VisitAssignmentStmt(ast::AssignmentStmt *node) { if (source != node->Source()) { node->SetSource(source); } + + if (src_type->IsFunctionType()) { + // this is a lambda function assignment + node->Source()->As()->name_ = node->Destination()->As()->Name(); + } } void Sema::VisitBlockStmt(ast::BlockStmt *node) { @@ -57,6 +62,16 @@ void Sema::VisitForStmt(ast::ForStmt *node) { return; } // If the resolved type isn't a boolean, it's an error + if (cond_type->IsSqlBooleanType()) { + auto context = GetContext(); + auto factory = context->GetNodeFactory(); + auto args = util::RegionVector({node->Condition()}, context->GetRegion()); + node->SetCond(factory->NewBuiltinCallExpr( + factory->NewIdentifierExpr(node->Position(), + GetContext()->GetBuiltinFunction(execution::ast::Builtin::SqlToBool)), + std::move(args))); + cond_type = Resolve(node->Condition()); + } if (!cond_type->IsBoolType()) { error_reporter_->Report(node->Condition()->Position(), ErrorMessages::kNonBoolForCondition); } diff --git a/src/include/execution/sema/scope.h b/src/include/execution/sema/scope.h index 34533825d6..3a3411d43a 100644 --- a/src/include/execution/sema/scope.h +++ b/src/include/execution/sema/scope.h @@ -67,6 +67,16 @@ class Scope { */ ast::Type *LookupLocal(ast::Identifier name) const; + /** + * TODO(Kyle): Document. + */ + Kind GetKind() const; + + /** + * TODO(Kyle): Document. + */ + std::vector> GetLocals() const; + /** * @return the parent scope */ From 6b1ae5997cd90e26c6b389d3a1024aba335acd82 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sun, 4 Apr 2021 11:17:15 -0400 Subject: [PATCH 013/139] integrate udf lambda parsing --- src/execution/parsing/parser.cpp | 59 +++++++++++++++++++++- src/execution/parsing/scanner.cpp | 2 + src/include/execution/parsing/parser.h | 4 ++ src/include/execution/parsing/token.h | 1 + src/include/execution/sema/error_message.h | 3 +- 5 files changed, 67 insertions(+), 2 deletions(-) diff --git a/src/execution/parsing/parser.cpp b/src/execution/parsing/parser.cpp index 592c91eaca..f8daaedb29 100644 --- a/src/execution/parsing/parser.cpp +++ b/src/execution/parsing/parser.cpp @@ -425,11 +425,45 @@ ast::Expr *Parser::ParseUnaryOpExpr() { return ParsePrimaryExpr(); } +ast::Expr *Parser::ParseLambdaExpr() { + Expect(Token::Type::LAMBDA); + + const SourcePosition &position = scanner_->CurrentPosition(); + + util::RegionVector captures(Region()); + + Expect(Token::Type::LEFT_BRACKET); + + while (Peek() != Token::Type::RIGHT_BRACKET) { + if (Matches(Token::Type::IDENTIFIER)) { + auto var = GetSymbol(); + captures.push_back(new (Region()) ast::IdentifierExpr(position, var)); + } + + if (!Matches(Token::Type::COMMA)) { + break; + } + } + + Expect(Token::Type::RIGHT_BRACKET); + + // The function literal + auto *fun = ParseFunctionLitExpr()->As(); + + // Create declaration + // ast::FunctionDecl *decl = node_factory_->NewFunctionDecl(position, name fun); + auto *lambda = node_factory_->NewLambdaExpr(position, fun, std::move(captures)); + + // Done + return lambda; +} + ast::Expr *Parser::ParsePrimaryExpr() { - // PrimaryExpr = Operand | CallExpr | MemberExpr | IndexExpr ; + // PrimaryExpr = Operand | CallExpr | MemberExpr | IndexExpr | LambdaExpr ; // CallExpr = PrimaryExpr '(' (Expr)* ') ; // MemberExpr = PrimaryExpr '.' Expr // IndexExpr = PrimaryExpr '[' Expr ']' + // LambdaExpr = lambda (FunctionLitExpr) ast::Expr *result = ParseOperand(); @@ -538,6 +572,10 @@ ast::Expr *Parser::ParseOperand() { Expect(Token::Type::RIGHT_PAREN); return expr; } + case Token::Type::LAMBDA: { + return ParseLambdaExpr(); + break; + } default: { break; } @@ -584,6 +622,9 @@ ast::Expr *Parser::ParseType() { case Token::Type::STRUCT: { return ParseStructType(); } + case Token::Type::LAMBDA: { + return ParseLambdaType(); + } default: { break; } @@ -728,4 +769,20 @@ ast::Expr *Parser::ParseMapType() { return node_factory_->NewMapType(position, key_type, value_type); } +ast::Expr *Parser::ParseLambdaType() { + // LambdaType = 'lambda' '[' FunctionExpr ']' ; + + const SourcePosition &position = scanner_->CurrentPosition(); + + Consume(Token::Type::LAMBDA); + + Expect(Token::Type::LEFT_BRACKET); + + ast::Expr *fn_type = ParseFunctionType(); + + Expect(Token::Type::RIGHT_BRACKET); + + return node_factory_->NewLambdaType(position, fn_type); +} + } // namespace noisepage::execution::parsing diff --git a/src/execution/parsing/scanner.cpp b/src/execution/parsing/scanner.cpp index 16c0e18f12..f349f80bb6 100644 --- a/src/execution/parsing/scanner.cpp +++ b/src/execution/parsing/scanner.cpp @@ -298,6 +298,8 @@ Token::Type Scanner::ScanIdentifierOrKeyword() { GROUP_START('i') \ GROUP_ELEM("if", Token::Type::IF) \ GROUP_ELEM("in", Token::Type::IN) \ + GROUP_START('l') \ + GROUP_ELEM("lambda", Token::Type::LAMBDA) \ GROUP_START('m') \ GROUP_ELEM("map", Token::Type::MAP) \ GROUP_START('n') \ diff --git a/src/include/execution/parsing/parser.h b/src/include/execution/parsing/parser.h index e5a35c1ded..4006034de4 100644 --- a/src/include/execution/parsing/parser.h +++ b/src/include/execution/parsing/parser.h @@ -121,6 +121,8 @@ class Parser { ast::Expr *ParseUnaryOpExpr(); + ast::Expr *ParseLambdaExpr(); + ast::Expr *ParsePrimaryExpr(); ast::Expr *ParseOperand(); @@ -139,6 +141,8 @@ class Parser { ast::Expr *ParseMapType(); + ast::Expr *ParseLambdaType(); + private: // The source code scanner Scanner *scanner_; diff --git a/src/include/execution/parsing/token.h b/src/include/execution/parsing/token.h index 16a8e635dd..91a4129fad 100644 --- a/src/include/execution/parsing/token.h +++ b/src/include/execution/parsing/token.h @@ -64,6 +64,7 @@ namespace noisepage::execution::parsing { K(FUN, "fun", 0) \ K(IF, "if", 0) \ K(IN, "in", 0) \ + K(LAMBDA, "lambda", 0) \ K(MAP, "map", 0) \ K(NIL, "nil", 0) \ K(RETURN, "return", 0) \ diff --git a/src/include/execution/sema/error_message.h b/src/include/execution/sema/error_message.h index 8c0a693384..0b7db9fd46 100644 --- a/src/include/execution/sema/error_message.h +++ b/src/include/execution/sema/error_message.h @@ -95,7 +95,8 @@ namespace sema { "indexIteratorFree() expects (*IndexIterator) argument " \ "types. Received type '%0' in position %1", \ (ast::Type *, uint32_t)) \ - F(IsValNullExpectsSqlValue, "@isValNull() expects a SQL value input, received type '%0'", (ast::Type *)) + F(IsValNullExpectsSqlValue, "@isValNull() expects a SQL value input, received type '%0'", (ast::Type *)) \ + F(NoScopeToBreak, "There is no scope to break from in position", ()) /// Define the ErrorMessageId enumeration enum class ErrorMessageId : uint16_t { From 7fc02a7c6824d7fde09a2464aeb2e4b77c6b39a4 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sun, 4 Apr 2021 13:17:42 -0400 Subject: [PATCH 014/139] pull in compiler additions, some places to revisit here --- src/execution/compiler/codegen.cpp | 10 + .../compiler/compilation_context.cpp | 39 ++- src/execution/compiler/executable_query.cpp | 81 ++++-- .../compiler/executable_query_builder.cpp | 2 +- .../expression/expression_translator.cpp | 12 + .../expression/function_translator.cpp | 52 ++++ src/execution/compiler/function_builder.cpp | 46 ++- .../compiler/operator/operator_translator.cpp | 13 + .../compiler/operator/output_translator.cpp | 38 ++- src/execution/compiler/pipeline.cpp | 269 ++++++++++++++++-- src/include/execution/compiler/ast_fwd.h | 1 + src/include/execution/compiler/codegen.h | 13 + .../execution/compiler/compilation_context.h | 27 +- .../execution/compiler/executable_query.h | 62 +++- .../expression/expression_translator.h | 11 + .../compiler/expression/function_translator.h | 19 ++ .../execution/compiler/function_builder.h | 35 ++- .../compiler/operator/operator_translator.h | 18 +- src/include/execution/compiler/pipeline.h | 96 ++++++- 19 files changed, 736 insertions(+), 108 deletions(-) diff --git a/src/execution/compiler/codegen.cpp b/src/execution/compiler/codegen.cpp index 17c191a1db..21db0a7f32 100644 --- a/src/execution/compiler/codegen.cpp +++ b/src/execution/compiler/codegen.cpp @@ -194,6 +194,10 @@ ast::Expr *CodeGen::Float32Type() const { return BuiltinType(ast::BuiltinType::F ast::Expr *CodeGen::Float64Type() const { return BuiltinType(ast::BuiltinType::Float64); } +ast::Expr *CodeGen::LambdaType(ast::Expr *fn_type) { + return context_->GetNodeFactory()->NewLambdaType(position_, fn_type); +} + ast::Expr *CodeGen::PointerType(ast::Expr *base_type_repr) const { // Create the type representation auto *type_repr = context_->GetNodeFactory()->NewPointerType(position_, base_type_repr); @@ -367,6 +371,12 @@ ast::Expr *CodeGen::AccessStructMember(ast::Expr *object, ast::Identifier member return context_->GetNodeFactory()->NewMemberExpr(position_, object, MakeExpr(member)); } +ast::Stmt *CodeGen::Break() { + ast::Stmt *break_stmt = context_->GetNodeFactory()->NewBreakStmt(position_); + NewLine(); + return break_stmt; +} + ast::Stmt *CodeGen::Return() { return Return(nullptr); } ast::Stmt *CodeGen::Return(ast::Expr *ret) { diff --git a/src/execution/compiler/compilation_context.cpp b/src/execution/compiler/compilation_context.cpp index 1f953d4cd5..ab756a16eb 100644 --- a/src/execution/compiler/compilation_context.cpp +++ b/src/execution/compiler/compilation_context.cpp @@ -79,16 +79,21 @@ std::atomic unique_ids{0}; } // namespace CompilationContext::CompilationContext(ExecutableQuery *query, catalog::CatalogAccessor *accessor, - const CompilationMode mode, const exec::ExecutionSettings &settings) + const CompilationMode mode, const exec::ExecutionSettings &settings, + ast::LambdaExpr *output_callback) : unique_id_(unique_ids++), query_(query), mode_(mode), codegen_(query_->GetContext(), accessor), query_state_var_(codegen_.MakeIdentifier("queryState")), - query_state_type_(codegen_.MakeIdentifier("QueryState")), + query_state_type_(codegen_.MakeIdentifier( + output_callback == nullptr ? "QueryState" : output_callback->GetName().GetString() + "QueryState")), query_state_(query_state_type_, [this](CodeGen *codegen) { return codegen->MakeExpr(query_state_var_); }), + output_callback_(output_callback), counters_enabled_(settings.GetIsCountersEnabled()), - pipeline_metrics_enabled_(settings.GetIsPipelineMetricsEnabled()) {} + pipeline_metrics_enabled_(output_callback ? false : settings.GetIsPipelineMetricsEnabled()) {} + +// TODO(Kyle): Why disable pipeline metrics whenever we have an output callback? ast::FunctionDecl *CompilationContext::GenerateInitFunction() { const auto name = codegen_.MakeIdentifier(GetFunctionPrefix() + "_Init"); @@ -155,6 +160,10 @@ void CompilationContext::GeneratePlan(const planner::AbstractPlanNode &plan) { std::vector execution_order; main_pipeline.CollectDependencies(&execution_order); for (auto *pipeline : execution_order) { + if (pipeline->IsPrepared()) { + continue; + } + // Extract and record the translators. // Pipelines require obtaining feature IDs, but features don't exist until translators are extracted. // Therefore translator extraction must happen before pipelines are generated. @@ -173,7 +182,7 @@ void CompilationContext::GeneratePlan(const planner::AbstractPlanNode &plan) { } main_builder.DeclareAll(pipeline_decls); } - pipeline->GeneratePipeline(&main_builder); + pipeline->GeneratePipeline(&main_builder, query_id_t{unique_id_}, output_callback_); } // Register the tear-down function. @@ -190,18 +199,22 @@ void CompilationContext::GeneratePlan(const planner::AbstractPlanNode &plan) { // static std::unique_ptr CompilationContext::Compile(const planner::AbstractPlanNode &plan, const exec::ExecutionSettings &exec_settings, - catalog::CatalogAccessor *accessor, - const CompilationMode mode, - common::ManagedPointer query_text) { + catalog::CatalogAccessor *accessor, CompilationMode mode, + common::ManagedPointer query_text, + ast::LambdaExpr *output_callback, + common::ManagedPointer context) { // The query we're generating code for. - auto query = std::make_unique(plan, exec_settings); + auto query = std::make_unique(plan, exec_settings, context.Get()); // TODO(Lin): Hacking... remove this after getting the counters in query->SetQueryText(query_text); // Generate the plan for the query - CompilationContext ctx(query.get(), accessor, mode, exec_settings); + CompilationContext ctx(query.get(), accessor, mode, exec_settings, output_callback); ctx.GeneratePlan(plan); + // TODO(Kyle): hacking + query->SetQueryStateType(ctx.query_state_.GetType()); + // Done return query; } @@ -219,8 +232,9 @@ void CompilationContext::PrepareOut(const planner::AbstractPlanNode &plan, Pipel } void CompilationContext::Prepare(const planner::AbstractPlanNode &plan, Pipeline *pipeline) { - std::unique_ptr translator; + NOISEPAGE_ASSERT(ops_.find(&plan) == ops_.end(), "plan already prepared"); + std::unique_ptr translator; switch (plan.GetPlanNodeType()) { case planner::PlanNodeType::AGGREGATE: { const auto &aggregation = dynamic_cast(plan); @@ -410,7 +424,10 @@ ExpressionTranslator *CompilationContext::LookupTranslator(const parser::Abstrac return nullptr; } -std::string CompilationContext::GetFunctionPrefix() const { return "Query" + std::to_string(unique_id_); } +std::string CompilationContext::GetFunctionPrefix() const { + return output_callback_ == nullptr ? "Query" + std::to_string(unique_id_) + : output_callback_->GetName().GetString() + "Query" + std::to_string(unique_id_); +} util::RegionVector CompilationContext::QueryParams() const { ast::Expr *state_type = codegen_.PointerType(codegen_.MakeExpr(query_state_type_)); diff --git a/src/execution/compiler/executable_query.cpp b/src/execution/compiler/executable_query.cpp index 75c103504e..cf32ac204d 100644 --- a/src/execution/compiler/executable_query.cpp +++ b/src/execution/compiler/executable_query.cpp @@ -4,6 +4,7 @@ #include "common/error/error_code.h" #include "common/error/exception.h" +#include "execution/ast/ast.h" #include "execution/ast/ast_dump.h" #include "execution/ast/context.h" #include "execution/compiler/compiler.h" @@ -23,8 +24,8 @@ namespace noisepage::execution::compiler { //===----------------------------------------------------------------------===// ExecutableQuery::Fragment::Fragment(std::vector &&functions, std::vector &&teardown_fn, - std::unique_ptr module) - : functions_(std::move(functions)), teardown_fn_(std::move(teardown_fn)), module_(std::move(module)) {} + std::unique_ptr module, ast::File *file) + : functions_(std::move(functions)), teardown_fn_(std::move(teardown_fn)), module_(std::move(module)), file_(file) {} ExecutableQuery::Fragment::~Fragment() = default; @@ -77,27 +78,42 @@ void ExecutableQuery::SetPipelineOperatingUnits(std::unique_ptr("errors_region")), - context_region_(std::make_unique("context_region")), - errors_(std::make_unique(errors_region_.get())), - ast_context_(std::make_unique(context_region_.get(), errors_.get())), - query_state_size_(0), - pipeline_operating_units_(nullptr), - query_id_(query_identifier++) {} +ExecutableQuery::ExecutableQuery(const planner::AbstractPlanNode &plan, const exec::ExecutionSettings &exec_settings, + ast::Context *context) + : plan_{plan}, + exec_settings_{exec_settings}, + context_region_{std::make_unique("context_region")}, + errors_region_{std::make_unique("errors_region")}, + errors_{std::make_unique(errors_region_.get())}, + ast_context_{context}, + query_state_size_{0}, + pipeline_operating_units_{nullptr}, + query_id_{query_identifier++} { + if (ast_context_ == nullptr) { + ast_context_ = new ast::Context(context_region_.get(), errors_.get()); + owned = true; + } else { + owned = false; + } +} ExecutableQuery::ExecutableQuery(const std::string &contents, const common::ManagedPointer exec_ctx, bool is_file, - size_t query_state_size, const exec::ExecutionSettings &exec_settings) + std::size_t query_state_size, const exec::ExecutionSettings &exec_settings, + ast::Context *context) // TODO(WAN): Giant hack for the plan. The whole point is that you have no plan. - : plan_(reinterpret_cast(exec_settings)), exec_settings_(exec_settings) { - context_region_ = std::make_unique("context_region"); - errors_region_ = std::make_unique("error_region"); - errors_ = std::make_unique(errors_region_.get()); - ast_context_ = std::make_unique(context_region_.get(), errors_.get()); - + : plan_{reinterpret_cast(exec_settings)}, + exec_settings_{exec_settings}, + context_region_{std::make_unique("context_region")}, + errors_region_{std::make_unique("error_region")}, + errors_{std::make_unique(errors_region_.get())} { + if (context) { + ast_context_ = context; + owned = false; + } else { + ast_context_ = new ast::Context(context_region_.get(), errors_.get()); + owned = true; + } // Let's scan the source std::string source; if (is_file) { @@ -106,19 +122,21 @@ ExecutableQuery::ExecutableQuery(const std::string &contents, EXECUTION_LOG_ERROR("There was an error reading file '{}': {}", contents, error.message()); return; } - // Copy the source into a temporary, compile, and run source = (*file)->getBuffer().str(); } else { source = contents; } - auto input = Compiler::Input("tpl_source", ast_context_.get(), &source); + auto input = Compiler::Input("tpl_source", ast_context_, &source); auto module = compiler::Compiler::RunCompilationSimple(input); std::vector functions{"main"}; std::vector teardown_functions; - auto fragment = std::make_unique(std::move(functions), std::move(teardown_functions), std::move(module)); + + // TODO(Kyle): bad API + auto fragment = + std::make_unique(std::move(functions), std::move(teardown_functions), std::move(module), nullptr); std::vector> fragments; fragments.emplace_back(std::move(fragment)); @@ -132,8 +150,12 @@ ExecutableQuery::ExecutableQuery(const std::string &contents, } // Needed because we forward-declare classes used as template types to std::unique_ptr<> -ExecutableQuery::~ExecutableQuery() = default; - +ExecutableQuery::~ExecutableQuery() { + // TODO(Kyle): This is a bad ownership model, revisit + if (owned) { + delete ast_context_; + } +} void ExecutableQuery::Setup(std::vector> &&fragments, const std::size_t query_state_size, std::unique_ptr pipeline_operating_units) { NOISEPAGE_ASSERT( @@ -164,10 +186,15 @@ void ExecutableQuery::Run(common::ManagedPointer exec_ct for (const auto &fragment : fragments_) { fragment->Run(query_state.get(), mode); } +} - // We do not currently re-use ExecutionContexts. However, this is unset to help ensure - // we don't *intentionally* retain any dangling pointers. - exec_ctx->SetQueryState(nullptr); +std::vector ExecutableQuery::GetDecls() const { + std::vector decls; + for (auto &f : fragments_) { + auto frag_decls = f->GetFile()->Declarations(); + decls.insert(decls.end(), frag_decls.begin(), frag_decls.end()); + } + return decls; } } // namespace noisepage::execution::compiler diff --git a/src/execution/compiler/executable_query_builder.cpp b/src/execution/compiler/executable_query_builder.cpp index 574790c6eb..e76d130fbf 100644 --- a/src/execution/compiler/executable_query_builder.cpp +++ b/src/execution/compiler/executable_query_builder.cpp @@ -68,7 +68,7 @@ std::unique_ptr ExecutableQueryFragmentBuilder::Compi teardown_names.push_back(decl->Name().GetString()); } return std::make_unique(std::move(step_functions_), std::move(teardown_names), - std::move(module)); + std::move(module), generated_file); } } // namespace noisepage::execution::compiler diff --git a/src/execution/compiler/expression/expression_translator.cpp b/src/execution/compiler/expression/expression_translator.cpp index 043772dced..3b732435a7 100644 --- a/src/execution/compiler/expression/expression_translator.cpp +++ b/src/execution/compiler/expression/expression_translator.cpp @@ -15,4 +15,16 @@ ast::Expr *ExpressionTranslator::GetExecutionContextPtr() const { return compilation_context_->GetExecutionContextPtrFromQueryState(); } +void ExpressionTranslator::DefineHelperFunctions(util::RegionVector *decls) { + for (auto child : expr_.GetChildren()) { + compilation_context_->LookupTranslator(*child)->DefineHelperFunctions(decls); + } +} + +void ExpressionTranslator::DefineHelperStructs(util::RegionVector *decls) { + for (auto child : expr_.GetChildren()) { + compilation_context_->LookupTranslator(*child)->DefineHelperStructs(decls); + } +} + } // namespace noisepage::execution::compiler diff --git a/src/execution/compiler/expression/function_translator.cpp b/src/execution/compiler/expression/function_translator.cpp index 2035a6a7fe..c2cc4d172d 100644 --- a/src/execution/compiler/expression/function_translator.cpp +++ b/src/execution/compiler/expression/function_translator.cpp @@ -1,6 +1,8 @@ #include "execution/compiler/expression/function_translator.h" #include "catalog/catalog_accessor.h" +#include "execution/ast/ast.h" +#include "execution/ast/ast_clone.h" #include "execution/compiler/compilation_context.h" #include "execution/compiler/work_context.h" #include "execution/functions/function_context.h" @@ -34,7 +36,57 @@ ast::Expr *FunctionTranslator::DeriveValue(WorkContext *ctx, const ColumnValuePr params.push_back(derived_expr); } + if (!func_context->IsBuiltin()) { + auto ident_expr = main_fn_; + std::vector args; + for (auto &expr : params) { + args.emplace_back(expr); + } + return GetCodeGen()->Call(ident_expr, std::move(args)); + } + return codegen->CallBuiltin(func_context->GetBuiltin(), params); } +void FunctionTranslator::DefineHelperFunctions(util::RegionVector *decls) { + ExpressionTranslator::DefineHelperFunctions(decls); + auto proc_oid = GetExpressionAs().GetProcOid(); + auto func_context = GetCodeGen()->GetCatalogAccessor()->GetFunctionContext(proc_oid); + if (func_context->IsBuiltin()) { + return; + } + auto *file = reinterpret_cast( + ast::AstClone::Clone(func_context->GetFile(), GetCodeGen()->GetAstContext()->GetNodeFactory(), "", nullptr, + GetCodeGen()->GetAstContext().Get())); + auto udf_decls = file->Declarations(); + main_fn_ = udf_decls.back()->Name(); + size_t num_added = 0; + for (ast::Decl *udf_decl : udf_decls) { + if (udf_decl->IsFunctionDecl()) { + decls->insert(decls->begin() + num_added, udf_decl->As()); + num_added++; + } + } +} + +void FunctionTranslator::DefineHelperStructs(util::RegionVector *decls) { + ExpressionTranslator::DefineHelperStructs(decls); + auto proc_oid = GetExpressionAs().GetProcOid(); + auto func_context = GetCodeGen()->GetCatalogAccessor()->GetFunctionContext(proc_oid); + if (func_context->IsBuiltin()) { + return; + } + auto *file = reinterpret_cast( + ast::AstClone::Clone(func_context->GetFile(), GetCodeGen()->GetAstContext()->GetNodeFactory(), "", nullptr, + GetCodeGen()->GetAstContext().Get())); + auto udf_decls = file->Declarations(); + size_t num_added = 0; + for (ast::Decl *udf_decl : udf_decls) { + if (udf_decl->IsStructDecl()) { + decls->insert(decls->begin() + num_added, udf_decl->As()); + num_added++; + } + } +} + } // namespace noisepage::execution::compiler diff --git a/src/execution/compiler/function_builder.cpp b/src/execution/compiler/function_builder.cpp index a517dd62cf..5413dbbbe5 100644 --- a/src/execution/compiler/function_builder.cpp +++ b/src/execution/compiler/function_builder.cpp @@ -13,7 +13,15 @@ FunctionBuilder::FunctionBuilder(CodeGen *codegen, ast::Identifier name, util::R ret_type_(ret_type), start_(codegen->GetPosition()), statements_(codegen->MakeEmptyBlock()), - decl_(nullptr) {} + is_lambda_(false) {} + +FunctionBuilder::FunctionBuilder(CodeGen *codegen, util::RegionVector &¶ms, ast::Expr *ret_type) + : codegen_(codegen), + params_(std::move(params)), + ret_type_(ret_type), + start_(codegen->GetPosition()), + statements_(codegen->MakeEmptyBlock()), + is_lambda_(true) {} FunctionBuilder::~FunctionBuilder() { Finish(); } @@ -36,8 +44,8 @@ void FunctionBuilder::Append(ast::Expr *expr) { Append(codegen_->GetFactory()->N void FunctionBuilder::Append(ast::VariableDecl *decl) { Append(codegen_->GetFactory()->NewDeclStmt(decl)); } ast::FunctionDecl *FunctionBuilder::Finish(ast::Expr *ret) { - if (decl_ != nullptr) { - return decl_; + if (decl_.fn_decl_ != nullptr) { + return decl_.fn_decl_; } NOISEPAGE_ASSERT(ret == nullptr || statements_->IsEmpty() || !statements_->GetLast()->IsReturnStmt(), @@ -58,10 +66,38 @@ ast::FunctionDecl *FunctionBuilder::Finish(ast::Expr *ret) { // Create the declaration. auto func_lit = codegen_->GetFactory()->NewFunctionLitExpr(func_type, statements_); - decl_ = codegen_->GetFactory()->NewFunctionDecl(start_, name_, func_lit); + decl_.fn_decl_ = codegen_->GetFactory()->NewFunctionDecl(start_, name_, func_lit); + + // Done + return decl_.fn_decl_; +} + +noisepage::execution::ast::LambdaExpr *FunctionBuilder::FinishLambda(util::RegionVector &&captures, + ast::Expr *ret) { + NOISEPAGE_ASSERT(is_lambda_, "Asking to finish a lambda function that's not actually a lambda function"); + if (decl_.lambda_expr_ != nullptr) { + return decl_.lambda_expr_; + } + + NOISEPAGE_ASSERT(ret == nullptr || statements_->IsEmpty() || !statements_->GetLast()->IsReturnStmt(), + "Double-return at end of function. You should either call FunctionBuilder::Finish() " + "with an explicit return expression, or use the factory to manually append a return " + "statement and call FunctionBuilder::Finish() with a null return."); + // Add the return. + if (!statements_->IsEmpty() && !statements_->GetLast()->IsReturnStmt()) { + Append(codegen_->GetFactory()->NewReturnStmt(codegen_->GetPosition(), ret)); + } + // Finalize everything. + statements_->SetRightBracePosition(codegen_->GetPosition()); + // Build the function's type. + auto func_type = codegen_->GetFactory()->NewFunctionType(start_, std::move(params_), ret_type_); + + // Create the declaration. + auto func_lit = codegen_->GetFactory()->NewFunctionLitExpr(func_type, statements_); + decl_.lambda_expr_ = codegen_->GetFactory()->NewLambdaExpr(start_, func_lit, std::move(captures)); // Done - return decl_; + return decl_.lambda_expr_; } } // namespace noisepage::execution::compiler diff --git a/src/execution/compiler/operator/operator_translator.cpp b/src/execution/compiler/operator/operator_translator.cpp index e433699f28..6573a0a1ff 100644 --- a/src/execution/compiler/operator/operator_translator.cpp +++ b/src/execution/compiler/operator/operator_translator.cpp @@ -26,6 +26,7 @@ OperatorTranslator::OperatorTranslator(const planner::AbstractPlanNode &plan, Co pipeline->RegisterStep(this); // Prepare all output expressions. for (const auto &output_column : plan.GetOutputSchema()->GetColumns()) { + compilation_context->SetCurrentOp(this); compilation_context->Prepare(*output_column.GetExpr()); } } @@ -44,6 +45,18 @@ ast::Expr *OperatorTranslator::GetOutput(WorkContext *context, uint32_t attr_idx return context->DeriveValue(*output_expression, this); } +void OperatorTranslator::DefineHelperFunctions(util::RegionVector *decls) { + for (const auto &output_column : GetPlan().GetOutputSchema()->GetColumns()) { + GetCompilationContext()->LookupTranslator(*output_column.GetExpr())->DefineHelperFunctions(decls); + } +} + +void OperatorTranslator::DefineHelperStructs(util::RegionVector *decls) { + for (const auto &output_column : GetPlan().GetOutputSchema()->GetColumns()) { + GetCompilationContext()->LookupTranslator(*output_column.GetExpr())->DefineHelperStructs(decls); + } +} + ast::Expr *OperatorTranslator::GetChildOutput(WorkContext *context, uint32_t child_idx, uint32_t attr_idx) const { // Check valid child. if (child_idx >= plan_.GetChildrenSize()) { diff --git a/src/execution/compiler/operator/output_translator.cpp b/src/execution/compiler/operator/output_translator.cpp index 6c4ac63df2..52d3a264fa 100644 --- a/src/execution/compiler/operator/output_translator.cpp +++ b/src/execution/compiler/operator/output_translator.cpp @@ -18,7 +18,8 @@ OutputTranslator::OutputTranslator(const planner::AbstractPlanNode &plan, Compil Pipeline *pipeline) : OperatorTranslator(plan, compilation_context, pipeline, selfdriving::ExecutionOperatingUnitType::OUTPUT), output_var_(GetCodeGen()->MakeFreshIdentifier("outRow")), - output_struct_(GetCodeGen()->MakeFreshIdentifier("OutputStruct")) { + output_struct_(GetCodeGen()->MakeFreshIdentifier( + "OutputStruct" + std::to_string(compilation_context->GetQueryId().UnderlyingValue()))) { // Prepare the child. compilation_context->Prepare(plan, pipeline); @@ -28,6 +29,10 @@ OutputTranslator::OutputTranslator(const planner::AbstractPlanNode &plan, Compil } void OutputTranslator::InitializePipelineState(const Pipeline &pipeline, FunctionBuilder *function) const { + if (GetCompilationContext()->GetOutputCallback()) { + return; + } + auto exec_ctx = GetExecutionContext(); auto *new_call = GetCodeGen()->CallBuiltin(ast::Builtin::ResultBufferNew, {exec_ctx}); function->Append(GetCodeGen()->Assign(output_buffer_.Get(GetCodeGen()), new_call)); @@ -36,6 +41,10 @@ void OutputTranslator::InitializePipelineState(const Pipeline &pipeline, Functio } void OutputTranslator::TearDownPipelineState(const Pipeline &pipeline, FunctionBuilder *function) const { + if (GetCompilationContext()->GetOutputCallback()) { + return; + } + auto out_buffer = output_buffer_.Get(GetCodeGen()); ast::Expr *alloc_call = GetCodeGen()->CallBuiltin(ast::Builtin::ResultBufferFree, {out_buffer}); function->Append(GetCodeGen()->MakeStmt(alloc_call)); @@ -43,20 +52,37 @@ void OutputTranslator::TearDownPipelineState(const Pipeline &pipeline, FunctionB void OutputTranslator::PerformPipelineWork(noisepage::execution::compiler::WorkContext *context, noisepage::execution::compiler::FunctionBuilder *function) const { - // First generate the call @resultBufferAllocRow(execCtx) auto out_buffer = output_buffer_.Get(GetCodeGen()); - ast::Expr *alloc_call = GetCodeGen()->CallBuiltin(ast::Builtin::ResultBufferAllocOutRow, {out_buffer}); - ast::Expr *cast_call = GetCodeGen()->PtrCast(output_struct_, alloc_call); + ast::Expr *cast_call; + auto callback = GetCompilationContext()->GetOutputCallback(); + if (callback) { + auto output = GetCodeGen()->MakeFreshIdentifier("output_row"); + auto *row_alloc = GetCodeGen()->DeclareVarNoInit(output, GetCodeGen()->MakeExpr(output_struct_)); + function->Append(row_alloc); + cast_call = GetCodeGen()->AddressOf(GetCodeGen()->MakeExpr(output)); + } else { + ast::Expr *alloc_call = GetCodeGen()->CallBuiltin(ast::Builtin::ResultBufferAllocOutRow, {out_buffer}); + cast_call = GetCodeGen()->PtrCast(output_struct_, alloc_call); + } + function->Append(GetCodeGen()->DeclareVar(output_var_, nullptr, cast_call)); const auto child_translator = GetCompilationContext()->LookupTranslator(GetPlan()); // Now fill up the output row // For each column in the output, set out.col_i = col_i + std::vector callback_args{GetExecutionContext()}; for (uint32_t attr_idx = 0; attr_idx < GetPlan().GetOutputSchema()->NumColumns(); attr_idx++) { ast::Identifier attr_name = GetCodeGen()->MakeIdentifier(OUTPUT_COL_PREFIX + std::to_string(attr_idx)); ast::Expr *lhs = GetCodeGen()->AccessStructMember(GetCodeGen()->MakeExpr(output_var_), attr_name); ast::Expr *rhs = child_translator->GetOutput(context, attr_idx); function->Append(GetCodeGen()->Assign(lhs, rhs)); + if (callback) { + callback_args.push_back(lhs); + } + } + + if (callback) { + function->Append(GetCodeGen()->Call(callback->As()->GetName(), std::move(callback_args))); } CounterAdd(function, num_output_, 1); @@ -79,6 +105,10 @@ void OutputTranslator::EndParallelPipelineWork(const Pipeline &pipeline, Functio } void OutputTranslator::FinishPipelineWork(const Pipeline &pipeline, FunctionBuilder *function) const { + if (GetCompilationContext()->GetOutputCallback()) { + return; + } + auto out_buffer = output_buffer_.Get(GetCodeGen()); function->Append(GetCodeGen()->CallBuiltin(ast::Builtin::ResultBufferFinalize, {out_buffer})); diff --git a/src/execution/compiler/pipeline.cpp b/src/execution/compiler/pipeline.cpp index cfa60a5399..3417079ea6 100644 --- a/src/execution/compiler/pipeline.cpp +++ b/src/execution/compiler/pipeline.cpp @@ -15,6 +15,7 @@ #include "loggers/execution_logger.h" #include "metrics/metrics_defs.h" #include "planner/plannodes/abstract_plan_node.h" +#include "planner/plannodes/output_schema.h" #include "spdlog/fmt/fmt.h" namespace noisepage::execution::compiler { @@ -25,14 +26,16 @@ Pipeline::Pipeline(CompilationContext *ctx) : id_(ctx->RegisterPipeline(this)), compilation_context_(ctx), codegen_(compilation_context_->GetCodeGen()), + state_var_(codegen_->MakeIdentifier("pipelineState")), + state_(codegen_->MakeIdentifier(fmt::format("P{}{}_State", ctx->GetFunctionPrefix(), id_)), + [this](CodeGen *codegen) { return codegen_->MakeExpr(state_var_); }), driver_(nullptr), parallelism_(Parallelism::Parallel), check_parallelism_(true), - state_var_(codegen_->MakeIdentifier("pipelineState")), - state_(codegen_->MakeIdentifier(fmt::format("P{}_State", id_)), - [this](CodeGen *codegen) { return codegen_->MakeExpr(state_var_); }) {} + nested_(false) {} -Pipeline::Pipeline(OperatorTranslator *op, Pipeline::Parallelism parallelism) : Pipeline(op->GetCompilationContext()) { +Pipeline::Pipeline(OperatorTranslator *op, Pipeline::Parallelism parallelism, bool consumer) + : Pipeline(op->GetCompilationContext()) { UpdateParallelism(parallelism); RegisterStep(op); } @@ -70,7 +73,8 @@ void Pipeline::RegisterExpression(ExpressionTranslator *expression) { } StateDescriptor::Entry Pipeline::DeclarePipelineStateEntry(const std::string &name, ast::Expr *type_repr) { - return state_.DeclareStateEntry(codegen_, name, type_repr); + auto &state = GetPipelineStateDescriptor(); + return state.DeclareStateEntry(codegen_, name, type_repr); } std::string Pipeline::CreatePipelineFunctionName(const std::string &func_name) const { @@ -144,7 +148,8 @@ util::RegionVector Pipeline::PipelineParams() const { // The main query parameters. util::RegionVector query_params = compilation_context_->QueryParams(); // Tag on the pipeline state. - ast::Expr *pipeline_state = codegen_->PointerType(codegen_->MakeExpr(state_.GetTypeName())); + auto &state = GetPipelineStateDescriptor(); + ast::Expr *pipeline_state = codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName())); query_params.push_back(codegen_->MakeField(state_var_, pipeline_state)); return query_params; } @@ -152,6 +157,28 @@ util::RegionVector Pipeline::PipelineParams() const { void Pipeline::LinkSourcePipeline(Pipeline *dependency) { NOISEPAGE_ASSERT(dependency != nullptr, "Source cannot be null"); dependencies_.push_back(dependency); + if (std::find(dependency->nested_pipelines_.begin(), dependency->nested_pipelines_.end(), this) != + dependency->nested_pipelines_.end()) { + std::remove(dependency->nested_pipelines_.begin(), dependency->nested_pipelines_.end(), this); + } +} + +void Pipeline::LinkNestedPipeline(Pipeline *pipeline, const OperatorTranslator *op) { + NOISEPAGE_ASSERT(pipeline != nullptr, "Nested pipeline cannot be null"); + // if pipeline is in my dependencies let's not do this to avoid circularity + if (std::find(dependencies_.begin(), dependencies_.end(), pipeline) == dependencies_.end()) { + pipeline->nested_pipelines_.push_back(this); + } + if (!pipeline->nested_) { + pipeline->nested_ = true; + // add to pipeline params + size_t i = 0; + for (auto &col : op->GetPlan().GetOutputSchema()->GetColumns()) { + pipeline->extra_pipeline_params_.push_back( + codegen_->MakeField(codegen_->MakeIdentifier("row" + std::to_string(i++)), + codegen_->PointerType(codegen_->TplType(sql::GetTypeId(col.GetType()))))); + } + } } void Pipeline::CollectDependencies(std::vector *deps) { @@ -159,6 +186,20 @@ void Pipeline::CollectDependencies(std::vector *deps) { pipeline->CollectDependencies(deps); } deps->push_back(this); + for (auto *pipeline : nested_pipelines_) { + pipeline->CollectDependencies(deps); + } +} + +void Pipeline::CollectDependencies(std::vector *deps) const { + for (auto *pipeline : dependencies_) { + pipeline->CollectDependencies(deps); + } + + for (auto *pipeline : nested_pipelines_) { + pipeline->CollectDependencies(deps); + } + deps->push_back(this); } void Pipeline::Prepare(const exec::ExecutionSettings &exec_settings) { @@ -167,6 +208,7 @@ void Pipeline::Prepare(const exec::ExecutionSettings &exec_settings) { ast::Expr *type = codegen_->BuiltinType(ast::BuiltinType::ExecOUFeatureVector); oufeatures_ = DeclarePipelineStateEntry("execFeatures", type); } + // if this pipeline is nested, it doesn't own its pipeline state state_.ConstructFinalType(codegen_); // Finalize the execution mode. We choose serial execution if ANY of the below @@ -197,6 +239,8 @@ void Pipeline::Prepare(const exec::ExecutionSettings &exec_settings) { EXECUTION_LOG_TRACE("Pipeline-{}: parallel={}, vectorized={}, steps=[{}]", id_, IsParallel(), IsVectorized(), result); } + + prepared_ = true; } ast::FunctionDecl *Pipeline::GenerateSetupPipelineStateFunction() const { @@ -232,10 +276,44 @@ ast::FunctionDecl *Pipeline::GenerateTearDownPipelineStateFunction() const { return builder.Finish(); } -ast::FunctionDecl *Pipeline::GenerateInitPipelineFunction() const { +ast::FunctionDecl *Pipeline::GeneratePipelineWrapperFunction(ast::LambdaExpr *output_callback) const { + auto name = codegen_->MakeIdentifier(CreatePipelineFunctionName("RunAll")); + auto params = compilation_context_->QueryParams(); + auto run_params = params; + if (output_callback != nullptr) { + run_params.push_back(codegen_->MakeField(output_callback->GetName(), + codegen_->LambdaType(output_callback->GetFunctionLitExpr()->TypeRepr()))); + } + FunctionBuilder builder(codegen_, name, std::move(run_params), codegen_->Nil()); + { + CodeGen::CodeScope code_scope(codegen_); + ast::Identifier p_state = codegen_->MakeFreshIdentifier("pipeline_state"); + builder.Append(codegen_->DeclareVarNoInit(p_state, state_.GetType()->TypeRepr())); + auto query_state_param = builder.GetParameterByPosition(0); + auto p_state_ptr = codegen_->AddressOf(p_state); + auto lambda_call = builder.GetParameterByPosition(1); + builder.Append(codegen_->Call(GetInitPipelineFunctionName(), {query_state_param, p_state_ptr})); + builder.Append(codegen_->Call(GetRunPipelineFunctionName(), {query_state_param, p_state_ptr, lambda_call})); + builder.Append(codegen_->Call(GetTeardownPipelineFunctionName(), {query_state_param, p_state_ptr})); + } + + return builder.Finish(); +} + +ast::FunctionDecl *Pipeline::GenerateInitPipelineFunction(ast::LambdaExpr *output_callback) const { auto query_state = compilation_context_->GetQueryState(); auto name = codegen_->MakeIdentifier(CreatePipelineFunctionName("Init")); - FunctionBuilder builder(codegen_, name, compilation_context_->QueryParams(), codegen_->Nil()); + auto params = compilation_context_->QueryParams(); + ast::FieldDecl *p_state_ptr = nullptr; + auto &state = GetPipelineStateDescriptor(); + uint32_t p_state_ind = 0; + if (nested_ || output_callback) { + p_state_ptr = codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), + codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName()))); + params.push_back(p_state_ptr); + p_state_ind = params.size() - 1; + } + FunctionBuilder builder(codegen_, name, std::move(params), codegen_->Nil()); { CodeGen::CodeScope code_scope(codegen_); // var tls = @execCtxGetTLS(exec_ctx) @@ -244,21 +322,34 @@ ast::FunctionDecl *Pipeline::GenerateInitPipelineFunction() const { builder.Append(codegen_->DeclareVarWithInit(tls, codegen_->ExecCtxGetTLS(exec_ctx))); // @tlsReset(tls, @sizeOf(ThreadState), init, tearDown, queryState) ast::Expr *state_ptr = query_state->GetStatePointer(codegen_); - builder.Append(codegen_->TLSReset(codegen_->MakeExpr(tls), state_.GetTypeName(), - GetSetupPipelineStateFunctionName(), GetTearDownPipelineStateFunctionName(), - state_ptr)); + if (!nested_ && !output_callback) { + builder.Append(codegen_->TLSReset(codegen_->MakeExpr(tls), state.GetTypeName(), + GetSetupPipelineStateFunctionName(), GetTearDownPipelineStateFunctionName(), + state_ptr)); + } else { + auto pipeline_state = builder.GetParameterByPosition(p_state_ind); + builder.Append(codegen_->Call(GetSetupPipelineStateFunctionName(), {state_ptr, pipeline_state})); + } } return builder.Finish(); } -ast::FunctionDecl *Pipeline::GeneratePipelineWorkFunction() const { +ast::FunctionDecl *Pipeline::GeneratePipelineWorkFunction(ast::LambdaExpr *output_callback) const { auto params = PipelineParams(); + for (auto field : extra_pipeline_params_) { + params.push_back(field); + } if (IsParallel()) { auto additional_params = driver_->GetWorkerParams(); params.insert(params.end(), additional_params.begin(), additional_params.end()); } + if (output_callback != nullptr) { + params.push_back(codegen_->MakeField(output_callback->GetName(), + codegen_->LambdaType(output_callback->GetFunctionLitExpr()->TypeRepr()))); + } + FunctionBuilder builder(codegen_, GetWorkFunctionName(), std::move(params), codegen_->Nil()); { // Begin a new code scope for fresh variables. @@ -286,10 +377,85 @@ ast::FunctionDecl *Pipeline::GeneratePipelineWorkFunction() const { return builder.Finish(); } -ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction() const { +std::vector Pipeline::CallSingleRunPipelineFunction() const { + NOISEPAGE_ASSERT(!nested_, "can't call a nested pipeline like this"); + return { + codegen_->Call(GetInitPipelineFunctionName(), {compilation_context_->GetQueryState()->GetStatePointer(codegen_)}), + codegen_->Call(GetRunPipelineFunctionName(), {compilation_context_->GetQueryState()->GetStatePointer(codegen_)}), + codegen_->Call(GetTeardownPipelineFunctionName(), + {compilation_context_->GetQueryState()->GetStatePointer(codegen_)})}; +} + +void Pipeline::CallNestedRunPipelineFunction(WorkContext *ctx, const OperatorTranslator *op, + FunctionBuilder *function) const { + std::vector stmts; + auto p_state = codegen_->MakeFreshIdentifier("nested_state"); + auto p_state_ptr = codegen_->AddressOf(p_state); + + std::vector params_vec = {compilation_context_->GetQueryState()->GetStatePointer(codegen_)}; + params_vec.push_back(p_state_ptr); + + for (size_t i = 0; i < op->GetPlan().GetOutputSchema()->GetColumns().size(); i++) { + params_vec.push_back(codegen_->AddressOf(op->GetOutput(ctx, i))); + } + + function->Append(codegen_->DeclareVarNoInit(p_state, codegen_->MakeExpr(GetPipelineStateDescriptor().GetTypeName()))); + function->Append(codegen_->Call(GetInitPipelineFunctionName(), + {compilation_context_->GetQueryState()->GetStatePointer(codegen_), p_state_ptr})); + function->Append(codegen_->Call(GetRunPipelineFunctionName(), params_vec)); + function->Append(codegen_->Call(GetTeardownPipelineFunctionName(), + {compilation_context_->GetQueryState()->GetStatePointer(codegen_), p_state_ptr})); + return; +} + +std::vector Pipeline::CallRunPipelineFunction() const { + std::vector calls; + std::vector pipelines; + CollectDependencies(&pipelines); + for (auto pipeline : pipelines) { + if (!pipeline->nested_ || (pipeline == this)) { + for (auto call : CallSingleRunPipelineFunction()) { + calls.push_back(call); + } + } + } + return calls; +} + +ast::Identifier Pipeline::GetInitPipelineFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName("Init")); +} + +ast::Identifier Pipeline::GetTeardownPipelineFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName("TearDown")); +} + +ast::Identifier Pipeline::GetRunPipelineFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName("Run")); +} + +ast::Expr *Pipeline::GetNestedInputArg(uint32_t index) const { + NOISEPAGE_ASSERT(nested_, "Asking for input arg on non-nested pipeline"); + NOISEPAGE_ASSERT(index < extra_pipeline_params_.size(), + "Asking for input arg on non-nested pipeline that doesn't exist"); + return codegen_->UnaryOp(parsing::Token::Type::STAR, codegen_->MakeExpr(extra_pipeline_params_[index]->Name())); +} + +ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction(query_id_t query_id, ast::LambdaExpr *output_callback) const { bool started_tracker = false; auto name = codegen_->MakeIdentifier(CreatePipelineFunctionName("Run")); - FunctionBuilder builder(codegen_, name, compilation_context_->QueryParams(), codegen_->Nil()); + auto params = compilation_context_->QueryParams(); + if (nested_ || output_callback) { + params.push_back(codegen_->MakeField(state_var_, codegen_->PointerType(state_.GetTypeName()))); + } + for (auto field : extra_pipeline_params_) { + params.push_back(field); + } + if (output_callback) { + params.push_back(codegen_->MakeField(output_callback->GetName(), + codegen_->LambdaType(output_callback->GetFunctionLitExpr()->TypeRepr()))); + } + FunctionBuilder builder(codegen_, name, std::move(params), codegen_->Nil()); { // Begin a new code scope for fresh variables. CodeGen::CodeScope code_scope(codegen_); @@ -304,7 +470,8 @@ ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction() const { // var pipelineState = @tlsGetCurrentThreadState(...) auto exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); auto tls = codegen_->ExecCtxGetTLS(exec_ctx); - auto state = codegen_->TLSAccessCurrentThreadState(tls, state_.GetTypeName()); + auto state_type = GetPipelineStateDescriptor().GetTypeName(); + auto state = codegen_->TLSAccessCurrentThreadState(tls, state_type); builder.Append(codegen_->DeclareVarWithInit(state_var_, state)); // Launch pipeline work. @@ -315,8 +482,19 @@ ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction() const { InjectStartResourceTracker(&builder, false); started_tracker = true; - builder.Append( - codegen_->Call(GetWorkFunctionName(), {builder.GetParameterByPosition(0), codegen_->MakeExpr(state_var_)})); + std::vector args = {builder.GetParameterByPosition(0), codegen_->MakeExpr(state_var_)}; + if (nested_) { + size_t i = args.size(); + ast::Expr *arg = builder.GetParameterByPosition(i++); + while (arg != nullptr) { + args.push_back(arg); + arg = builder.GetParameterByPosition(i++); + } + } + if (output_callback && !nested_) { + args.push_back(codegen_->MakeExpr(output_callback->GetName())); + } + builder.Append(codegen_->Call(GetWorkFunctionName(), std::move(args))); } // TODO(abalakum): This shouldn't actually be dependent on order and the loop can be simplified @@ -334,38 +512,67 @@ ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction() const { return builder.Finish(); } -ast::FunctionDecl *Pipeline::GenerateTearDownPipelineFunction() const { +ast::FunctionDecl *Pipeline::GenerateTearDownPipelineFunction(ast::LambdaExpr *output_callback) const { auto name = codegen_->MakeIdentifier(CreatePipelineFunctionName("TearDown")); - FunctionBuilder builder(codegen_, name, compilation_context_->QueryParams(), codegen_->Nil()); + auto params = compilation_context_->QueryParams(); + ast::FieldDecl *p_state_ptr = nullptr; + auto &state = GetPipelineStateDescriptor(); + uint32_t p_state_index = 0; + if (nested_ || output_callback) { + p_state_ptr = codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), + codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName()))); + params.push_back(p_state_ptr); + p_state_index = params.size() - 1; + } + + FunctionBuilder builder(codegen_, name, std::move(params), codegen_->Nil()); { // Begin a new code scope for fresh variables. CodeGen::CodeScope code_scope(codegen_); - // Tear down thread local state if parallel pipeline. - ast::Expr *exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); - builder.Append(codegen_->TLSClear(codegen_->ExecCtxGetTLS(exec_ctx))); + if (!nested_ && !output_callback) { + // Tear down thread local state if parallel pipeline. + ast::Expr *exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); + builder.Append(codegen_->TLSClear(codegen_->ExecCtxGetTLS(exec_ctx))); + + auto call = codegen_->CallBuiltin(ast::Builtin::EnsureTrackersStopped, {exec_ctx}); + builder.Append(codegen_->MakeStmt(call)); + } else { + auto query_state = compilation_context_->GetQueryState(); + auto state_ptr = query_state->GetStatePointer(codegen_); - auto call = codegen_->CallBuiltin(ast::Builtin::EnsureTrackersStopped, {exec_ctx}); - builder.Append(codegen_->MakeStmt(call)); + auto pipeline_state = builder.GetParameterByPosition(p_state_index); + auto call = codegen_->Call(GetTearDownPipelineStateFunctionName(), {state_ptr, pipeline_state}); + builder.Append(codegen_->MakeStmt(call)); + } } return builder.Finish(); } -void Pipeline::GeneratePipeline(ExecutableQueryFragmentBuilder *builder) const { +void Pipeline::GeneratePipeline(ExecutableQueryFragmentBuilder *builder, query_id_t query_id, + ast::LambdaExpr *output_callback) const { // Declare the pipeline state. builder->DeclareStruct(state_.GetType()); - // Generate pipeline state initialization and tear-down functions. builder->DeclareFunction(GenerateSetupPipelineStateFunction()); builder->DeclareFunction(GenerateTearDownPipelineStateFunction()); // Generate main pipeline logic. - builder->DeclareFunction(GeneratePipelineWorkFunction()); + builder->DeclareFunction(GeneratePipelineWorkFunction(output_callback)); + builder->DeclareFunction(GenerateRunPipelineFunction(query_id, output_callback)); + builder->DeclareFunction(GenerateInitPipelineFunction(output_callback)); + auto teardown = GenerateTearDownPipelineFunction(output_callback); + builder->DeclareFunction(teardown); // Register the main init, run, tear-down functions as steps, in that order. - builder->RegisterStep(GenerateInitPipelineFunction()); - builder->RegisterStep(GenerateRunPipelineFunction()); - auto teardown = GenerateTearDownPipelineFunction(); - builder->RegisterStep(teardown); + if (output_callback) { + auto fn = GeneratePipelineWrapperFunction(output_callback); + builder->DeclareFunction(fn); + builder->RegisterStep(fn); + } else if (!nested_) { + builder->RegisterStep(GenerateInitPipelineFunction(output_callback)); + builder->RegisterStep(GenerateRunPipelineFunction(query_id, output_callback)); + builder->RegisterStep(teardown); + } builder->AddTeardownFn(teardown); } diff --git a/src/include/execution/compiler/ast_fwd.h b/src/include/execution/compiler/ast_fwd.h index 1f94ff4ec4..b12eb753ad 100644 --- a/src/include/execution/compiler/ast_fwd.h +++ b/src/include/execution/compiler/ast_fwd.h @@ -12,6 +12,7 @@ class Decl; class FieldDecl; class File; class FunctionDecl; +class LambdaExpr; class Stmt; class StructDecl; class VariableDecl; diff --git a/src/include/execution/compiler/codegen.h b/src/include/execution/compiler/codegen.h index 01c799905c..8fdb10bb6c 100644 --- a/src/include/execution/compiler/codegen.h +++ b/src/include/execution/compiler/codegen.h @@ -195,6 +195,11 @@ class CodeGen { */ [[nodiscard]] ast::Expr *Float64Type() const; + /** + * @return The type representation for a TPL lambda. + */ + [[nodiscard]] ast::Expr *LambdaType(ast::Expr *fn_type); + /** * @return The type representation for the provided builtin type. */ @@ -411,6 +416,14 @@ class CodeGen { */ [[nodiscard]] ast::Expr *AccessStructMember(ast::Expr *object, ast::Identifier member); + // TODO(Kyle): These should be in a different section? + + /** + * Create a break statement. + * @return The statement. + */ + [[nodiscard]] ast::Stmt *Break(); + /** * Create a return statement without a return value. * @return The statement. diff --git a/src/include/execution/compiler/compilation_context.h b/src/include/execution/compiler/compilation_context.h index aaa556c793..c36b7d3e18 100644 --- a/src/include/execution/compiler/compilation_context.h +++ b/src/include/execution/compiler/compilation_context.h @@ -53,7 +53,9 @@ class CompilationContext { const exec::ExecutionSettings &exec_settings, catalog::CatalogAccessor *accessor, CompilationMode mode = CompilationMode::Interleaved, - common::ManagedPointer query_text = nullptr); + common::ManagedPointer query_text = nullptr, + ast::LambdaExpr *output_callback = nullptr, + common::ManagedPointer context = nullptr); /** * Register a pipeline in this context. @@ -117,6 +119,11 @@ class CompilationContext { */ CompilationMode GetCompilationMode() const { return mode_; } + /** + * @return The output callback. + */ + ast::Expr *GetOutputCallback() const { return output_callback_; } + /** @return True if we should collect counters in TPL, used for Lin's models. */ bool IsCountersEnabled() const { return counters_enabled_; } @@ -126,10 +133,20 @@ class CompilationContext { /** @return Query Id associated with the query */ query_id_t GetQueryId() const { return query_id_t{unique_id_}; } + /** + * @brief Set the current op. + */ + void SetCurrentOp(OperatorTranslator *current_op) { current_op_ = current_op; } + + /** + * @return The current op. + */ + OperatorTranslator *GetCurrentOp() const { return current_op_; } + private: // Private to force use of static Compile() function. explicit CompilationContext(ExecutableQuery *query, catalog::CatalogAccessor *accessor, CompilationMode mode, - const exec::ExecutionSettings &exec_settings); + const exec::ExecutionSettings &exec_settings, ast::LambdaExpr *output_callback = nullptr); // Given a plan node, compile it into a compiled query object. void GeneratePlan(const planner::AbstractPlanNode &plan); @@ -164,6 +181,9 @@ class CompilationContext { StateDescriptor query_state_; StateDescriptor::Entry exec_ctx_; + // The output callback. + ast::LambdaExpr *output_callback_; + // The operator and expression translators. std::unordered_map> ops_; std::unordered_map> expressions_; @@ -176,6 +196,9 @@ class CompilationContext { // Whether pipeline metrics are enabled. bool pipeline_metrics_enabled_; + + // The current operator. + OperatorTranslator *current_op_; }; } // namespace noisepage::execution::compiler diff --git a/src/include/execution/compiler/executable_query.h b/src/include/execution/compiler/executable_query.h index 535fd0f7f6..ddc3e92381 100644 --- a/src/include/execution/compiler/executable_query.h +++ b/src/include/execution/compiler/executable_query.h @@ -66,9 +66,10 @@ class ExecutableQuery { * @param functions The name of the functions to execute, in order. * @param teardown_fns The name of the teardown functions in the module, in order. * @param module The module that contains the functions. + * @param file TODO(Kyle): this */ Fragment(std::vector &&functions, std::vector &&teardown_fns, - std::unique_ptr module); + std::unique_ptr module, ast::File *file); /** * Destructor. @@ -87,6 +88,16 @@ class ExecutableQuery { */ bool IsCompiled() const { return module_ != nullptr; } + /** + * @return The functions in the fragment. + */ + const std::vector &GetFunctions() const { return functions_; } + + /** + * @return The file. + */ + ast::File *GetFile() { return file_; }; + private: // The functions that must be run (in the provided order) to execute this // query fragment. @@ -96,14 +107,19 @@ class ExecutableQuery { // The module. std::unique_ptr module_; + + // The file. + ast::File *file_; }; /** * Create a query object. * @param plan The physical plan. * @param exec_settings The execution settings used for this query. + * @param context TODO(Kyle): this */ - ExecutableQuery(const planner::AbstractPlanNode &plan, const exec::ExecutionSettings &exec_settings); + ExecutableQuery(const planner::AbstractPlanNode &plan, const exec::ExecutionSettings &exec_settings, + ast::Context *context = nullptr); /** * This class cannot be copied or moved. @@ -142,7 +158,7 @@ class ExecutableQuery { /** * @return The AST context. */ - ast::Context *GetContext() { return ast_context_.get(); } + ast::Context *GetContext() { return ast_context_; } /** @return The execution settings used for this query. */ const exec::ExecutionSettings &GetExecutionSettings() const { return exec_settings_; } @@ -155,28 +171,61 @@ class ExecutableQuery { /** @return The Query Identifier */ query_id_t GetQueryId() { return query_id_; } + /** @brief Set the query state type */ + void SetQueryStateType(ast::StructDecl *query_state_type) { query_state_type_ = query_state_type; } + + /** @return The query state type */ + ast::StructDecl *GetQueryStateType() const { return query_state_type_; } + /** @param query_text The SQL string for this query */ void SetQueryText(common::ManagedPointer query_text) { query_text_ = query_text; } /** @return The SQL query string */ common::ManagedPointer GetQueryText() { return query_text_; } + /** + * @return The functions. + */ + const std::vector GetFunctions() const { + // TODO(Kyle): string copying, figure out something better + std::vector ret{}; + for (auto &f : fragments_) { + auto fns = f->GetFunctions(); + ret.insert(ret.end(), fns.begin(), fns.end()); + } + return ret; + } + + /** + * @return TODO(Kyle): this. + */ + std::vector GetDecls() const; + private: // The plan. const planner::AbstractPlanNode &plan_; + // The execution settings used for code generation. const exec::ExecutionSettings &exec_settings_; - std::unique_ptr errors_region_; std::unique_ptr context_region_; + std::unique_ptr errors_region_; + // The AST error reporter. std::unique_ptr errors_; + // The AST context used to generate the TPL AST. - std::unique_ptr ast_context_; + ast::Context *ast_context_; + bool owned{true}; + // The compiled query fragments that make up the query. std::vector> fragments_; + // The query state size. std::size_t query_state_size_; + // The type of the query state. + ast::StructDecl *query_state_type_; + // The pipeline operating units that were generated as part of this query. std::unique_ptr pipeline_operating_units_; @@ -184,7 +233,8 @@ class ExecutableQuery { /** Legacy constructor that creates a hardcoded fragment with main(ExecutionContext*)->int32. */ ExecutableQuery(const std::string &contents, common::ManagedPointer exec_ctx, bool is_file, - size_t query_state_size, const exec::ExecutionSettings &exec_settings); + std::size_t query_state_size, const exec::ExecutionSettings &exec_settings, + ast::Context *context = nullptr); /** * Set Pipeline Operating Units for use by mini_runners * @param units Pipeline Operating Units diff --git a/src/include/execution/compiler/expression/expression_translator.h b/src/include/execution/compiler/expression/expression_translator.h index 2e120a09d9..d8130c4c0f 100644 --- a/src/include/execution/compiler/expression/expression_translator.h +++ b/src/include/execution/compiler/expression/expression_translator.h @@ -5,6 +5,7 @@ #include "common/macros.h" #include "execution/ast/ast_fwd.h" #include "execution/compiler/expression/column_value_provider.h" +#include "execution/util/region_containers.h" namespace noisepage::parser { class AbstractExpression; @@ -47,6 +48,16 @@ class ExpressionTranslator { */ virtual ast::Expr *DeriveValue(WorkContext *ctx, const ColumnValueProvider *provider) const = 0; + /** + * TODO(Kyle): this + */ + virtual void DefineHelperFunctions(util::RegionVector *decls); + + /** + * TODO(Kyle): this + */ + virtual void DefineHelperStructs(util::RegionVector *decls); + /** * @return The expression being translated. */ diff --git a/src/include/execution/compiler/expression/function_translator.h b/src/include/execution/compiler/expression/function_translator.h index 6cdcfe18ea..f892509b7c 100644 --- a/src/include/execution/compiler/expression/function_translator.h +++ b/src/include/execution/compiler/expression/function_translator.h @@ -1,6 +1,11 @@ #pragma once #include "execution/compiler/expression/expression_translator.h" +#include "execution/functions/function_context.h" +#include "execution/util/region_containers.h" + +#include +#include namespace noisepage::parser { class FunctionExpression; @@ -27,6 +32,20 @@ class FunctionTranslator : public ExpressionTranslator { * @return The value of the expression. */ ast::Expr *DeriveValue(WorkContext *ctx, const ColumnValueProvider *provider) const override; + + /** + * TODO(Kyle): this. + */ + void DefineHelperFunctions(util::RegionVector *decls) override; + + /** + * TODO(Kyle): this. + */ + void DefineHelperStructs(util::RegionVector *decls) override; + + private: + std::vector params_; + ast::Identifier main_fn_; }; } // namespace noisepage::execution::compiler diff --git a/src/include/execution/compiler/function_builder.h b/src/include/execution/compiler/function_builder.h index 1bfb6cc611..448aa8563c 100644 --- a/src/include/execution/compiler/function_builder.h +++ b/src/include/execution/compiler/function_builder.h @@ -29,6 +29,14 @@ class FunctionBuilder { FunctionBuilder(CodeGen *codegen, ast::Identifier name, util::RegionVector &¶ms, ast::Expr *ret_type); + /** + * Create a builder for a function with the provided return type and arguments. + * @param codegen The code generation instance. + * @param params The parameters to the function. + * @param ret_type The return type representation of the function. + */ + FunctionBuilder(CodeGen *codegen, util::RegionVector &¶ms, ast::Expr *ret_type); + /** * Destructor. */ @@ -64,11 +72,26 @@ class FunctionBuilder { */ ast::FunctionDecl *Finish(ast::Expr *ret = nullptr); + /** + * Finish constructing the lambda. + * @param captures The lambda captures + * @param ret The return value, if present + * @return The lambda expression + */ + noisepage::execution::ast::LambdaExpr *FinishLambda(util::RegionVector &&captures, + ast::Expr *ret = nullptr); + /** * @return The final constructed function; null if the builder hasn't been constructed through * FunctionBuilder::Finish(). */ - ast::FunctionDecl *GetConstructedFunction() const { return decl_; } + ast::FunctionDecl *GetConstructedFunction() const { return decl_.fn_decl_; } + + /** + * @return The final constructed lambda; null if the builder hasn't been constructed through + * FunctionBuilder::FinishLambda(). + */ + ast::LambdaExpr *GetConstructedLambda() const { return decl_.lambda_expr_; } /** * @return The code generator instance. @@ -88,8 +111,16 @@ class FunctionBuilder { SourcePosition start_; // The list of generated statements making up the function. ast::BlockStmt *statements_; + + // `true` if this function is a lambda, `false` otherwise. + bool is_lambda_; + // The cached function declaration. Constructed once in Finish(). - ast::FunctionDecl *decl_; + // TODO(Kyle): This needs to be a variant... + union { + ast::FunctionDecl *fn_decl_{nullptr}; + ast::LambdaExpr *lambda_expr_; + } decl_; }; } // namespace noisepage::execution::compiler diff --git a/src/include/execution/compiler/operator/operator_translator.h b/src/include/execution/compiler/operator/operator_translator.h index ddcc1c07a1..18b8114088 100644 --- a/src/include/execution/compiler/operator/operator_translator.h +++ b/src/include/execution/compiler/operator/operator_translator.h @@ -111,14 +111,14 @@ class OperatorTranslator : public ColumnValueProvider { * declaration container. * @param decls Query-level declarations. */ - virtual void DefineHelperStructs(util::RegionVector *decls) {} + virtual void DefineHelperStructs(util::RegionVector *decls); /** * Define any helper functions required for processing. Ensure they're declared in the provided * declaration container. * @param decls Query-level declarations. */ - virtual void DefineHelperFunctions(util::RegionVector *decls) {} + virtual void DefineHelperFunctions(util::RegionVector *decls); /** * Define any helper functions that rely on pipeline's thread local state. @@ -254,6 +254,17 @@ class OperatorTranslator : public ColumnValueProvider { /** @return The address of the current tuple slot, if any. */ virtual ast::Expr *GetSlotAddress() const { UNREACHABLE("This translator does not deal with tupleslots."); } + /** + * TODO(Kyle): This. + */ + virtual void RegisterNeedValue(const OperatorTranslator *requester, uint32_t child_idx, uint32_t attr_idx) { + UNREACHABLE("not implemented"); + } + + /** @return The pipeline this translator is a part of. */ + // TODO(Kyle): Why did we change visibility of this? Protected to public + Pipeline *GetPipeline() const { return pipeline_; } + protected: /** Get the code generator instance. */ CodeGen *GetCodeGen() const; @@ -270,9 +281,6 @@ class OperatorTranslator : public ColumnValueProvider { /** Get the memory pool pointer from the execution context stored in the query state. */ ast::Expr *GetMemoryPool() const; - /** The pipeline this translator is a part of. */ - Pipeline *GetPipeline() const { return pipeline_; } - /** The plan node for this translator as its concrete type. */ template const T &GetPlanAs() const { diff --git a/src/include/execution/compiler/pipeline.h b/src/include/execution/compiler/pipeline.h index 094ca3619a..7c7ca78912 100644 --- a/src/include/execution/compiler/pipeline.h +++ b/src/include/execution/compiler/pipeline.h @@ -29,6 +29,7 @@ class ExecutableQueryFragmentBuilder; class ExpressionTranslator; class OperatorTranslator; class PipelineDriver; +class WorkContext; /** * A pipeline represents an ordered sequence of relational operators that operate on tuple data @@ -66,8 +67,9 @@ class Pipeline { * Create a pipeline with the given operator as the root. * @param op The root operator of the pipeline. * @param parallelism The operator's requested parallelism. + * @param consumer TODO(Kyle) */ - Pipeline(OperatorTranslator *op, Parallelism parallelism); + Pipeline(OperatorTranslator *op, Parallelism parallelism, bool consumer = false); /** * Register an operator in this pipeline with a customized parallelism configuration. @@ -115,6 +117,12 @@ class Pipeline { */ void LinkSourcePipeline(Pipeline *dependency); + /** + * Registers a nested pipeline. These pipelines are invoked from other pipelines and are not added to the main steps + * @param pipeline The pipeline to nest + */ + void LinkNestedPipeline(Pipeline *pipeline, const OperatorTranslator *op); + /** * Store in the provided output vector the set of all dependencies for this pipeline. In other * words, store in the output vector all pipelines that must execute (in order) before this @@ -132,8 +140,11 @@ class Pipeline { /** * Generate all functions to execute this pipeline in the provided container. * @param builder The builder for the executable query container. + * @param query_id TODO(Kyle) + * @param output_callback TODO(Kyle) */ - void GeneratePipeline(ExecutableQueryFragmentBuilder *builder) const; + void GeneratePipeline(ExecutableQueryFragmentBuilder *builder, query_id_t query_id, + ast::LambdaExpr *output_callback = nullptr) const; /** * @return True if the pipeline is parallel; false otherwise. @@ -175,6 +186,18 @@ class Pipeline { */ std::string CreatePipelineFunctionName(const std::string &func_name) const; + /** + * @return A vector of expressions that initialize, run and teardown a nested pipeline. + */ + std::vector CallSingleRunPipelineFunction() const; + + void CallNestedRunPipelineFunction(WorkContext *ctx, const OperatorTranslator *op, FunctionBuilder *function) const; + + /** + * @return A vector of expressions that do the work of running a pipeline function and its dependencies + */ + std::vector CallRunPipelineFunction() const; + /** * @return Pipeline state variable */ @@ -207,6 +230,12 @@ class Pipeline { */ ast::Expr *OUFeatureVecPtr() const { return oufeatures_.GetPtr(codegen_); } + /** @return TODO(Kyle) */ + ast::Expr *GetNestedInputArg(uint32_t index) const; + + /** @return `true` if this pipeline is prepared, `false` otherwise */ + bool IsPrepared() const { return prepared_; } + private: // Return the thread-local state initialization and tear-down function names. // This is needed when we invoke @tlsReset() from the pipeline initialization @@ -215,6 +244,9 @@ class Pipeline { ast::Identifier GetTearDownPipelineStateFunctionName() const; ast::Identifier GetWorkFunctionName() const; + // TODO(Kyle) this + ast::FunctionDecl *GeneratePipelineWrapperFunction(ast::LambdaExpr *output_callback) const; + // Generate the pipeline state initialization logic. ast::FunctionDecl *GenerateSetupPipelineStateFunction() const; @@ -222,16 +254,19 @@ class Pipeline { ast::FunctionDecl *GenerateTearDownPipelineStateFunction() const; // Generate pipeline initialization logic. - ast::FunctionDecl *GenerateInitPipelineFunction() const; + ast::FunctionDecl *GenerateInitPipelineFunction(ast::LambdaExpr *output_callback) const; // Generate the main pipeline work function. - ast::FunctionDecl *GeneratePipelineWorkFunction() const; + ast::FunctionDecl *GeneratePipelineWorkFunction(ast::LambdaExpr *output_callback) const; // Generate the main pipeline logic. - ast::FunctionDecl *GenerateRunPipelineFunction() const; + ast::FunctionDecl *GenerateRunPipelineFunction(query_id_t query_id, ast::LambdaExpr *output_callback) const; // Generate pipeline tear-down logic. - ast::FunctionDecl *GenerateTearDownPipelineFunction() const; + ast::FunctionDecl *GenerateTearDownPipelineFunction(ast::LambdaExpr *output_callback) const; + + /** @brief TODO(Kyle) */ + void MarkNested() { nested_ = true; } private: // Internals which are exposed for minirunners. @@ -241,6 +276,30 @@ class Pipeline { /** @return The vector of pipeline operators that make up the pipeline. */ const std::vector &GetTranslators() const { return steps_; } + /** @brief TODO(Kyle) */ + void InjectStartPipelineTracker(FunctionBuilder *builder) const; + + /** @brief TODO(Kyle) */ + void InjectEndResourceTracker(FunctionBuilder *builder, query_id_t query_id) const; + + /** @brief TODO(Kyle) */ + ast::Identifier GetRunPipelineFunctionName() const; + + /** @brief TODO(Kyle) */ + void CollectDependencies(std::vector *deps) const; + + /** @brief TODO(Kyle) */ + ast::Identifier GetTeardownPipelineFunctionName() const; + + /** @brief TODO(Kyle) */ + ast::Identifier GetInitPipelineFunctionName() const; + + /** @brief TODO(Kyle) */ + const StateDescriptor &GetPipelineStateDescriptor() const { return state_; } + + /** @brief TODO(Kyle) */ + StateDescriptor &GetPipelineStateDescriptor() { return state_; } + private: // A unique pipeline ID. uint32_t id_; @@ -248,24 +307,33 @@ class Pipeline { CompilationContext *compilation_context_; // The code generation instance. CodeGen *codegen_; + // Cache of common identifiers. + ast::Identifier state_var_; + // The pipeline state. + StateDescriptor state_; + // The pipeline operating unit feature vector state. + StateDescriptor::Entry oufeatures_; // Operators making up the pipeline. std::vector steps_; // The driver. PipelineDriver *driver_; + // pointer to parent pipeline (only applicable if this is a nested pipeline) + Pipeline *parent_; // Expressions participating in the pipeline. std::vector expressions_; + // All unnested pipelines this one depends on completion of. + std::vector dependencies_; + // Vector of pipelines that are nested under this pipeline + std::vector nested_pipelines_; + std::vector extra_pipeline_params_; // Configured parallelism. Parallelism parallelism_; // Whether to check for parallelism in new pipeline elements. bool check_parallelism_; - // All pipelines this one depends on completion of. - std::vector dependencies_; - // Cache of common identifiers. - ast::Identifier state_var_; - // The pipeline state. - StateDescriptor state_; - // The pipeline operating unit feature vector state. - StateDescriptor::Entry oufeatures_; + // Whether or not this is a nested pipeline. + bool nested_; + // Whether or not this pipeline is prepared. + bool prepared_{false}; }; } // namespace noisepage::execution::compiler From 98259fd0c42be811aa12973665d6c90a7cca9218 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sun, 4 Apr 2021 16:39:25 -0400 Subject: [PATCH 015/139] integrate vm and ddl executors --- src/execution/sql/ddl_executors.cpp | 109 +++++ src/execution/vm/bytecode_emitter.cpp | 16 + src/execution/vm/bytecode_generator.cpp | 384 +++++++++++++----- src/execution/vm/bytecode_module.cpp | 7 + src/execution/vm/llvm_engine.cpp | 32 +- src/execution/vm/module.cpp | 5 +- src/execution/vm/vm.cpp | 42 +- src/include/catalog/catalog_accessor.h | 2 + src/include/catalog/postgres/pg_language.h | 7 +- src/include/execution/sql/ddl_executors.h | 39 +- src/include/execution/vm/bytecode_emitter.h | 18 + .../execution/vm/bytecode_function_info.h | 20 + src/include/execution/vm/bytecode_generator.h | 16 +- src/include/execution/vm/bytecode_handlers.h | 31 ++ src/include/execution/vm/bytecodes.h | 13 + .../execution/vm/control_flow_builders.h | 12 + 16 files changed, 629 insertions(+), 124 deletions(-) diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index 5e4d5e1e79..d696d0879d 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -5,10 +5,21 @@ #include #include "catalog/catalog_accessor.h" +#include "catalog/postgres/pg_language.h" #include "common/macros.h" +#include "execution/ast/ast_pretty_print.h" +#include "execution/ast/context.h" +#include "execution/ast/udf/udf_ast_context.h" +#include "execution/compiler/codegen.h" +#include "execution/compiler/function_builder.h" +#include "execution/compiler/udf/udf_codegen.h" #include "execution/exec/execution_context.h" +#include "execution/sema/sema.h" +#include "loggers/execution_logger.h" #include "parser/expression/column_value_expression.h" +#include "parser/udf/udf_parser.h" #include "planner/plannodes/create_database_plan_node.h" +#include "planner/plannodes/create_function_plan_node.h" #include "planner/plannodes/create_index_plan_node.h" #include "planner/plannodes/create_namespace_plan_node.h" #include "planner/plannodes/create_table_plan_node.h" @@ -33,6 +44,104 @@ bool DDLExecutors::CreateNamespaceExecutor(const common::ManagedPointerCreateNamespace(node->GetNamespaceName()) != catalog::INVALID_NAMESPACE_OID; } +bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointer node, + const common::ManagedPointer accessor) { + // Request permission from the Catalog to see if this a valid namespace name + NOISEPAGE_ASSERT(node->GetUDFLanguage() == parser::PLType::PL_PGSQL, "Unsupported language"); + NOISEPAGE_ASSERT(node->GetFunctionBody().size() >= 1, "Unsupported function body?"); + + // I don't like how we have to separate the two here + std::vector param_type_ids; + std::vector param_types; + for (auto t : node->GetFunctionParameterTypes()) { + param_type_ids.push_back(parser::FuncParameter::DataTypeToTypeId(t)); + param_types.push_back(accessor->GetTypeOidFromTypeId(parser::FuncParameter::DataTypeToTypeId(t))); + } + auto body = node->GetFunctionBody().front(); + auto proc_id = accessor->CreateProcedure( + node->GetFunctionName(), catalog::postgres::PgLanguage::PLPGSQL_LANGUAGE_OID, node->GetNamespaceOid(), + node->GetFunctionParameterNames(), param_types, param_types, {}, + accessor->GetTypeOidFromTypeId(parser::ReturnType::DataTypeToTypeId(node->GetReturnType())), body, false); + if (proc_id == catalog::INVALID_PROC_OID) { + return false; + } + + // make the context here using the body + ast::udf::UDFASTContext udf_ast_context{}; + // parser::udf::UDFContext udf_context; + parser::udf::PLpgSQLParser udf_parser((common::ManagedPointer(&udf_ast_context)), accessor, node->GetDatabaseOid()); + std::unique_ptr ast{}; + try { + ast = udf_parser.ParsePLpgSQL(node->GetFunctionParameterNames(), std::move(param_type_ids), body, + (common::ManagedPointer(&udf_ast_context))); + } catch (Exception &e) { + return false; + } + + auto region = new util::Region(node->GetFunctionName()); + sema::ErrorReporter error_reporter(region); + auto ast_context = new ast::Context(region, &error_reporter); + + compiler::CodeGen codegen(ast_context, accessor.Get()); + util::RegionVector fn_params{codegen.GetAstContext()->GetRegion()}; + // auto ret_name = parser::udf::UDFCodegen::GetReturnParamString(); + // auto ret_type = parser::ReturnType::DataTypeToTypeId(node->GetReturnType()); + // fn_params.emplace_back(codegen.MakeField(ast::Identifier{ret_name}, + // codegen.PointerType(codegen.TplType(ret_type)))); + fn_params.emplace_back( + codegen.MakeField(codegen.MakeFreshIdentifier("executionCtx"), + codegen.PointerType(codegen.BuiltinType(ast::BuiltinType::ExecutionContext)))); + + for (size_t i = 0; i < node->GetFunctionParameterNames().size(); i++) { + auto name = node->GetFunctionParameterNames()[i]; + auto type = parser::ReturnType::DataTypeToTypeId(node->GetFunctionParameterTypes()[i]); + // auto name_alloc = reinterpret_cast(codegen.GetAstContext()->GetRegion()->Allocate(name.length()+1)); + // std::memcpy(name_alloc, name.c_str(), name.length() + 1); + fn_params.emplace_back(codegen.MakeField(ast_context->GetIdentifier(name), + // codegen.PointerType( + codegen.TplType(execution::sql::GetTypeId(type)) + // ) + )); + } + + auto name = node->GetFunctionName(); + // char *name_alloc = reinterpret_cast(codegen.GetAstContext()->GetRegion()->Allocate(name.length() + 1)); + // std::memcpy(name_alloc, name.c_str(), name.length() + 1); + + compiler::FunctionBuilder fb{ + &codegen, codegen.MakeFreshIdentifier(name), std::move(fn_params), + // codegen.PointerType( + codegen.TplType(execution::sql::GetTypeId(parser::ReturnType::DataTypeToTypeId(node->GetReturnType()))) + // ) + }; + compiler::udf::UDFCodegen udf_codegen{accessor.Get(), &fb, &udf_ast_context, &codegen, node->GetDatabaseOid()}; + udf_codegen.GenerateUDF(ast->body.get()); + auto fn = fb.Finish(); + //// util::RegionVector decls_reg_vec{decls->begin(), decls->end(), codegen.Region()}; + util::RegionVector decls({fn}, codegen.GetAstContext()->GetRegion()); + auto file = udf_codegen.Finish(); + + { + sema::Sema type_check(codegen.GetAstContext().Get()); + type_check.GetErrorReporter()->Reset(); + type_check.Run(file); + EXECUTION_LOG_ERROR("Errors: \n {}", type_check.GetErrorReporter()->SerializeErrors()); + execution::ast::AstPrettyPrint::Dump(std::cout, file); + // NOISEPAGE_ASSERT(!bad, "bad function"); + } + + auto udf_context = new functions::FunctionContext( + node->GetFunctionName(), parser::ReturnType::DataTypeToTypeId(node->GetReturnType()), std::move(param_type_ids), + std::unique_ptr(region), std::unique_ptr(ast_context), file); + if (!accessor->SetFunctionContextPointer(proc_id, udf_context)) { + delete udf_context; + return false; + } + + accessor->GetTxn()->RegisterAbortAction([=]() { delete udf_context; }); + return true; +} + bool DDLExecutors::CreateTableExecutor(const common::ManagedPointer node, const common::ManagedPointer accessor, const catalog::db_oid_t connection_db) { diff --git a/src/execution/vm/bytecode_emitter.cpp b/src/execution/vm/bytecode_emitter.cpp index 3bf53d40f3..c8bc2f3d03 100644 --- a/src/execution/vm/bytecode_emitter.cpp +++ b/src/execution/vm/bytecode_emitter.cpp @@ -25,6 +25,10 @@ void BytecodeEmitter::EmitAssign(Bytecode bytecode, LocalVar dest, LocalVar src) EmitAll(bytecode, dest, src); } +void BytecodeEmitter::EmitAssignN(LocalVar dest, LocalVar src, uint32_t len) { + EmitAll(Bytecode::AssignN, dest, src.AddressOf(), len); +} + void BytecodeEmitter::EmitAssignImm1(LocalVar dest, int8_t val) { EmitAll(Bytecode::AssignImm1, dest, val); } void BytecodeEmitter::EmitAssignImm2(LocalVar dest, int16_t val) { EmitAll(Bytecode::AssignImm2, dest, val); } @@ -62,6 +66,18 @@ void BytecodeEmitter::EmitCall(FunctionId func_id, const std::vector & } } +std::function BytecodeEmitter::DeferedEmitCall(const std::vector ¶ms) { + NOISEPAGE_ASSERT(Bytecodes::GetNthOperandSize(Bytecode::Call, 1) == OperandSize::Short, + "Expected argument count to be 2-byte short"); + NOISEPAGE_ASSERT(params.size() < std::numeric_limits::max(), "Too many parameters!"); + auto bc_insert_index = bytecode_->size() + sizeof(Bytecode); + EmitAll(Bytecode::Call, std::numeric_limits::max(), static_cast(params.size())); + for (LocalVar local : params) { + EmitImpl(local); + } + return [=](FunctionId func_id) { EmitScalarValue(static_cast(func_id), bc_insert_index); }; +} + void BytecodeEmitter::EmitReturn() { EmitImpl(Bytecode::Return); } void BytecodeEmitter::Bind(BytecodeLabel *label) { diff --git a/src/execution/vm/bytecode_generator.cpp b/src/execution/vm/bytecode_generator.cpp index 44f4a189a3..da9f34103b 100644 --- a/src/execution/vm/bytecode_generator.cpp +++ b/src/execution/vm/bytecode_generator.cpp @@ -135,6 +135,7 @@ class BytecodeGenerator::BytecodePositionScope { // Bytecode Generator begins // --------------------------------------------------------- +// TODO(Kyle): reserve here on functions? BytecodeGenerator::BytecodeGenerator() noexcept : emitter_(&code_) {} void BytecodeGenerator::VisitIfStmt(ast::IfStmt *node) { @@ -164,7 +165,8 @@ void BytecodeGenerator::VisitIterationStatement(ast::IterationStmt *iteration, L } void BytecodeGenerator::VisitForStmt(ast::ForStmt *node) { - LoopBuilder loop_builder(this); + LoopBuilder *prev = current_loop_; + LoopBuilder loop_builder{this, prev}; if (node->Init() != nullptr) { Visit(node->Init()); @@ -177,7 +179,9 @@ void BytecodeGenerator::VisitForStmt(ast::ForStmt *node) { VisitExpressionForTest(node->Condition(), &loop_body_label, loop_builder.GetBreakLabel(), TestFallthrough::Then); } + current_loop_ = &loop_builder; VisitIterationStatement(node, &loop_builder); + current_loop_ = prev; if (node->Next() != nullptr) { Visit(node->Next()); @@ -205,7 +209,8 @@ void BytecodeGenerator::VisitFunctionDecl(ast::FunctionDecl *node) { auto *func_type = node->TypeRepr()->GetType()->As(); // Allocate the function - FunctionInfo *func_info = AllocateFunc(node->Name().GetData(), func_type); + auto *func_info = AllocateFunc(node->Name().GetData(), func_type); + EnterFunction(func_info->GetId()); { // Visit the body of the function. We use this handy scope object to track @@ -215,6 +220,11 @@ void BytecodeGenerator::VisitFunctionDecl(ast::FunctionDecl *node) { BytecodePositionScope position_scope(this, func_info); Visit(node->Function()); } + + // TODO(Kyle): what is this doing? + for (auto f : func_info->actions_) { + f(); + } } void BytecodeGenerator::VisitLambdaExpr(ast::LambdaExpr *node) { @@ -279,6 +289,8 @@ void BytecodeGenerator::VisitLambdaExpr(ast::LambdaExpr *node) { // }); } +// TODO(Kyle): Do we need a VisitLambdaDecl()? + void BytecodeGenerator::VisitLambdaTypeRepr(ast::LambdaTypeRepr *node) { UNREACHABLE("Should not visit type-representation nodes!"); } @@ -290,9 +302,58 @@ void BytecodeGenerator::VisitIdentifierExpr(ast::IdentifierExpr *node) { const std::string local_name = node->Name().GetData(); LocalVar local = GetCurrentFunction()->LookupLocal(local_name); + std::string suffix{}; + bool capture = false; + + if (local.IsInvalid() && GetCurrentFunction()->is_lambda_) { + local = GetCurrentFunction()->LookupLocal(local_name + "ptr").ValueOf(); + suffix = "ptr"; + if (!local.IsInvalid()) { + if (GetExecutionResult()->IsRValue()) { + auto local_val = GetCurrentFunction()->NewLocal(node->GetType(), ""); + GetEmitter()->EmitDerefN(local_val, local.ValueOf(), node->GetType()->GetSize()); + local = local_val; + } + } + } + + if (local.IsInvalid()) { + NOISEPAGE_ASSERT(GetCurrentFunction()->is_lambda_, "Not a lambda and variable not found"); + + // TODO(Kyle): modularize this fetch of capture struct + auto params = GetCurrentFunction()->func_type_->GetParams(); + auto captures = GetCurrentFunction()->func_type_->GetCapturesType(); + for (auto field : captures->GetFieldsWithoutPadding()) { + // TODO(Kyle): cache these + if (field.name_.GetString() == local_name) { + auto captures_local = GetCurrentFunction()->LookupLocal("captures"); + auto local_ptr = GetCurrentFunction()->NewLocal(field.type_->PointerTo()); + GetEmitter()->EmitLea(local_ptr, captures_local.ValueOf(), captures->GetOffsetOfFieldByName(field.name_)); + auto local_ptr_2 = GetCurrentFunction()->NewLocal(field.type_, local_name + "ptr"); + GetEmitter()->EmitDerefN(local_ptr_2, local_ptr.ValueOf(), field.type_->GetSize()); + local = local_ptr_2; + suffix = "ptr"; + if (GetExecutionResult()->IsRValue()) { + local = GetCurrentFunction()->NewLocal(field.type_->GetPointeeType(), ""); + GetEmitter()->EmitDerefN(local, local_ptr_2.ValueOf(), field.type_->GetPointeeType()->GetSize()); + suffix = "val"; + } + local = local.ValueOf(); + break; + } + } + capture = true; + } + NOISEPAGE_ASSERT(!local.IsInvalid(), "Local not found"); if (GetExecutionResult()->IsLValue()) { - GetExecutionResult()->SetDestination(local); + // TODO(Kyle): crappy names + auto *local_info_2 = GetCurrentFunction()->LookupLocalInfoByOffset(local.GetOffset()); + if (local_info_2->GetType()->IsPointerType() && local_info_2->GetType()->GetPointeeType()->IsSqlValueType()) { + GetExecutionResult()->SetDestination(local.ValueOf()); + } else { + GetExecutionResult()->SetDestination(local); + } return; } @@ -310,34 +371,39 @@ void BytecodeGenerator::VisitIdentifierExpr(ast::IdentifierExpr *node) { // If the local we want the R-Value of is a parameter, we can't take its // pointer for the deref, so we use an assignment. Otherwise, a deref is good. - if (auto *local_info = GetCurrentFunction()->LookupLocalInfoByName(local_name); local_info->IsParameter()) { - BuildAssign(dest, local.ValueOf(), node->GetType()); + auto *local_info = GetCurrentFunction()->LookupLocalInfoByOffset(local.GetOffset()); + if (local_info->IsParameter()) { + if (local_info->GetType()->IsPointerType() && local_info->GetType()->GetPointeeType()->IsSqlValueType() && + GetExecutionResult()->IsRValue()) { + BuildDeref(dest, local.ValueOf(), node->GetType()); + } else { + BuildAssign(dest, local.ValueOf(), node->GetType()); + } } else { BuildDeref(dest, local, node->GetType()); } - GetExecutionResult()->SetDestination(dest); + GetExecutionResult()->SetDestination(capture ? dest.ValueOf() : dest); } void BytecodeGenerator::VisitImplicitCastExpr(ast::ImplicitCastExpr *node) { + LocalVar input = VisitExpressionForRValue(node->Input()); + switch (node->GetCastKind()) { case ast::CastKind::SqlBoolToBool: { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(node->GetType()); - LocalVar input = VisitExpressionForSQLValue(node->Input()); GetEmitter()->Emit(Bytecode::ForceBoolTruth, dest, input); GetExecutionResult()->SetDestination(dest.ValueOf()); break; } case ast::CastKind::BoolToSqlBool: { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(node->GetType()); - LocalVar input = VisitExpressionForRValue(node->Input()); GetEmitter()->Emit(Bytecode::InitBool, dest, input); GetExecutionResult()->SetDestination(dest); break; } case ast::CastKind::IntToSqlInt: { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(node->GetType()); - LocalVar input = VisitExpressionForRValue(node->Input()); ast::Expr *arg = node->Input(); Bytecode bytecode = Bytecode::InitInteger; @@ -355,7 +421,6 @@ void BytecodeGenerator::VisitImplicitCastExpr(ast::ImplicitCastExpr *node) { } case ast::CastKind::BitCast: case ast::CastKind::IntegralCast: { - LocalVar input = VisitExpressionForRValue(node->Input()); // As an optimization, we only issue a new assignment if the input and // output types of the cast have different sizes. if (node->Input()->GetType()->GetSize() != node->GetType()->GetSize()) { @@ -369,15 +434,13 @@ void BytecodeGenerator::VisitImplicitCastExpr(ast::ImplicitCastExpr *node) { } case ast::CastKind::FloatToSqlReal: { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(node->GetType()); - LocalVar input = VisitExpressionForRValue(node->Input()); - GetEmitter()->Emit(Bytecode::InitReal, dest, input); + GetEmitter()->Emit(Bytecode::InitReal, dest.AddressOf(), input.AddressOf()); GetExecutionResult()->SetDestination(dest); break; } case ast::CastKind::SqlIntToSqlReal: { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(node->GetType()); - LocalVar input = VisitExpressionForSQLValue(node->Input()); - GetEmitter()->Emit(Bytecode::IntegerToReal, dest, input); + GetEmitter()->Emit(Bytecode::IntegerToReal, dest.AddressOf(), input.AddressOf()); GetExecutionResult()->SetDestination(dest); break; } @@ -540,7 +603,7 @@ void BytecodeGenerator::VisitLogicalNotExpr(ast::UnaryOpExpr *op) { GetEmitter()->EmitUnaryOp(Bytecode::Not, dest, input); GetExecutionResult()->SetDestination(dest.ValueOf()); } else if (op->GetType()->IsSqlBooleanType()) { - input = VisitExpressionForSQLValue(op->Input()); + input = VisitExpressionForLValue(op->Input()); GetEmitter()->EmitUnaryOp(Bytecode::NotSql, dest, input); GetExecutionResult()->SetDestination(dest); } @@ -575,7 +638,7 @@ void BytecodeGenerator::VisitReturnStmt(ast::ReturnStmt *node) { if (node->Ret() != nullptr) { LocalVar rv = GetCurrentFunction()->GetReturnValueLocal(); if (node->Ret()->GetType()->IsSqlValueType()) { - LocalVar result = VisitExpressionForSQLValue(node->Ret()); + LocalVar result = VisitExpressionForLValue(node->Ret()); BuildDeref(rv.ValueOf(), result, node->Ret()->GetType()); } else { LocalVar result = VisitExpressionForRValue(node->Ret()); @@ -648,17 +711,17 @@ void BytecodeGenerator::VisitSqlConversionCall(ast::CallExpr *call, ast::Builtin break; } case ast::Builtin::SqlToBool: { - auto input = VisitExpressionForSQLValue(call->Arguments()[0]); + auto input = VisitExpressionForLValue(call->Arguments()[0]); GetEmitter()->Emit(Bytecode::ForceBoolTruth, dest, input); GetExecutionResult()->SetDestination(dest.ValueOf()); break; } -#define GEN_CASE(Builtin, Bytecode) \ - case Builtin: { \ - auto input = VisitExpressionForSQLValue(call->Arguments()[0]); \ - GetEmitter()->Emit(Bytecode, dest, input); \ - break; \ +#define GEN_CASE(Builtin, Bytecode) \ + case Builtin: { \ + auto input = VisitExpressionForRValue(call->Arguments()[0]); \ + GetEmitter()->Emit(Bytecode, dest, input); \ + break; \ } GEN_CASE(ast::Builtin::ConvertBoolToInteger, Bytecode::BoolToInteger); GEN_CASE(ast::Builtin::ConvertIntegerToReal, Bytecode::IntegerToReal); @@ -680,7 +743,7 @@ void BytecodeGenerator::VisitNullValueCall(ast::CallExpr *call, UNUSED_ATTRIBUTE switch (builtin) { case ast::Builtin::IsValNull: { LocalVar result = GetExecutionResult()->GetOrCreateDestination(call->GetType()); - LocalVar input = VisitExpressionForSQLValue(call->Arguments()[0]); + LocalVar input = VisitExpressionForLValue(call->Arguments()[0]); GetEmitter()->Emit(Bytecode::ValIsNull, result, input); GetExecutionResult()->SetDestination(result.ValueOf()); break; @@ -703,15 +766,15 @@ void BytecodeGenerator::VisitNullValueCall(ast::CallExpr *call, UNUSED_ATTRIBUTE void BytecodeGenerator::VisitSqlStringLikeCall(ast::CallExpr *call) { auto dest = GetExecutionResult()->GetOrCreateDestination(call->GetType()); - auto input = VisitExpressionForSQLValue(call->Arguments()[0]); - auto pattern = VisitExpressionForSQLValue(call->Arguments()[1]); + auto input = VisitExpressionForLValue(call->Arguments()[0]); + auto pattern = VisitExpressionForLValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Like, dest, input, pattern); GetExecutionResult()->SetDestination(dest); } void BytecodeGenerator::VisitBuiltinDateFunctionCall(ast::CallExpr *call, ast::Builtin builtin) { auto dest = GetExecutionResult()->GetOrCreateDestination(call->GetType()); - auto input = VisitExpressionForSQLValue(call->Arguments()[0]); + auto input = VisitExpressionForLValue(call->Arguments()[0]); auto date_type = sql::DatePartType(call->Arguments()[1]->As()->Arguments()[0]->As()->Int64Val()); @@ -909,13 +972,13 @@ void BytecodeGenerator::VisitBuiltinVPICall(ast::CallExpr *call, ast::Builtin bu #define GEN_CASE(BuiltinName, Bytecode) \ case ast::Builtin::BuiltinName: { \ - auto input = VisitExpressionForSQLValue(call->Arguments()[1]); \ + auto input = VisitExpressionForLValue(call->Arguments()[1]); \ auto col_idx = call->Arguments()[2]->As()->Int64Val(); \ GetEmitter()->EmitVPISet(Bytecode, vpi, input, col_idx); \ break; \ } \ case ast::Builtin::BuiltinName##Null: { \ - auto input = VisitExpressionForSQLValue(call->Arguments()[1]); \ + auto input = VisitExpressionForLValue(call->Arguments()[1]); \ auto col_idx = call->Arguments()[2]->As()->Int64Val(); \ GetEmitter()->EmitVPISet(Bytecode##Null, vpi, input, col_idx); \ break; \ @@ -953,7 +1016,7 @@ void BytecodeGenerator::VisitBuiltinHashCall(ast::CallExpr *call) { for (uint32_t idx = 0; idx < call->NumArgs(); idx++) { NOISEPAGE_ASSERT(call->Arguments()[idx]->GetType()->IsSqlValueType(), "Input to hash must be a SQL value type"); - LocalVar input = VisitExpressionForSQLValue(call->Arguments()[idx]); + LocalVar input = VisitExpressionForLValue(call->Arguments()[idx]); const auto *type = call->Arguments()[idx]->GetType()->As(); switch (type->GetKind()) { case ast::BuiltinType::Integer: @@ -1026,7 +1089,7 @@ void BytecodeGenerator::VisitBuiltinVectorFilterCall(ast::CallExpr *call, ast::B #define GEN_CASE(BYTECODE) \ LocalVar left_col = VisitExpressionForRValue(call->Arguments()[2]); \ if (!call->Arguments()[3]->GetType()->IsIntegerType()) { \ - LocalVar right_val = VisitExpressionForSQLValue(call->Arguments()[3]); \ + LocalVar right_val = VisitExpressionForLValue(call->Arguments()[3]); \ GetEmitter()->Emit(BYTECODE##Val, exec_ctx, vector_projection, left_col, right_val, tid_list); \ } else { \ LocalVar right_col = VisitExpressionForRValue(call->Arguments()[3]); \ @@ -1907,7 +1970,7 @@ void BytecodeGenerator::VisitBuiltinThreadStateContainerCall(ast::CallExpr *call void BytecodeGenerator::VisitBuiltinTrigCall(ast::CallExpr *call, ast::Builtin builtin) { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(call->GetType()); - LocalVar src = VisitExpressionForSQLValue(call->Arguments()[0]); + LocalVar src = VisitExpressionForLValue(call->Arguments()[0]); switch (builtin) { case ast::Builtin::ACos: { @@ -1935,7 +1998,7 @@ void BytecodeGenerator::VisitBuiltinTrigCall(ast::CallExpr *call, ast::Builtin b break; } case ast::Builtin::ATan2: { - LocalVar src2 = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar src2 = VisitExpressionForLValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Atan2, dest, src, src2); break; } @@ -1976,7 +2039,7 @@ void BytecodeGenerator::VisitBuiltinTrigCall(ast::CallExpr *call, ast::Builtin b break; } case ast::Builtin::Exp: { - src = VisitExpressionForSQLValue(call->Arguments()[1]); + src = VisitExpressionForLValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Exp, dest, src); break; } @@ -1993,12 +2056,12 @@ void BytecodeGenerator::VisitBuiltinTrigCall(ast::CallExpr *call, ast::Builtin b break; } case ast::Builtin::Round2: { - LocalVar src2 = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar src2 = VisitExpressionForLValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Round2, dest, src, src2); break; } case ast::Builtin::Pow: { - LocalVar src2 = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar src2 = VisitExpressionForLValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Pow, dest, src, src2); break; } @@ -2016,13 +2079,13 @@ void BytecodeGenerator::VisitBuiltinArithmeticCall(ast::CallExpr *call, ast::Bui switch (builtin) { case ast::Builtin::Abs: { - LocalVar src = VisitExpressionForSQLValue(call->Arguments()[0]); + LocalVar src = VisitExpressionForLValue(call->Arguments()[0]); GetEmitter()->Emit(is_integer_math ? Bytecode::AbsInteger : Bytecode::AbsReal, dest, src); break; } case ast::Builtin::Mod: { - LocalVar first_input = VisitExpressionForSQLValue(call->Arguments()[0]); - LocalVar second_input = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar first_input = VisitExpressionForLValue(call->Arguments()[0]); + LocalVar second_input = VisitExpressionForLValue(call->Arguments()[1]); if (!is_integer_math) { NOISEPAGE_ASSERT(call->Arguments()[0]->GetType()->IsSpecificBuiltin(ast::BuiltinType::Real) && call->Arguments()[1]->GetType()->IsSpecificBuiltin(ast::BuiltinType::Real), @@ -2501,8 +2564,22 @@ void BytecodeGenerator::VisitBuiltinStorageInterfaceCall(ast::CallExpr *call, as void BytecodeGenerator::VisitBuiltinParamCall(ast::CallExpr *call, ast::Builtin builtin) { LocalVar exec_ctx = VisitExpressionForRValue(call->Arguments()[0]); - LocalVar param_idx = VisitExpressionForRValue(call->Arguments()[1]); - LocalVar ret = GetExecutionResult()->GetOrCreateDestination(call->GetType()); + LocalVar param_idx{}; + if (builtin != ast::Builtin::StartNewParams && builtin != ast::Builtin::FinishNewParams) { + param_idx = VisitExpressionForRValue(call->Arguments()[1]); + } + LocalVar ret; + if (builtin < ast::Builtin::StartNewParams) { + ret = GetExecutionResult()->GetOrCreateDestination(call->GetType()); + } else { + if (builtin != ast::Builtin::StartNewParams && builtin != ast::Builtin::FinishNewParams) { + if (call->Arguments()[1]->GetType()->IsPointerType()) { + param_idx = VisitExpressionForRValue(call->Arguments()[1]); + } else { + param_idx = VisitExpressionForLValue(call->Arguments()[1]); + } + } + } switch (builtin) { case ast::Builtin::GetParamBool: GetEmitter()->Emit(Bytecode::GetParamBool, ret, exec_ctx, param_idx); @@ -2534,6 +2611,42 @@ void BytecodeGenerator::VisitBuiltinParamCall(ast::CallExpr *call, ast::Builtin case ast::Builtin::GetParamString: GetEmitter()->Emit(Bytecode::GetParamString, ret, exec_ctx, param_idx); break; + case ast::Builtin::AddParamBool: + GetEmitter()->Emit(Bytecode::AddParamBool, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamTinyInt: + GetEmitter()->Emit(Bytecode::AddParamTinyInt, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamSmallInt: + GetEmitter()->Emit(Bytecode::AddParamSmallInt, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamInt: + GetEmitter()->Emit(Bytecode::AddParamInt, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamBigInt: + GetEmitter()->Emit(Bytecode::AddParamBigInt, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamReal: + GetEmitter()->Emit(Bytecode::AddParamReal, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamDouble: + GetEmitter()->Emit(Bytecode::AddParamDouble, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamDate: + GetEmitter()->Emit(Bytecode::AddParamDateVal, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamTimestamp: + GetEmitter()->Emit(Bytecode::AddParamTimestampVal, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamString: + GetEmitter()->Emit(Bytecode::AddParamString, exec_ctx, param_idx); + break; + case ast::Builtin::StartNewParams: + GetEmitter()->Emit(Bytecode::StartNewParams, exec_ctx); + break; + case ast::Builtin::FinishNewParams: + GetEmitter()->Emit(Bytecode::FinishParams, exec_ctx); + break; default: UNREACHABLE("Impossible parameter call!"); } @@ -2544,35 +2657,35 @@ void BytecodeGenerator::VisitBuiltinStringCall(ast::CallExpr *call, ast::Builtin LocalVar ret = GetExecutionResult()->GetOrCreateDestination(call->GetType()); switch (builtin) { case ast::Builtin::SplitPart: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar delim = VisitExpressionForSQLValue(call->Arguments()[2]); - LocalVar field = VisitExpressionForSQLValue(call->Arguments()[3]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar delim = VisitExpressionForRValue(call->Arguments()[2]); + LocalVar field = VisitExpressionForRValue(call->Arguments()[3]); GetEmitter()->Emit(Bytecode::SplitPart, ret, exec_ctx, input_string, delim, field); break; } case ast::Builtin::Chr: { // input_string here is a integer type number - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Chr, ret, exec_ctx, input_string); break; } case ast::Builtin::CharLength: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::CharLength, ret, exec_ctx, input_string); break; } case ast::Builtin::ASCII: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::ASCII, ret, exec_ctx, input_string); break; } case ast::Builtin::Lower: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Lower, ret, exec_ctx, input_string); break; } case ast::Builtin::Upper: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Upper, ret, exec_ctx, input_string); break; } @@ -2581,73 +2694,73 @@ void BytecodeGenerator::VisitBuiltinStringCall(ast::CallExpr *call, ast::Builtin break; } case ast::Builtin::StartsWith: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar start_str = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar start_str = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::StartsWith, ret, exec_ctx, input_string, start_str); break; } case ast::Builtin::Substring: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar start_ind = VisitExpressionForSQLValue(call->Arguments()[2]); - LocalVar length = VisitExpressionForSQLValue(call->Arguments()[3]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar start_ind = VisitExpressionForRValue(call->Arguments()[2]); + LocalVar length = VisitExpressionForRValue(call->Arguments()[3]); GetEmitter()->Emit(Bytecode::Substring, ret, exec_ctx, input_string, start_ind, length); break; } case ast::Builtin::Reverse: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Reverse, ret, exec_ctx, input_string); break; } case ast::Builtin::Left: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar len = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar len = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::Left, ret, exec_ctx, input_string, len); break; } case ast::Builtin::Right: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar len = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar len = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::Right, ret, exec_ctx, input_string, len); break; } case ast::Builtin::Repeat: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar num_repeat = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar num_repeat = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::Repeat, ret, exec_ctx, input_string, num_repeat); break; } case ast::Builtin::Trim: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Trim, ret, exec_ctx, input_string); break; } case ast::Builtin::Trim2: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar trim_str = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar trim_str = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::Trim2, ret, exec_ctx, input_string, trim_str); break; } case ast::Builtin::Position: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar sub_string = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar sub_string = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::Position, ret, exec_ctx, input_string, sub_string); break; } case ast::Builtin::Length: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Length, ret, exec_ctx, input_string); break; } case ast::Builtin::InitCap: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::InitCap, ret, exec_ctx, input_string); break; } case ast::Builtin::Lpad: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar len = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar len = VisitExpressionForRValue(call->Arguments()[2]); if (call->NumArgs() == 4) { - LocalVar pad = VisitExpressionForSQLValue(call->Arguments()[3]); + LocalVar pad = VisitExpressionForRValue(call->Arguments()[3]); GetEmitter()->Emit(Bytecode::LPad3Arg, ret, exec_ctx, input_string, len, pad); } else { GetEmitter()->Emit(Bytecode::LPad2Arg, ret, exec_ctx, input_string, len); @@ -2655,10 +2768,10 @@ void BytecodeGenerator::VisitBuiltinStringCall(ast::CallExpr *call, ast::Builtin break; } case ast::Builtin::Rpad: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar len = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar len = VisitExpressionForRValue(call->Arguments()[2]); if (call->NumArgs() == 4) { - LocalVar pad = VisitExpressionForSQLValue(call->Arguments()[3]); + LocalVar pad = VisitExpressionForRValue(call->Arguments()[3]); GetEmitter()->Emit(Bytecode::RPad3Arg, ret, exec_ctx, input_string, len, pad); } else { GetEmitter()->Emit(Bytecode::RPad2Arg, ret, exec_ctx, input_string, len); @@ -2666,21 +2779,21 @@ void BytecodeGenerator::VisitBuiltinStringCall(ast::CallExpr *call, ast::Builtin break; } case ast::Builtin::Ltrim: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); if (call->NumArgs() == 2) { GetEmitter()->Emit(Bytecode::LTrim1Arg, ret, exec_ctx, input_string); } else { - LocalVar chars = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar chars = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::LTrim2Arg, ret, exec_ctx, input_string, chars); } break; } case ast::Builtin::Rtrim: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); if (call->NumArgs() == 2) { GetEmitter()->Emit(Bytecode::RTrim1Arg, ret, exec_ctx, input_string); } else { - LocalVar chars = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar chars = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::RTrim2Arg, ret, exec_ctx, input_string, chars); } break; @@ -2695,7 +2808,7 @@ void BytecodeGenerator::VisitBuiltinStringCall(ast::CallExpr *call, ast::Builtin auto arr_elem_ptr = GetCurrentFunction()->NewLocal(string_type->PointerTo()->PointerTo()); for (uint32_t i = 0; i < num_inputs; i++) { GetEmitter()->EmitLea(arr_elem_ptr, inputs, i * 8); - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[i + 1]); + LocalVar input_string = VisitExpressionForLValue(call->Arguments()[i + 1]); GetEmitter()->EmitAssign(Bytecode::Assign8, arr_elem_ptr.ValueOf(), input_string); } @@ -3150,7 +3263,19 @@ void BytecodeGenerator::VisitBuiltinCallExpr(ast::CallExpr *call) { case ast::Builtin::GetParamDouble: case ast::Builtin::GetParamDate: case ast::Builtin::GetParamTimestamp: - case ast::Builtin::GetParamString: { + case ast::Builtin::GetParamString: + case ast::Builtin::AddParamBool: + case ast::Builtin::AddParamTinyInt: + case ast::Builtin::AddParamSmallInt: + case ast::Builtin::AddParamInt: + case ast::Builtin::AddParamBigInt: + case ast::Builtin::AddParamReal: + case ast::Builtin::AddParamDouble: + case ast::Builtin::AddParamDate: + case ast::Builtin::AddParamTimestamp: + case ast::Builtin::AddParamString: + case ast::Builtin::StartNewParams: + case ast::Builtin::FinishNewParams: { VisitBuiltinParamCall(call, builtin); break; } @@ -3314,11 +3439,15 @@ void BytecodeGenerator::VisitBuiltinIndexIteratorCall(ast::CallExpr *call, ast:: void BytecodeGenerator::VisitRegularCallExpr(ast::CallExpr *call) { bool caller_wants_result = GetExecutionResult() != nullptr; - NOISEPAGE_ASSERT(!caller_wants_result || GetExecutionResult()->IsRValue(), "Calls can only be R-Values!"); - + NOISEPAGE_ASSERT(!caller_wants_result || GetExecutionResult()->IsRValue() || + (GetExecutionResult()->IsLValue() && call->GetType()->IsSqlValueType()), + "Calls can only be R-Values!"); std::vector params; - auto *func_type = call->Function()->GetType()->As(); + auto *func_type = call->Function()->GetType()->SafeAs(); + if (func_type == nullptr) { + func_type = call->Function()->GetType()->SafeAs()->GetFunctionType(); + } if (!func_type->GetReturnType()->IsNilType()) { LocalVar ret_val; @@ -3327,6 +3456,9 @@ void BytecodeGenerator::VisitRegularCallExpr(ast::CallExpr *call) { // Let the caller know where the result value is GetExecutionResult()->SetDestination(ret_val.ValueOf()); + if (GetExecutionResult()->IsLValue()) { + GetExecutionResult()->SetDestination(ret_val.AddressOf()); + } } else { ret_val = GetCurrentFunction()->NewLocal(func_type->GetReturnType()); } @@ -3336,12 +3468,21 @@ void BytecodeGenerator::VisitRegularCallExpr(ast::CallExpr *call) { } // Collect non-return-value parameters as usual - for (uint32_t i = 0; i < func_type->GetNumParams(); i++) { - params.push_back(VisitExpressionForRValue(call->Arguments()[i])); + for (uint32_t i = 0; i < call->Arguments().size(); i++) { + if (func_type->GetParams()[i].type_->IsSqlValueType()) { + params.push_back(VisitExpressionForLValue(call->Arguments()[i])); + } else { + params.push_back(VisitExpressionForRValue(call->Arguments()[i])); + } } // Emit call const auto func_id = LookupFuncIdByName(call->GetFuncName().GetData()); + if (func_id == FunctionInfo::K_INVALID_FUNC_ID) { + auto action = GetEmitter()->DeferedEmitCall(params); + deferred_function_create_actions_[call->GetFuncName().GetString()].push_back(action); + return; + } NOISEPAGE_ASSERT(func_id != FunctionInfo::K_INVALID_FUNC_ID, "Function not found!"); GetEmitter()->EmitCall(func_id, params); } @@ -3349,10 +3490,18 @@ void BytecodeGenerator::VisitRegularCallExpr(ast::CallExpr *call) { void BytecodeGenerator::VisitCallExpr(ast::CallExpr *node) { ast::CallExpr::CallKind call_kind = node->GetCallKind(); - if (call_kind == ast::CallExpr::CallKind::Builtin) { - VisitBuiltinCallExpr(node); - } else { - VisitRegularCallExpr(node); + switch (call_kind) { + case ast::CallExpr::CallKind::Builtin: { + VisitBuiltinCallExpr(node); + break; + } + case ast::CallExpr::CallKind::Regular: { + VisitRegularCallExpr(node); + break; + } + default: { + UNREACHABLE("Unknown Call Kind"); + } } } @@ -3368,7 +3517,7 @@ void BytecodeGenerator::VisitFile(ast::File *node) { } void BytecodeGenerator::VisitLitExpr(ast::LitExpr *node) { - NOISEPAGE_ASSERT(GetExecutionResult()->IsRValue(), "Literal expressions cannot be R-Values!"); + NOISEPAGE_ASSERT(GetExecutionResult()->IsRValue(), "Literal expressions cannot be L-Values!"); LocalVar target = GetExecutionResult()->GetOrCreateDestination(node->GetType()); @@ -3518,8 +3667,8 @@ void BytecodeGenerator::VisitPrimitiveArithmeticExpr(ast::BinaryOpExpr *node) { void BytecodeGenerator::VisitSqlArithmeticExpr(ast::BinaryOpExpr *node) { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(node->GetType()); - LocalVar left = VisitExpressionForSQLValue(node->Left()); - LocalVar right = VisitExpressionForSQLValue(node->Right()); + LocalVar left = VisitExpressionForLValue(node->Left()); + LocalVar right = VisitExpressionForLValue(node->Right()); const bool is_integer_math = node->GetType()->IsSpecificBuiltin(ast::BuiltinType::Integer); @@ -3605,8 +3754,8 @@ void BytecodeGenerator::VisitBinaryOpExpr(ast::BinaryOpExpr *node) { void BytecodeGenerator::VisitSqlCompareOpExpr(ast::ComparisonOpExpr *compare) { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(compare->GetType()); - LocalVar left = VisitExpressionForSQLValue(compare->Left()); - LocalVar right = VisitExpressionForSQLValue(compare->Right()); + LocalVar left = VisitExpressionForLValue(compare->Left()); + LocalVar right = VisitExpressionForLValue(compare->Right()); NOISEPAGE_ASSERT(compare->Left()->GetType() == compare->Right()->GetType(), "Left and right input types to comparison are not equal"); @@ -3749,8 +3898,10 @@ void BytecodeGenerator::BuildAssign(LocalVar dest, LocalVar val, ast::Type *dest GetEmitter()->EmitAssign(Bytecode::Assign2, dest, val); } else if (size == 4) { GetEmitter()->EmitAssign(Bytecode::Assign4, dest, val); - } else { + } else if (size == 8 && dest_type != ast::BuiltinType::Get(dest_type->GetContext(), ast::BuiltinType::Date)) { GetEmitter()->EmitAssign(Bytecode::Assign8, dest, val); + } else { + GetEmitter()->EmitAssignN(dest, val, size); } } @@ -3870,6 +4021,36 @@ void BytecodeGenerator::VisitMapTypeRepr(ast::MapTypeRepr *node) { FunctionInfo *BytecodeGenerator::AllocateFunc(const std::string &func_name, ast::FunctionType *const func_type) { // Allocate function const auto func_id = static_cast(functions_.size()); + functions_.emplace_back(func_id, std::string(func_name), func_type); + FunctionInfo *func = &functions_.back(); + + // Register return type + if (auto *return_type = func_type->GetReturnType(); !return_type->IsNilType()) { + func->NewParameterLocal(return_type->PointerTo(), "hiddenRv"); + } + + // Register parameters + for (const auto ¶m : func_type->GetParams()) { + if (param.type_->IsSqlValueType()) { + func->NewParameterLocal(param.type_->PointerTo(), param.name_.GetData()); + } else { + func->NewParameterLocal(param.type_, param.name_.GetData()); + } + } + + // Cache + func_map_[func->GetName()] = func->GetId(); + for (auto action : deferred_function_create_actions_[func->GetName()]) { + action(func->GetId()); + } + + return func; +} + +FunctionInfo *BytecodeGenerator::AllocateFunc(const std::string &func_name, ast::FunctionType *func_type, + LocalVar captures, ast::Type *capture_type) { + // Allocate function + const auto func_id = static_cast(functions_.size()); functions_.emplace_back(func_id, func_name, func_type); FunctionInfo *func = &functions_.back(); @@ -3878,6 +4059,9 @@ FunctionInfo *BytecodeGenerator::AllocateFunc(const std::string &func_name, ast: func->NewParameterLocal(return_type->PointerTo(), "hiddenRv"); } + // lambda captures + func->NewParameterLocal(capture_type->PointerTo(), "hiddenCaptures"); + // Register parameters for (const auto ¶m : func_type->GetParams()) { func->NewParameterLocal(param.type_, param.name_.GetData()); @@ -3885,7 +4069,9 @@ FunctionInfo *BytecodeGenerator::AllocateFunc(const std::string &func_name, ast: // Cache func_map_[func->GetName()] = func->GetId(); - + for (auto action : deferred_function_create_actions_[func->GetName()]) { + action(func->GetId()); + } return func; } @@ -3943,12 +4129,6 @@ LocalVar BytecodeGenerator::VisitExpressionForRValue(ast::Expr *expr) { return scope.GetDestination(); } -LocalVar BytecodeGenerator::VisitExpressionForSQLValue(ast::Expr *expr) { return VisitExpressionForLValue(expr); } - -void BytecodeGenerator::VisitExpressionForSQLValue(ast::Expr *expr, LocalVar dest) { - VisitExpressionForRValue(expr, dest); -} - void BytecodeGenerator::VisitExpressionForRValue(ast::Expr *expr, LocalVar dest) { RValueResultScope scope(this, dest); Visit(expr); @@ -3957,7 +4137,7 @@ void BytecodeGenerator::VisitExpressionForRValue(ast::Expr *expr, LocalVar dest) void BytecodeGenerator::VisitExpressionForTest(ast::Expr *expr, BytecodeLabel *then_label, BytecodeLabel *else_label, TestFallthrough fallthrough) { // Evaluate the expression - LocalVar cond = VisitExpressionForRValue(expr); + LocalVar cond = VisitExpressionForRValue(expr).ValueOf(); switch (fallthrough) { case TestFallthrough::Then: { diff --git a/src/execution/vm/bytecode_module.cpp b/src/execution/vm/bytecode_module.cpp index 37e3f33118..8a97ed02b5 100644 --- a/src/execution/vm/bytecode_module.cpp +++ b/src/execution/vm/bytecode_module.cpp @@ -170,6 +170,13 @@ void PrettyPrintFuncCode(std::ostream &os, const BytecodeModule &module, const F break; } case OperandType::FunctionId: { + auto fn_id = iter->GetFunctionIdOperand(i); + if (fn_id == FunctionInfo::K_INVALID_FUNC_ID) { + os << "func=<" + << "unresolved lambda" + << ">"; + break; + } auto target = module.GetFuncInfoById(iter->GetFunctionIdOperand(i)); os << "func=<" << target->GetName() << ">"; break; diff --git a/src/execution/vm/llvm_engine.cpp b/src/execution/vm/llvm_engine.cpp index e123d85d13..7b3bcefd29 100644 --- a/src/execution/vm/llvm_engine.cpp +++ b/src/execution/vm/llvm_engine.cpp @@ -287,8 +287,13 @@ llvm::FunctionType *LLVMEngine::TypeMap::GetLLVMFunctionType(const ast::Function // for (const auto ¶m_info : func_type->GetParams()) { - llvm::Type *param_type = GetLLVMType(param_info.type_); - param_types.push_back(param_type); + // TODO(Kyle): make this read from bytecode stuff instead to avoid this + if (param_info.type_->IsSqlValueType()) { + param_types.push_back(GetLLVMType(param_info.type_->PointerTo())); + } else { + llvm::Type *param_type = GetLLVMType(param_info.type_); + param_types.push_back(param_type); + } } return llvm::FunctionType::get(return_type, param_types, false); @@ -335,6 +340,16 @@ LLVMEngine::FunctionLocalsMap::FunctionLocalsMap(const FunctionInfo &func_info, params_[param.GetOffset()] = &*arg_iter; } + if (func_info.IsLambda()) { + auto capture_type = type_map->GetLLVMType(func_info.GetFuncType()->GetCapturesType()->PointerTo()); + auto capture_local = func_locals[local_idx - 1]; + auto capture_param = params_[capture_local.GetOffset()]; + auto new_capture_param = ir_builder->CreateBitCast(capture_param, capture_type); + params_[capture_local.GetOffset()] = new_capture_param; + } + + auto calling_context = func_info; + // Allocate all local variables up front. for (; local_idx < func_info.GetLocals().size(); local_idx++) { const LocalInfo &local_info = func_locals[local_idx]; @@ -346,7 +361,13 @@ LLVMEngine::FunctionLocalsMap::FunctionLocalsMap(const FunctionInfo &func_info, llvm::Value *LLVMEngine::FunctionLocalsMap::GetArgumentById(LocalVar var) { if (auto iter = params_.find(var.GetOffset()); iter != params_.end()) { - return iter->second; + auto val = iter->second; + if ((var.GetAddressMode() == LocalVar::AddressMode::Address) && llvm::isa(val)) { + auto new_val = ir_builder_->CreateAlloca(val->getType()); + ir_builder_->CreateStore(val, new_val); + val = new_val; + } + return val; } if (auto iter = locals_.find(var.GetOffset()); iter != locals_.end()) { @@ -636,6 +657,11 @@ void LLVMEngine::CompiledModuleBuilder::BuildSimpleCFG(const FunctionInfo &func_ void LLVMEngine::CompiledModuleBuilder::DefineFunction(const FunctionInfo &func_info, llvm::IRBuilder<> *ir_builder) { llvm::LLVMContext &ctx = ir_builder->getContext(); llvm::Function *func = llvm_module_->getFunction(func_info.GetName()); + if (func->getName().str().find("inline") != std::string::npos) { + func->setLinkage(llvm::Function::LinkOnceAnyLinkage); + func->addFnAttr(llvm::Attribute::AlwaysInline); + } + llvm::BasicBlock *first_bb = llvm::BasicBlock::Create(ctx, "BB0", func); llvm::BasicBlock *entry_bb = llvm::BasicBlock::Create(ctx, "EntryBB", func, first_bb); diff --git a/src/execution/vm/module.cpp b/src/execution/vm/module.cpp index 81a6e94629..1aea32b09e 100644 --- a/src/execution/vm/module.cpp +++ b/src/execution/vm/module.cpp @@ -279,7 +279,10 @@ void Module::CompileToMachineCode() { // previous implementation. for (const auto &func_info : bytecode_module_->GetFunctionsInfo()) { auto *jit_function = jit_module_->GetFunctionPointer(func_info.GetName()); - NOISEPAGE_ASSERT(jit_function != nullptr, "Missing function in compiled module!"); + // TODO(Kyle): Why is this OK now? + if (jit_function == nullptr) { + continue; + } functions_[func_info.GetId()].store(jit_function, std::memory_order_relaxed); } }); diff --git a/src/execution/vm/vm.cpp b/src/execution/vm/vm.cpp index 4a4e27990d..8e60346d52 100644 --- a/src/execution/vm/vm.cpp +++ b/src/execution/vm/vm.cpp @@ -413,6 +413,14 @@ void VM::Interpret(const uint8_t *ip, Frame *frame) { // NOLINT GEN_ASSIGN(int64_t, 8); #undef GEN_ASSIGN + OP(AssignN) : { + auto *dest = frame->LocalAt(READ_LOCAL_ID()); + auto *src = frame->LocalAt(READ_LOCAL_ID()); + auto len = READ_UIMM4(); + OpAssignN(dest, src, len); + DISPATCH_NEXT(); + } + OP(AssignImm4F) : { auto *dest = frame->LocalAt(READ_LOCAL_ID()); OpAssignImm4F(dest, READ_IMM4F()); @@ -2207,6 +2215,38 @@ void VM::Interpret(const uint8_t *ip, Frame *frame) { // NOLINT GEN_PARAM_GET(String, StringVal) #undef GEN_PARAM_GET +#define GEN_PARAM_ADD(Name, SqlType) \ + OP(AddParam##Name) : { \ + auto *exec_ctx = frame->LocalAt(READ_LOCAL_ID()); \ + auto *ret = frame->LocalAt(READ_LOCAL_ID()); \ + OpAddParam##Name(exec_ctx, ret); \ + DISPATCH_NEXT(); \ + } + + GEN_PARAM_ADD(Bool, BoolVal) + GEN_PARAM_ADD(TinyInt, Integer) + GEN_PARAM_ADD(SmallInt, Integer) + GEN_PARAM_ADD(Int, Integer) + GEN_PARAM_ADD(BigInt, Integer) + GEN_PARAM_ADD(Real, Real) + GEN_PARAM_ADD(Double, Real) + GEN_PARAM_ADD(DateVal, DateVal) + GEN_PARAM_ADD(TimestampVal, TimestampVal) + GEN_PARAM_ADD(String, StringVal) +#undef GEN_PARAM_ADD + + OP(StartNewParams) : { + auto *exec_ctx = frame->LocalAt(READ_LOCAL_ID()); + OpStartNewParams(exec_ctx); + DISPATCH_NEXT(); + } + + OP(FinishParams) : { + auto *exec_ctx = frame->LocalAt(READ_LOCAL_ID()); + OpFinishParams(exec_ctx); + DISPATCH_NEXT(); + } + // ------------------------------------------------------- // Trig functions // ------------------------------------------------------- @@ -2673,7 +2713,7 @@ const uint8_t *VM::ExecuteCall(const uint8_t *ip, VM::Frame *caller) { const LocalVar param = LocalVar::Decode(READ_LOCAL_ID()); const void *param_ptr = caller->PtrToLocalAt(param); if (param.GetAddressMode() == LocalVar::AddressMode::Address) { - std::memcpy(raw_frame + param_info.GetOffset(), ¶m_ptr, param_info.GetSize()); + std::memcpy(raw_frame + param_info.GetOffset(), ¶m_ptr, sizeof(void *)); } else { std::memcpy(raw_frame + param_info.GetOffset(), param_ptr, param_info.GetSize()); } diff --git a/src/include/catalog/catalog_accessor.h b/src/include/catalog/catalog_accessor.h index 4c7927e7b5..cc20efb916 100644 --- a/src/include/catalog/catalog_accessor.h +++ b/src/include/catalog/catalog_accessor.h @@ -345,6 +345,8 @@ class EXPORT CatalogAccessor { */ common::ManagedPointer GetProcCtxPtr(proc_oid_t proc_oid); + // TODO(Kyle): Make these functions consistent + /** * Sets the proc context pointer column of proc_oid to func_context * @param proc_oid The proc_oid whose pointer column we are setting here diff --git a/src/include/catalog/postgres/pg_language.h b/src/include/catalog/postgres/pg_language.h index a22ece3c4a..c9ca1bddde 100644 --- a/src/include/catalog/postgres/pg_language.h +++ b/src/include/catalog/postgres/pg_language.h @@ -18,6 +18,10 @@ class PgProcImpl; class PgLanguage { private: friend class storage::RecoveryManager; + // TODO(Kyle): How do we want to expose these constants? + // This is a friend because the DDL executor needs to access + // the OID for the PL/pgSQL language... + friend class execution::sql::DDLExecutors; friend class Builder; friend class PgLanguageImpl; @@ -38,7 +42,8 @@ class PgLanguage { static constexpr CatalogColumnDef LANNAME{col_oid_t{2}}; // VARCHAR (skey) static constexpr CatalogColumnDef LANISPL{col_oid_t{3}}; // BOOLEAN (skey) static constexpr CatalogColumnDef LANPLTRUSTED{col_oid_t{4}}; // BOOLEAN (skey) - // TODO(tanujnay112): Make these foreign keys when we implement pg_proc + + // TODO(Kyle): Make these foreign keys when we implement pg_proc static constexpr CatalogColumnDef LANPLCALLFOID{ col_oid_t{5}}; // INTEGER (skey) (fkey: pg_proc) static constexpr CatalogColumnDef LANINLINE{col_oid_t{6}}; // INTEGER (skey) (fkey: pg_proc) diff --git a/src/include/execution/sql/ddl_executors.h b/src/include/execution/sql/ddl_executors.h index 85bfafa3c8..1300682ed6 100644 --- a/src/include/execution/sql/ddl_executors.h +++ b/src/include/execution/sql/ddl_executors.h @@ -12,6 +12,7 @@ class CreateIndexPlanNode; class CreateViewPlanNode; class DropDatabasePlanNode; class DropNamespacePlanNode; +class CreateFunctionPlanNode; class DropTablePlanNode; class DropIndexPlanNode; } // namespace noisepage::planner @@ -31,7 +32,7 @@ class DDLExecutors { DDLExecutors() = delete; /** - * @param node node to executed + * @param node node to execute * @param accessor accessor to use for execution * @return true if operation succeeded, false otherwise */ @@ -39,61 +40,69 @@ class DDLExecutors { common::ManagedPointer accessor); /** - * @param node node to executed + * @param node node to execute * @param accessor accessor to use for execution - * @return true if operation succeeded, false otherwise + * @return `true` if operation succeeds, `false` otherwise */ static bool CreateNamespaceExecutor(common::ManagedPointer node, common::ManagedPointer accessor); /** - * @param node node to executed + * @param node node to execute + * @param exec_ctx accessor to use for execution + * @return `true` if the operation succeeds, `false` otherwise + */ + static bool CreateFunctionExecutor(common::ManagedPointer node, + common::ManagedPointer accessor); + + /** + * @param node node to execute * @param accessor accessor to use for execution * @param connection_db database for the current connection - * @return true if operation succeeded, false otherwise + * @return `true` if operation succeeds, `false` otherwise */ static bool CreateTableExecutor(common::ManagedPointer node, common::ManagedPointer accessor, catalog::db_oid_t connection_db); /** - * @param node node to executed + * @param node node to execute * @param accessor accessor to use for execution - * @return true if operation succeeded, false otherwise + * @return `true` if operation succeeds, `false` otherwise */ static bool CreateIndexExecutor(common::ManagedPointer node, common::ManagedPointer accessor); /** - * @param node node to executed + * @param node node to execute * @param accessor accessor to use for execution * @param connection_db database for the current connection - * @return true if operation succeeded, false otherwise + * @return `true` if operation succeeds, `false` otherwise */ static bool DropDatabaseExecutor(common::ManagedPointer node, common::ManagedPointer accessor, catalog::db_oid_t connection_db); /** - * @param node node to executed + * @param node node to execute * @param accessor accessor to use for execution - * @return true if operation succeeded, false otherwise + * @return `true` if operation succeeds, `false` otherwise */ static bool DropNamespaceExecutor(common::ManagedPointer node, common::ManagedPointer accessor); /** - * @param node node to executed + * @param node node to execute * @param accessor accessor to use for execution - * @return true if operation succeeded, false otherwise + * @return `true` if operation succeeds, `false` otherwise */ static bool DropTableExecutor(common::ManagedPointer node, common::ManagedPointer accessor); /** - * @param node node to executed + * @param node node to execute * @param accessor accessor to use for execution - * @return true if operation succeeded, false otherwise + * @return `true` if operation succeeds, `false` otherwise */ static bool DropIndexExecutor(common::ManagedPointer node, common::ManagedPointer accessor); diff --git a/src/include/execution/vm/bytecode_emitter.h b/src/include/execution/vm/bytecode_emitter.h index 33f0afaeb1..164b487864 100644 --- a/src/include/execution/vm/bytecode_emitter.h +++ b/src/include/execution/vm/bytecode_emitter.h @@ -66,6 +66,11 @@ class BytecodeEmitter { */ void EmitAssign(Bytecode bytecode, LocalVar dest, LocalVar src); + /** + * TODO(Kyle): this. + */ + void EmitAssignN(LocalVar dest, LocalVar src, uint32_t len); + /** * Emit assignment code for 1 byte values. * @param dest destination variable @@ -167,6 +172,11 @@ class BytecodeEmitter { */ void EmitCall(FunctionId func_id, const std::vector ¶ms); + /** + * TODO(Kyle): this. + */ + std::function DeferedEmitCall(const std::vector ¶ms); + /** * Emit a return bytecode */ @@ -424,6 +434,14 @@ class BytecodeEmitter { *reinterpret_cast(&*(bytecode_->end() - sizeof(T))) = val; } + /** + * TODO(Kyle): this. + */ + template + auto EmitScalarValue(const T val, std::size_t index) -> std::enable_if_t> { + *reinterpret_cast(&*(bytecode_->begin() + index)) = val; + } + /** Emit a bytecode */ void EmitImpl(const Bytecode bytecode) { EmitScalarValue(Bytecodes::ToByte(bytecode)); } diff --git a/src/include/execution/vm/bytecode_function_info.h b/src/include/execution/vm/bytecode_function_info.h index 1a169117bc..19cbd577e8 100644 --- a/src/include/execution/vm/bytecode_function_info.h +++ b/src/include/execution/vm/bytecode_function_info.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -287,6 +288,16 @@ class FunctionInfo { */ uint32_t GetParamsCount() const noexcept { return num_params_; } + /** + * TODO(Kyle): this. + */ + void DeferAction(const std::function action) { actions_.push_back(action); } + + /** + * TODO(Kyle): this. + */ + bool IsLambda() const { return is_lambda_; } + private: friend class BytecodeGenerator; @@ -302,6 +313,15 @@ class FunctionInfo { // Allocate a new local variable in the function. LocalVar NewLocal(ast::Type *type, const std::string &name, LocalInfo::Kind kind); + // TODO(Kyle): this + LocalVar captures_; + + // TODO(Kyle): this + bool is_lambda_{false}; + + // TODO(Kyle): this + std::vector> actions_; + private: // The ID of the function in the module. IDs are unique within a module. FunctionId id_; diff --git a/src/include/execution/vm/bytecode_generator.h b/src/include/execution/vm/bytecode_generator.h index 16277e2caf..4876144f38 100644 --- a/src/include/execution/vm/bytecode_generator.h +++ b/src/include/execution/vm/bytecode_generator.h @@ -70,6 +70,10 @@ class BytecodeGenerator final : public ast::AstVisitor { // Allocate a new function ID FunctionInfo *AllocateFunc(const std::string &func_name, ast::FunctionType *func_type); + // Allocate a new function ID with captures. + FunctionInfo *AllocateFunc(const std::string &func_name, ast::FunctionType *func_type, LocalVar captures, + ast::Type *capture_type); + void VisitAbortTxn(ast::CallExpr *call); // ONLY FOR TESTING! @@ -189,7 +193,9 @@ class BytecodeGenerator final : public ast::AstVisitor { void SetExecutionResult(ExpressionResultScope *exec_result) { execution_result_ = exec_result; } // Access the current function that's being generated. May be NULL. - FunctionInfo *GetCurrentFunction() { return &functions_.back(); } + FunctionInfo *GetCurrentFunction() { return &functions_[current_fn_]; } + + void EnterFunction(FunctionId id) { current_fn_ = id; } private: // The data section of the module @@ -206,14 +212,22 @@ class BytecodeGenerator final : public ast::AstVisitor { // Information about all generated functions std::vector functions_; + // The ID of the current function. + FunctionId current_fn_{0}; + // Cache of function names to IDs for faster lookup std::unordered_map func_map_; + std::unordered_map>> deferred_function_create_actions_; // Emitter to write bytecode into the code section BytecodeEmitter emitter_; // RAII struct to capture semantics of expression evaluation ExpressionResultScope *execution_result_{nullptr}; + + // The loop builder for the current loop. + // TODO(Kyle): seems messy. + LoopBuilder *current_loop_{nullptr}; }; } // namespace noisepage::execution::vm diff --git a/src/include/execution/vm/bytecode_handlers.h b/src/include/execution/vm/bytecode_handlers.h index 3576278646..f0d228cd6a 100644 --- a/src/include/execution/vm/bytecode_handlers.h +++ b/src/include/execution/vm/bytecode_handlers.h @@ -170,6 +170,10 @@ VM_OP_HOT void OpAssign4(int32_t *dest, int32_t src) { *dest = src; } VM_OP_HOT void OpAssign8(int64_t *dest, int64_t src) { *dest = src; } +VM_OP_HOT void OpAssignN(noisepage::byte *dest, const noisepage::byte *const src, uint32_t len) { + std::memcpy(dest, src, len); +} + VM_OP_HOT void OpAssignImm1(int8_t *dest, int8_t src) { *dest = src; } VM_OP_HOT void OpAssignImm2(int16_t *dest, int16_t src) { *dest = src; } @@ -1408,6 +1412,9 @@ VM_OP_WARM void OpSorterIteratorSkipRows(noisepage::execution::sql::SorterIterat VM_OP void OpSorterIteratorFree(noisepage::execution::sql::SorterIterator *iter); +VM_OP void OpPushParamContext(noisepage::execution::exec::ExecutionContext **new_ctx, + noisepage::execution::exec::ExecutionContext *ctx); + // --------------------------------------------------------- // Output // --------------------------------------------------------- @@ -2140,6 +2147,30 @@ GEN_SCALAR_PARAM_GET(TimestampVal, TimestampVal) GEN_SCALAR_PARAM_GET(String, StringVal) #undef GEN_SCALAR_PARAM_GET +// Parameter calls +#define GEN_SCALAR_PARAM_ADD(Name, SqlType, typeId) \ + VM_OP_HOT void OpAddParam##Name(noisepage::execution::exec::ExecutionContext *exec_ctx, \ + noisepage::execution::sql::SqlType *ret) { \ + exec_ctx->AddParam(noisepage::common::ManagedPointer( \ + reinterpret_cast(ret))); \ + } + +GEN_SCALAR_PARAM_ADD(Bool, BoolVal, BOOLEAN) +GEN_SCALAR_PARAM_ADD(TinyInt, Integer, TINYINT) +GEN_SCALAR_PARAM_ADD(SmallInt, Integer, SMALLINT) +GEN_SCALAR_PARAM_ADD(Int, Integer, INTEGER) +GEN_SCALAR_PARAM_ADD(BigInt, Integer, BIGINT) +GEN_SCALAR_PARAM_ADD(Real, Real, DECIMAL) +GEN_SCALAR_PARAM_ADD(Double, Real, DECIMAL) +GEN_SCALAR_PARAM_ADD(DateVal, DateVal, DATE) +GEN_SCALAR_PARAM_ADD(TimestampVal, TimestampVal, TIMESTAMP) +GEN_SCALAR_PARAM_ADD(String, StringVal, VARCHAR) +#undef GEN_SCALAR_PARAM_ADD + +VM_OP_HOT void OpStartNewParams(noisepage::execution::exec::ExecutionContext *exec_ctx) { exec_ctx->StartParams(); } + +VM_OP_HOT void OpFinishParams(noisepage::execution::exec::ExecutionContext *exec_ctx) { exec_ctx->PopParams(); } + // --------------------------------- // Replication functions // --------------------------------- diff --git a/src/include/execution/vm/bytecodes.h b/src/include/execution/vm/bytecodes.h index cc58c4a4a8..111039127e 100644 --- a/src/include/execution/vm/bytecodes.h +++ b/src/include/execution/vm/bytecodes.h @@ -84,6 +84,7 @@ namespace noisepage::execution::vm { F(Assign2, OperandType::Local, OperandType::Local) \ F(Assign4, OperandType::Local, OperandType::Local) \ F(Assign8, OperandType::Local, OperandType::Local) \ + F(AssignN, OperandType::Local, OperandType::Local, OperandType::UImm4) \ F(AssignImm1, OperandType::Local, OperandType::Imm1) \ F(AssignImm2, OperandType::Local, OperandType::Imm2) \ F(AssignImm4, OperandType::Local, OperandType::Imm4) \ @@ -765,6 +766,18 @@ namespace noisepage::execution::vm { F(GetParamDateVal, OperandType::Local, OperandType::Local, OperandType::Local) \ F(GetParamTimestampVal, OperandType::Local, OperandType::Local, OperandType::Local) \ F(GetParamString, OperandType::Local, OperandType::Local, OperandType::Local) \ + F(AddParamBool, OperandType::Local, OperandType::Local) \ + F(AddParamTinyInt, OperandType::Local, OperandType::Local) \ + F(AddParamSmallInt, OperandType::Local, OperandType::Local) \ + F(AddParamInt, OperandType::Local, OperandType::Local) \ + F(AddParamBigInt, OperandType::Local, OperandType::Local) \ + F(AddParamReal, OperandType::Local, OperandType::Local) \ + F(AddParamDouble, OperandType::Local, OperandType::Local) \ + F(AddParamDateVal, OperandType::Local, OperandType::Local) \ + F(AddParamTimestampVal, OperandType::Local, OperandType::Local) \ + F(AddParamString, OperandType::Local, OperandType::Local) \ + F(StartNewParams, OperandType::Local) \ + F(FinishParams, OperandType::Local) \ \ /* FOR TESTING ONLY */ \ F(TestCatalogLookup, OperandType::Local, OperandType::Local, OperandType::StaticLocal, OperandType::UImm4, \ diff --git a/src/include/execution/vm/control_flow_builders.h b/src/include/execution/vm/control_flow_builders.h index 02479eebe5..68eca8dfff 100644 --- a/src/include/execution/vm/control_flow_builders.h +++ b/src/include/execution/vm/control_flow_builders.h @@ -79,6 +79,17 @@ class LoopBuilder : public BreakableBlockBuilder { */ explicit LoopBuilder(BytecodeGenerator *generator) : BreakableBlockBuilder(generator) {} + /** + * Construct a loop builder. + * + * TODO(Kyle): Why was this construtor removed? + * + * @param generator The generator the loop writes. + * @param prev + */ + explicit LoopBuilder(BytecodeGenerator *generator, LoopBuilder *prev = nullptr) + : BreakableBlockBuilder(generator), prev_loop_(prev) {} + /** * Destructor. */ @@ -114,6 +125,7 @@ class LoopBuilder : public BreakableBlockBuilder { private: BytecodeLabel header_label_; BytecodeLabel continue_label_; + LoopBuilder *prev_loop_; }; /** From ea585b62efc8eeea1cb2e72c7a9812cc0857ba6c Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sun, 4 Apr 2021 18:45:57 -0400 Subject: [PATCH 016/139] fixed bug in udf codegen, but now libpg_query wont link... --- src/execution/compiler/udf/udf_codegen.cpp | 6 ++++++ src/include/catalog/postgres/pg_language.h | 8 +++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index c553bb46fa..0253dc0aed 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -312,6 +312,12 @@ void UDFCodegen::Visit(ast::udf::IsNullExprAST *ast) { } } +void UDFCodegen::Visit(ast::udf::SeqStmtAST *ast) { + for (auto &stmt : ast->stmts) { + stmt->Accept(this); + } +} + void UDFCodegen::Visit(ast::udf::WhileStmtAST *ast) { ast->cond_expr->Accept(this); auto cond = dst_; diff --git a/src/include/catalog/postgres/pg_language.h b/src/include/catalog/postgres/pg_language.h index c9ca1bddde..ce0875f920 100644 --- a/src/include/catalog/postgres/pg_language.h +++ b/src/include/catalog/postgres/pg_language.h @@ -9,6 +9,10 @@ namespace noisepage::storage { class RecoveryManager; } // namespace noisepage::storage +namespace noisepage::execution::sql { +class DDLExecutors; +} // namespace noisepage::execution::sql + namespace noisepage::catalog::postgres { class Builder; class PgLanguageImpl; @@ -17,10 +21,8 @@ class PgProcImpl; /** The OIDs used by the NoisePage version of pg_language. */ class PgLanguage { private: - friend class storage::RecoveryManager; // TODO(Kyle): How do we want to expose these constants? - // This is a friend because the DDL executor needs to access - // the OID for the PL/pgSQL language... + friend class storage::RecoveryManager; friend class execution::sql::DDLExecutors; friend class Builder; From f5f8cf7b58a7cdb8c77cf5a4ca15054943f599df Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Tue, 20 Apr 2021 15:42:39 -0400 Subject: [PATCH 017/139] able to push from dev10 From aa56dbd1fedc6a4c1a62ac58e010a4a7407b81ac Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Tue, 20 Apr 2021 16:02:00 -0400 Subject: [PATCH 018/139] fix linker error with libpg_query --- third_party/libpg_query/pg_list.h | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/third_party/libpg_query/pg_list.h b/third_party/libpg_query/pg_list.h index 21e9a1a31b..3ffda4cfb0 100644 --- a/third_party/libpg_query/pg_list.h +++ b/third_party/libpg_query/pg_list.h @@ -37,6 +37,7 @@ #ifndef PG_LIST_H #define PG_LIST_H +#include #include "nodes.h" typedef struct ListCell ListCell; @@ -76,30 +77,30 @@ struct ListCell * if supported by the compiler, or as regular functions otherwise. * See STATIC_IF_INLINE in c.h. */ -#ifndef PG_USE_INLINE -extern ListCell *list_head(const List *l); -extern ListCell *list_tail(List *l); -extern int list_length(const List *l); -#endif /* PG_USE_INLINE */ -#if defined(PG_USE_INLINE) || defined(PG_LIST_INCLUDE_DEFINITIONS) -STATIC_IF_INLINE ListCell * +//#ifndef PG_USE_INLINE +//extern ListCell *list_head(const List *l); +//extern ListCell *list_tail(List *l); +//extern int list_length(const List *l); +//#endif /* PG_USE_INLINE */ +//#if defined(PG_USE_INLINE) || defined(PG_LIST_INCLUDE_DEFINITIONS) +static inline ListCell * list_head(const List *l) { return l ? l->head : NULL; } -STATIC_IF_INLINE ListCell * +static inline ListCell * list_tail(List *l) { return l ? l->tail : NULL; } -STATIC_IF_INLINE int +static inline int list_length(const List *l) { return l ? l->length : 0; } -#endif /*-- PG_USE_INLINE || PG_LIST_INCLUDE_DEFINITIONS */ +//#endif /*-- PG_USE_INLINE || PG_LIST_INCLUDE_DEFINITIONS */ /* * NB: There is an unfortunate legacy from a previous incarnation of From 86427c6302b8d9006e816c985c7eff7e2a5cb976 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Tue, 20 Apr 2021 16:07:26 -0400 Subject: [PATCH 019/139] remove old comments in libpg_query --- third_party/libpg_query/pg_list.h | 7 ------- 1 file changed, 7 deletions(-) diff --git a/third_party/libpg_query/pg_list.h b/third_party/libpg_query/pg_list.h index 3ffda4cfb0..25c023c6cf 100644 --- a/third_party/libpg_query/pg_list.h +++ b/third_party/libpg_query/pg_list.h @@ -77,12 +77,6 @@ struct ListCell * if supported by the compiler, or as regular functions otherwise. * See STATIC_IF_INLINE in c.h. */ -//#ifndef PG_USE_INLINE -//extern ListCell *list_head(const List *l); -//extern ListCell *list_tail(List *l); -//extern int list_length(const List *l); -//#endif /* PG_USE_INLINE */ -//#if defined(PG_USE_INLINE) || defined(PG_LIST_INCLUDE_DEFINITIONS) static inline ListCell * list_head(const List *l) { @@ -100,7 +94,6 @@ list_length(const List *l) { return l ? l->length : 0; } -//#endif /*-- PG_USE_INLINE || PG_LIST_INCLUDE_DEFINITIONS */ /* * NB: There is an unfortunate legacy from a previous incarnation of From 48112cea47bc9f055a252fcd06adcce574a4f66d Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 23 Apr 2021 08:59:36 -0400 Subject: [PATCH 020/139] remove old TODO in db_server.py --- script/testing/util/db_server.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/script/testing/util/db_server.py b/script/testing/util/db_server.py index bb500a8e62..0e3809c9f9 100644 --- a/script/testing/util/db_server.py +++ b/script/testing/util/db_server.py @@ -493,8 +493,9 @@ def handle_flags(value: str, meta: Dict) -> str: `-attribute=value` and instead want to format it as `-attribute` alone. This preprocessor encapsulates the logic for this transformation. - TODO(Kyle): Do we actually support any arguments like this? - I can't seem to come up with any actual examples... + NOTE(Kyle): At this time it doesn't appear we actually support + any arguments like this, but keeping it in anyway so I don't + inadvertently break something. Arguments --------- @@ -518,11 +519,6 @@ def apply_all(functions: List, init_obj, meta: Dict): Apply all of the functions in `functions` to object `init_obj` sequentially, supplying metadata object `meta` to each function invocation. - TODO(Kyle): Initially I wanted to implement this with function composition - in terms of functools.reduce() which makes it really beautiful, but there - we run into issues with multi-argument callbacks, and the real solution is - to use partial application, but this seemed like overkill... maybe revisit. - Arguments --------- functions : List[function] From 1e211e401509a1512c6024399f54eb20026c2f39 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 23 Apr 2021 09:45:49 -0400 Subject: [PATCH 021/139] cleanup in pipeline, remove dead code, add documentation --- src/execution/compiler/pipeline.cpp | 10 ++--- src/execution/compiler/udf/udf_codegen.cpp | 28 ++++++------- src/include/execution/compiler/pipeline.h | 49 +++++++++++----------- src/include/execution/vm/llvm_engine.h | 2 +- 4 files changed, 43 insertions(+), 46 deletions(-) diff --git a/src/execution/compiler/pipeline.cpp b/src/execution/compiler/pipeline.cpp index 3417079ea6..a500387259 100644 --- a/src/execution/compiler/pipeline.cpp +++ b/src/execution/compiler/pipeline.cpp @@ -34,8 +34,7 @@ Pipeline::Pipeline(CompilationContext *ctx) check_parallelism_(true), nested_(false) {} -Pipeline::Pipeline(OperatorTranslator *op, Pipeline::Parallelism parallelism, bool consumer) - : Pipeline(op->GetCompilationContext()) { +Pipeline::Pipeline(OperatorTranslator *op, Pipeline::Parallelism parallelism) : Pipeline(op->GetCompilationContext()) { UpdateParallelism(parallelism); RegisterStep(op); } @@ -434,10 +433,9 @@ ast::Identifier Pipeline::GetRunPipelineFunctionName() const { return codegen_->MakeIdentifier(CreatePipelineFunctionName("Run")); } -ast::Expr *Pipeline::GetNestedInputArg(uint32_t index) const { - NOISEPAGE_ASSERT(nested_, "Asking for input arg on non-nested pipeline"); - NOISEPAGE_ASSERT(index < extra_pipeline_params_.size(), - "Asking for input arg on non-nested pipeline that doesn't exist"); +ast::Expr *Pipeline::GetNestedInputArg(const std::size_t index) const { + NOISEPAGE_ASSERT(nested_, "Requested nested input argument on non-nested pipeline"); + NOISEPAGE_ASSERT(index < extra_pipeline_params_.size(), "Requested nested index argument out of range"); return codegen_->UnaryOp(parsing::Token::Type::STAR, codegen_->MakeExpr(extra_pipeline_params_[index]->Name())); } diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 0253dc0aed..4a90c62a39 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -335,7 +335,7 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { // const auto query = common::ManagedPointer(ast->query_); // auto exec_ctx = fb_->GetParameterByPosition(0); - // // TODO(Matt): I don't think the binder should need the database name. It's already bound in the ConnectionContext + // TODO(Matt): I don't think the binder should need the database name. It's already bound in the ConnectionContext // binder::BindNodeVisitor visitor(common::ManagedPointer(accessor_), db_oid_); // auto query_params = visitor.BindAndGetUDFParams(query, common::ManagedPointer(udf_ast_context_)); @@ -345,7 +345,7 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { // std::unique_ptr plan = trafficcop::TrafficCopUtil::Optimize( // accessor_->GetTxn(), common::ManagedPointer(accessor_), query, db_oid_, common::ManagedPointer(&stats), // std::make_unique(), 1000000); - // // make lambda that just writes into this + // make lambda that just writes into this // std::vector var_idents; // auto lam_var = codegen_->MakeFreshIdentifier("looplamb"); // execution::util::RegionVector params(codegen_->GetAstContext()->GetRegion()); @@ -356,7 +356,7 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { // for (auto var : ast->vars_) { // var_idents.push_back(str_to_ident_.find(var)->second); // auto var_ident = var_idents.back(); - // // NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, "Can't support non scalars yet!"); + // NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, "Can't support non scalars yet!"); // auto type = codegen_->TplType(execution::sql::GetTypeId(plan->GetOutputSchema()->GetColumn(i).GetType())); // fb_->Append(codegen_->Assign(codegen_->MakeExpr(var_ident), @@ -390,8 +390,8 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { // lambda_expr = fn.FinishLambda(std::move(captures)); // lambda_expr->SetName(lam_var); - // // want to pass something down that will materialize the lambda function for me into lambda_expr and will - // // also feed in a lambda_expr to the compiler + // want to pass something down that will materialize the lambda function for me into lambda_expr and will + // also feed in a lambda_expr to the compiler // execution::exec::ExecutionSettings exec_settings{}; // const std::string dummy_query = ""; // auto exec_query = execution::compiler::CompilationContext::Compile( @@ -406,10 +406,10 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { // codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLitExpr()->TypeRepr()), // lambda_expr)); - // // make query state + // make query state // auto query_state = codegen_->MakeFreshIdentifier("query_state"); // fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); - // // set its execution context to whatever exec context was passed in here + // set its execution context to whatever exec context was passed in here // fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); // std::vector>::iterator> sorted_vec; // for (auto it = query_params.begin(); it != query_params.end(); it++) { @@ -418,10 +418,10 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { // std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second < y->second; }); // for (auto entry : sorted_vec) { - // // TODO(order these dudes) + // TODO(order these dudes) // type::TypeId type = type::TypeId::INVALID; // udf_ast_context_->GetVariableType(entry->first, &type); - // // NOISEPAGE_ASSERT(ret, "didn't find param in udf ast context"); + // NOISEPAGE_ASSERT(ret, "didn't find param in udf ast context"); // execution::ast::Builtin builtin; // switch (type) { @@ -457,15 +457,15 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { // } // fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, codegen_->MakeExpr(str_to_ident_[entry->first])})); // } - // // set param 1 - // // set param 2 - // // etc etc + // set param 1 + // set param 2 + // etc etc // fb_->Append(codegen_->Assign( // codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); - // // set its execution context to whatever exec context was passed in here + // set its execution context to whatever exec context was passed in here // for (auto &sub_fn : fns) { - // // aux_decls_.push_back(c) + // aux_decls_.push_back(c) // if (sub_fn.find("Run") != std::string::npos) { // fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), // {codegen_->AddressOf(query_state), codegen_->MakeExpr(lam_var)})); diff --git a/src/include/execution/compiler/pipeline.h b/src/include/execution/compiler/pipeline.h index 7c7ca78912..46a6affa92 100644 --- a/src/include/execution/compiler/pipeline.h +++ b/src/include/execution/compiler/pipeline.h @@ -67,9 +67,8 @@ class Pipeline { * Create a pipeline with the given operator as the root. * @param op The root operator of the pipeline. * @param parallelism The operator's requested parallelism. - * @param consumer TODO(Kyle) */ - Pipeline(OperatorTranslator *op, Parallelism parallelism, bool consumer = false); + Pipeline(OperatorTranslator *op, Parallelism parallelism); /** * Register an operator in this pipeline with a customized parallelism configuration. @@ -131,6 +130,14 @@ class Pipeline { */ void CollectDependencies(std::vector *deps); + /** + * Store in the provided output vector the set of all dependencies for this pipeline. In other + * words, store in the output vector all pipelines that must execute (in order) before this + * pipeline can begin. + * @param[out] deps The sorted list of pipelines to execute before this pipeline can begin. + */ + void CollectDependencies(std::vector *deps) const; + /** * Perform initialization logic before code generation. * @param exec_settings The execution settings used for query compilation. @@ -140,8 +147,9 @@ class Pipeline { /** * Generate all functions to execute this pipeline in the provided container. * @param builder The builder for the executable query container. - * @param query_id TODO(Kyle) - * @param output_callback TODO(Kyle) + * @param query_id The ID of the query for which this pipeline is generated. + * @param output_callback The lambda expression that represents the + * output callback for the pipeline. */ void GeneratePipeline(ExecutableQueryFragmentBuilder *builder, query_id_t query_id, ast::LambdaExpr *output_callback = nullptr) const; @@ -221,17 +229,17 @@ class Pipeline { void InjectEndResourceTracker(FunctionBuilder *builder, bool is_hook) const; /** - * @return query identifier of the query that we are codegen-ing + * @return Query identifier of the query that we are codegen-ing */ query_id_t GetQueryId() const; /** - * @return a pointer to the OUFeatureVector in the pipeline state + * @return A pointer to the OUFeatureVector in the pipeline state */ ast::Expr *OUFeatureVecPtr() const { return oufeatures_.GetPtr(codegen_); } - /** @return TODO(Kyle) */ - ast::Expr *GetNestedInputArg(uint32_t index) const; + /** @return The nested input argument at `index` */ + ast::Expr *GetNestedInputArg(std::size_t index) const; /** @return `true` if this pipeline is prepared, `false` otherwise */ bool IsPrepared() const { return prepared_; } @@ -244,7 +252,7 @@ class Pipeline { ast::Identifier GetTearDownPipelineStateFunctionName() const; ast::Identifier GetWorkFunctionName() const; - // TODO(Kyle) this + // Generate a wrapper function for the current pipeline. ast::FunctionDecl *GeneratePipelineWrapperFunction(ast::LambdaExpr *output_callback) const; // Generate the pipeline state initialization logic. @@ -265,7 +273,7 @@ class Pipeline { // Generate pipeline tear-down logic. ast::FunctionDecl *GenerateTearDownPipelineFunction(ast::LambdaExpr *output_callback) const; - /** @brief TODO(Kyle) */ + /** @brief Indicate that this pipeline is nested. */ void MarkNested() { nested_ = true; } private: @@ -276,28 +284,19 @@ class Pipeline { /** @return The vector of pipeline operators that make up the pipeline. */ const std::vector &GetTranslators() const { return steps_; } - /** @brief TODO(Kyle) */ - void InjectStartPipelineTracker(FunctionBuilder *builder) const; - - /** @brief TODO(Kyle) */ - void InjectEndResourceTracker(FunctionBuilder *builder, query_id_t query_id) const; + /** @return An identifier for the pipeline `Init` function */ + ast::Identifier GetInitPipelineFunctionName() const; - /** @brief TODO(Kyle) */ + /** @return An identifier for the pipeline `Run` function */ ast::Identifier GetRunPipelineFunctionName() const; - /** @brief TODO(Kyle) */ - void CollectDependencies(std::vector *deps) const; - - /** @brief TODO(Kyle) */ + /** @return An identifier for the pipeline `Teardown` function */ ast::Identifier GetTeardownPipelineFunctionName() const; - /** @brief TODO(Kyle) */ - ast::Identifier GetInitPipelineFunctionName() const; - - /** @brief TODO(Kyle) */ + /** @return An immutable reference to the pipeline state descriptor */ const StateDescriptor &GetPipelineStateDescriptor() const { return state_; } - /** @brief TODO(Kyle) */ + /** @return A mutable reference to the pipeline state descriptor */ StateDescriptor &GetPipelineStateDescriptor() { return state_; } private: diff --git a/src/include/execution/vm/llvm_engine.h b/src/include/execution/vm/llvm_engine.h index 269122f4eb..422d8c17d2 100644 --- a/src/include/execution/vm/llvm_engine.h +++ b/src/include/execution/vm/llvm_engine.h @@ -225,7 +225,7 @@ class LLVMEngine { /** * Process-wide LLVM engine settings. * - * TODO(Kyle): I'm not particularly happy with this setup - an inline + * NOTE(Kyle): I'm not particularly happy with this setup - an inline * static variable (essentially just a global with scoping) for managing * the settings for the LLVM engine. The ownership model should be * relatively simple - the LLVMEngine should own its settings, but From d4278aa2bdbeff05b662f108080528350321113c Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 23 Apr 2021 09:51:06 -0400 Subject: [PATCH 022/139] implement CreateFunctionExecutor in traffic_cop --- src/traffic_cop/traffic_cop.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index 7033f9d657..502b8fe48a 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -279,12 +279,10 @@ TrafficCopResult TrafficCop::ExecuteCreateStatement( break; } case network::QueryType::QUERY_CREATE_FUNCTION: { - // TODO(Kyle): Port executor - // if (execution::sql::DDLExecutors::CreateFunctionExecutor( - // physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { - // return {ResultType::COMPLETE, 0}; - // } - throw NOT_IMPLEMENTED_EXCEPTION("CREATE FUNCTION not implemented"); + if (execution::sql::DDLExecutors::CreateFunctionExecutor( + physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { + return {ResultType::COMPLETE, 0}; + } break; } default: { From 3b29f8839de2adfc6bb79443b780cac79169af24 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 23 Apr 2021 10:09:29 -0400 Subject: [PATCH 023/139] starting to work on cleaning up udf_parser --- src/parser/udf/udf_parser.cpp | 89 +++++++++++++++------------------ src/traffic_cop/traffic_cop.cpp | 2 +- src/util/query_exec_util.cpp | 2 +- 3 files changed, 41 insertions(+), 52 deletions(-) diff --git a/src/parser/udf/udf_parser.cpp b/src/parser/udf/udf_parser.cpp index 735d3781ae..ba484f016a 100644 --- a/src/parser/udf/udf_parser.cpp +++ b/src/parser/udf/udf_parser.cpp @@ -64,10 +64,8 @@ std::unique_ptr PLpgSQLParser::ParsePLpgSQL(std::vector PLpgSQLParser::ParseBlock(const nlohmann::json &block) std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &decl) { const auto &decl_names = decl.items().begin(); - for (auto &it : decl.items()) { - std::cout << it.key() << " : " << it.value() << "\n"; - } - // NOISEPAGE_ASSERT(decl_names->size() >= 1, "Bad declaration names membership size"); PARSER_LOG_DEBUG("Declaration : {}", decl_names.key()); if (decl_names.key() == kPLpgSQL_var) { @@ -196,10 +190,7 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &decl) { udf_ast_context_->SetVariableType(var_name, type::TypeId::INVALID); return std::unique_ptr(new DeclStmtAST(var_name, type::TypeId::INVALID, std::move(initial))); } else { - NOISEPAGE_ASSERT(false, "Unsupported "); - // udf_ast_context_->SetVariableType(var_name, type::TypeId::INVALID); - // return std::unique_ptr( - // new DeclStmtAST(var_name, type::TypeId::INVALID)); + NOISEPAGE_ASSERT(false, "Unsupported"); } } else if (decl_names.key() == kPLpgSQL_row) { auto var_name = decl[kPLpgSQL_row][kRefname].get(); @@ -250,45 +241,43 @@ std::unique_ptr PLpgSQLParser::ParseFor(const nlohmann::json &loop) { // TODO(Kyle): Implement std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql_stmt) { - throw NOT_IMPLEMENTED_EXCEPTION("ParseSQL Not Implemented"); - // auto sql_query = sql_stmt[kSqlstmt][kPLpgSQL_expr][kQuery].get(); - // auto var_name = sql_stmt[kRow][kPLpgSQL_row][kFields][0][kName].get(); - // auto parse_result = PostgresParser::BuildParseTree(sql_query.c_str()); - // if (parse_result == nullptr) { - // PARSER_LOG_DEBUG("Bad SQL statement"); - // return nullptr; - // } - // binder::BindNodeVisitor visitor(accessor_, db_oid_); - - // std::unordered_map> query_params; - - // try { - // // TODO(Matt): I don't think the binder should need the database name. It's already bound in the - // ConnectionContext - // // binder::BindNodeVisitor visitor(accessor_, db_oid_); - // query_params = visitor.BindAndGetUDFParams(common::ManagedPointer(parse_result), udf_ast_context_); - // } catch (BinderException &b) { - // PARSER_LOG_DEBUG("Bad SQL statement"); - // return nullptr; - // } - - // // check to see if a record type can be bound to this - // type::TypeId type; - // auto ret = udf_ast_context_->GetVariableType(var_name, &type); - // if (!ret) { - // throw PARSER_EXCEPTION("PL/pgSQL parser : Didn't declare variable"); - // } - // if (type == type::TypeId::INVALID) { - // std::vector> elems; - // auto sel = parse_result->GetStatement(0).CastManagedPointerTo()->GetSelectColumns(); - // for (auto col : sel) { - // elems.emplace_back(col->GetAliasName(), col->GetReturnValueType()); - // } - // udf_ast_context_->SetRecordType(var_name, std::move(elems)); - // } - - // return std::unique_ptr( - // new SQLStmtAST(std::move(parse_result), std::move(var_name), std::move(query_params))); + auto sql_query = sql_stmt[kSqlstmt][kPLpgSQL_expr][kQuery].get(); + auto var_name = sql_stmt[kRow][kPLpgSQL_row][kFields][0][kName].get(); + + // TODO(Kyle): Should probably do something else on malformed SQL + auto parse_result = PostgresParser::BuildParseTree(sql_query.c_str()); + NOISEPAGE_ASSERT(parse_result != nullptr, "Malformed SQL Statement"); + + binder::BindNodeVisitor visitor{accessor_, db_oid_}; + std::unordered_map> query_params{}; + try { + // TODO(Matt): I don't think the binder should need the database name. + // It's already bound in the ConnectionContext binder::BindNodeVisitor visitor(accessor_, db_oid_); + query_params = visitor.BindAndGetUDFParams(common::ManagedPointer{parse_result}, udf_ast_context_); + } catch (BinderException &b) { + // TODO(Kyle): Same here + NOISEPAGE_ASSERT(false, "Malformed SQL Statement"); + } + + // Check to see if a record type can be bound to this + type::TypeId type{}; + auto ret = udf_ast_context_->GetVariableType(var_name, &type); + if (!ret) { + throw PARSER_EXCEPTION("PL/pgSQL parser: variable was not declared"); + } + + if (type == type::TypeId::INVALID) { + std::vector> elems{}; + const auto &select_columns = + parse_result->GetStatement(0).CastManagedPointerTo()->GetSelectColumns(); + elems.reserve(select_columns.size()); + for (const auto& col : select_columns) { + elems.emplace_back(col->GetAlias(), col->GetReturnValueType()); + } + udf_ast_context_->SetRecordType(var_name, std::move(elems)); + } + + return std::make_unique(std::move(parse_result), std::move(var_name), std::move(query_params)); } std::unique_ptr PLpgSQLParser::ParseDynamicSQL(const nlohmann::json &sql_stmt) { diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index 502b8fe48a..fb120b2d73 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -280,7 +280,7 @@ TrafficCopResult TrafficCop::ExecuteCreateStatement( } case network::QueryType::QUERY_CREATE_FUNCTION: { if (execution::sql::DDLExecutors::CreateFunctionExecutor( - physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { + physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { return {ResultType::COMPLETE, 0}; } break; diff --git a/src/util/query_exec_util.cpp b/src/util/query_exec_util.cpp index 7396fefdbd..1eacb0aaae 100644 --- a/src/util/query_exec_util.cpp +++ b/src/util/query_exec_util.cpp @@ -251,7 +251,7 @@ bool QueryExecUtil::ExecuteQuery(const std::string &statement, TupleFunction tup auto exec_ctx = std::make_unique( db_oid_, txn, callback, schema, common::ManagedPointer(accessor), exec_settings, metrics, DISABLED, DISABLED); - // TODO(Kyle): Should probably write a helper for this functionality + // Must translate the ConstantValueExpressions to opaque sql::Val std::vector> value_params{}; value_params.reserve(params->size()); std::transform(params->cbegin(), params->cend(), std::back_inserter(value_params), From e7dd0d020edb5fcfa48c0f6e958a31a1d04f8f27 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 23 Apr 2021 10:22:20 -0400 Subject: [PATCH 024/139] documentation for FunctionContext --- .../execution/functions/function_context.h | 36 ++++++++++++------- src/parser/udf/udf_parser.cpp | 3 +- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/src/include/execution/functions/function_context.h b/src/include/execution/functions/function_context.h index 39018160c9..2928366a7b 100644 --- a/src/include/execution/functions/function_context.h +++ b/src/include/execution/functions/function_context.h @@ -15,7 +15,7 @@ namespace noisepage::execution::functions { /** - * @brief Stores execution and type information about a stored procedure + * @brief Stores execution and type information about a stored procedure. */ class FunctionContext { public: @@ -31,6 +31,7 @@ class FunctionContext { args_type_(std::move(args_type)), is_builtin_{false}, is_exec_ctx_required_{false} {} + /** * Creates a FunctionContext object for a builtin function * @param func_name Name of function @@ -48,6 +49,17 @@ class FunctionContext { builtin_{builtin}, is_exec_ctx_required_{is_exec_ctx_required} {} + /** + * Creates a FunctionContext object for a non-builtin function. + * @param func_name Name of function= + * @param func_ret_type Return type of function + * @param arg_types Vector of argument types + * @param ast_region The region associated with the AST context + * @param ast_context The AST context for the function + * @param file The AST file + * @param is_exec_ctx_required Flag indicating whether an + * execution context is required for this function + */ FunctionContext(std::string func_name, type::TypeId func_ret_type, std::vector &&args_type, std::unique_ptr ast_region, std::unique_ptr ast_context, ast::File *file, bool is_exec_ctx_required = true) @@ -61,28 +73,28 @@ class FunctionContext { file_{file} {} /** - * @return The name of the function represented by this context object + * @return The name of the function represented by this context object. */ const std::string &GetFunctionName() const { return func_name_; } /** - * @return The vector of type arguments of the function represented by this context object + * @return The vector of type arguments of the function represented by this context object. */ const std::vector &GetFunctionArgsType() const { return args_type_; } /** - * Gets the return type of the function represented by this object - * @return return type of this function + * Gets the return type of the function represented by this object. + * @return The return type of this function. */ type::TypeId GetFunctionReturnType() const { return func_ret_type_; } /** - * @return true iff this represents a builtin function + * @return `true` if this represents a builtin function, `false` otherwise. */ bool IsBuiltin() const { return is_builtin_; } /** - * @return returns what builtin function this represents + * @return The builtin function this procedure represents. */ ast::Builtin GetBuiltin() const { NOISEPAGE_ASSERT(IsBuiltin(), "Getting a builtin from a non-builtin function"); @@ -90,7 +102,7 @@ class FunctionContext { } /** - * @return returns if this function requires an execution context + * @return `true` if this function requires an execution context, `false` otherwise. */ bool IsExecCtxRequired() const { NOISEPAGE_ASSERT(IsBuiltin(), "IsExecCtxRequired is only valid or a builtin function"); @@ -98,7 +110,7 @@ class FunctionContext { } /** - * @return returns the main functiondecl of this udf (to be used only if not builtin) + * @return The main functiondecl of this UDF. */ common::ManagedPointer GetMainFunctionDecl() const { NOISEPAGE_ASSERT(!IsBuiltin(), "Getting a non-builtin from a builtin function"); @@ -107,7 +119,7 @@ class FunctionContext { } /** - * @return returns the file with the functiondecl and supporting decls (to be used only if not builtin) + * @return The file with the functiondecl and supporting decls. */ ast::File *GetFile() const { NOISEPAGE_ASSERT(!IsBuiltin(), "Getting a non-builtin from a builtin function"); @@ -115,10 +127,10 @@ class FunctionContext { } /** - * TODO(Kyle): Document. + * @return The AST context for this procedure. */ ast::Context *GetASTContext() const { - NOISEPAGE_ASSERT(!IsBuiltin(), "Getting a non-builtin from a builtin function"); + NOISEPAGE_ASSERT(!IsBuiltin(), "No AST Context associated with builtin function"); return ast_context_.get(); } diff --git a/src/parser/udf/udf_parser.cpp b/src/parser/udf/udf_parser.cpp index ba484f016a..877b33bd08 100644 --- a/src/parser/udf/udf_parser.cpp +++ b/src/parser/udf/udf_parser.cpp @@ -239,7 +239,6 @@ std::unique_ptr PLpgSQLParser::ParseFor(const nlohmann::json &loop) { return std::unique_ptr(new ForStmtAST(std::move(var_vec), std::move(parse_result), std::move(body_stmt))); } -// TODO(Kyle): Implement std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql_stmt) { auto sql_query = sql_stmt[kSqlstmt][kPLpgSQL_expr][kQuery].get(); auto var_name = sql_stmt[kRow][kPLpgSQL_row][kFields][0][kName].get(); @@ -271,7 +270,7 @@ std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql_stmt) const auto &select_columns = parse_result->GetStatement(0).CastManagedPointerTo()->GetSelectColumns(); elems.reserve(select_columns.size()); - for (const auto& col : select_columns) { + for (const auto &col : select_columns) { elems.emplace_back(col->GetAlias(), col->GetReturnValueType()); } udf_ast_context_->SetRecordType(var_name, std::move(elems)); From eb898d8ef06ebde485acc8b5860c37c66dbffcbb Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 23 Apr 2021 10:44:34 -0400 Subject: [PATCH 025/139] address some TODO in udf parser and AST context --- .../execution/ast/udf/udf_ast_context.h | 2 +- src/parser/udf/udf_parser.cpp | 111 +++++++----------- 2 files changed, 46 insertions(+), 67 deletions(-) diff --git a/src/include/execution/ast/udf/udf_ast_context.h b/src/include/execution/ast/udf/udf_ast_context.h index 05714d43e2..def6df89d7 100644 --- a/src/include/execution/ast/udf/udf_ast_context.h +++ b/src/include/execution/ast/udf/udf_ast_context.h @@ -13,7 +13,7 @@ class UDFASTContext { public: UDFASTContext() {} - void SetVariableType(std::string &var, type::TypeId type) { symbol_table_[var] = type; } + void SetVariableType(const std::string &var, type::TypeId type) { symbol_table_[var] = type; } bool GetVariableType(const std::string &var, type::TypeId *type) { auto it = symbol_table_.find(var); diff --git a/src/parser/udf/udf_parser.cpp b/src/parser/udf/udf_parser.cpp index 877b33bd08..7022ab2726 100644 --- a/src/parser/udf/udf_parser.cpp +++ b/src/parser/udf/udf_parser.cpp @@ -2,15 +2,11 @@ #include "binder/bind_node_visitor.h" #include "execution/ast/udf/udf_ast_nodes.h" -#include "loggers/parser_logger.h" #include "parser/udf/udf_parser.h" #include "libpg_query/pg_query.h" #include "nlohmann/json.hpp" -// TODO(Kyle): This whole file needs documentation... - -// TODO(Kyle): Do we want to put UDF parsing in its own namespace? namespace noisepage { namespace parser { namespace udf { @@ -18,39 +14,40 @@ namespace udf { using namespace nlohmann; using namespace execution::ast::udf; -// TODO(Kyle): constexpr -// TODO(Kyle): Define elsewhere? -const std::string kFunctionList = "FunctionList"; -const std::string kDatums = "datums"; -const std::string kPLpgSQL_var = "PLpgSQL_var"; -const std::string kRefname = "refname"; -const std::string kDatatype = "datatype"; -const std::string kDefaultVal = "default_val"; -const std::string kPLpgSQL_type = "PLpgSQL_type"; -const std::string kTypname = "typname"; -const std::string kAction = "action"; -const std::string kPLpgSQL_function = "PLpgSQL_function"; -const std::string kBody = "body"; -const std::string kPLpgSQL_stmt_block = "PLpgSQL_stmt_block"; -const std::string kPLpgSQL_stmt_return = "PLpgSQL_stmt_return"; -const std::string kPLpgSQL_stmt_if = "PLpgSQL_stmt_if"; -const std::string kPLpgSQL_stmt_while = "PLpgSQL_stmt_while"; -const std::string kPLpgSQL_stmt_fors = "PLpgSQL_stmt_fors"; -const std::string kCond = "cond"; -const std::string kThenBody = "then_body"; -const std::string kElseBody = "else_body"; -const std::string kExpr = "expr"; -const std::string kQuery = "query"; -const std::string kPLpgSQL_expr = "PLpgSQL_expr"; -const std::string kPLpgSQL_stmt_assign = "PLpgSQL_stmt_assign"; -const std::string kVarno = "varno"; -const std::string kPLpgSQL_stmt_execsql = "PLpgSQL_stmt_execsql"; -const std::string kSqlstmt = "sqlstmt"; -const std::string kRow = "row"; -const std::string kFields = "fields"; -const std::string kName = "name"; -const std::string kPLpgSQL_row = "PLpgSQL_row"; -const std::string kPLpgSQL_stmt_dynexecute = "PLpgSQL_stmt_dynexecute"; +/** + * @brief The identifiers used as keys in the parse tree. + */ +static const std::string kFunctionList = "FunctionList"; +static const std::string kDatums = "datums"; +static const std::string kPLpgSQL_var = "PLpgSQL_var"; +static const std::string kRefname = "refname"; +static const std::string kDatatype = "datatype"; +static const std::string kDefaultVal = "default_val"; +static const std::string kPLpgSQL_type = "PLpgSQL_type"; +static const std::string kTypname = "typname"; +static const std::string kAction = "action"; +static const std::string kPLpgSQL_function = "PLpgSQL_function"; +static const std::string kBody = "body"; +static const std::string kPLpgSQL_stmt_block = "PLpgSQL_stmt_block"; +static const std::string kPLpgSQL_stmt_return = "PLpgSQL_stmt_return"; +static const std::string kPLpgSQL_stmt_if = "PLpgSQL_stmt_if"; +static const std::string kPLpgSQL_stmt_while = "PLpgSQL_stmt_while"; +static const std::string kPLpgSQL_stmt_fors = "PLpgSQL_stmt_fors"; +static const std::string kCond = "cond"; +static const std::string kThenBody = "then_body"; +static const std::string kElseBody = "else_body"; +static const std::string kExpr = "expr"; +static const std::string kQuery = "query"; +static const std::string kPLpgSQL_expr = "PLpgSQL_expr"; +static const std::string kPLpgSQL_stmt_assign = "PLpgSQL_stmt_assign"; +static const std::string kVarno = "varno"; +static const std::string kPLpgSQL_stmt_execsql = "PLpgSQL_stmt_execsql"; +static const std::string kSqlstmt = "sqlstmt"; +static const std::string kRow = "row"; +static const std::string kFields = "fields"; +static const std::string kName = "name"; +static const std::string kPLpgSQL_row = "PLpgSQL_row"; +static const std::string kPLpgSQL_stmt_dynexecute = "PLpgSQL_stmt_dynexecute"; std::unique_ptr PLpgSQLParser::ParsePLpgSQL(std::vector &¶m_names, std::vector &¶m_types, @@ -58,28 +55,25 @@ std::unique_ptr PLpgSQLParser::ParsePLpgSQL(std::vector ast_context) { auto result = pg_query_parse_plpgsql(func_body.c_str()); if (result.error) { - PARSER_LOG_INFO("PL/pgSQL parse error : {}", result.error->message); pg_query_free_plpgsql_parse_result(result); throw PARSER_EXCEPTION("PL/pgSQL parsing error"); } // The result is a list, we need to wrap it - std::string ast_json_str = "{ \"" + kFunctionList + "\" : " + std::string(result.plpgsql_funcs) + " }"; + const auto ast_json_str = "{ \"" + kFunctionList + "\" : " + std::string{result.plpgsql_funcs} + " }"; pg_query_free_plpgsql_parse_result(result); - std::istringstream ss(ast_json_str); - json ast_json; + std::istringstream ss{ast_json_str}; + json ast_json{}; ss >> ast_json; const auto function_list = ast_json[kFunctionList]; NOISEPAGE_ASSERT(function_list.is_array(), "Function list is not an array"); if (function_list.size() != 1) { - PARSER_LOG_DEBUG("PL/pgSQL error : Function list size %u", function_list.size()); throw PARSER_EXCEPTION("Function list has size other than 1"); } - size_t i = 0; - for (auto udf_name : param_names) { - // udf_ast_context_->AddVariable(udf_name); + std::size_t i{0}; + for (const auto &udf_name : param_names) { udf_ast_context_->SetVariableType(udf_name, param_types[i++]); } const auto function = function_list[0][kPLpgSQL_function]; @@ -94,7 +88,6 @@ std::unique_ptr PLpgSQLParser::ParseFunction(const nlohmann::json &bloc std::vector> stmts; - PARSER_LOG_DEBUG("Parsing Declarations"); NOISEPAGE_ASSERT(decl_list.is_array(), "Declaration list is not an array"); for (uint32_t i = 1; i < decl_list.size(); i++) { stmts.push_back(ParseDecl(decl_list[i])); @@ -118,8 +111,6 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nlohmann::json &block) for (uint32_t i = 0; i < block.size(); i++) { const auto stmt = block[i]; const auto stmt_names = stmt.items().begin(); - // NOISEPAGE_ASSERT(stmt_names->size() == 1, "Bad statement size"); - PARSER_LOG_DEBUG("Statement : {}", stmt_names.key()); if (stmt_names.key() == kPLpgSQL_stmt_return) { auto expr = ParseExprSQL(stmt[kPLpgSQL_stmt_return][kExpr][kPLpgSQL_expr][kQuery].get()); @@ -147,16 +138,13 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nlohmann::json &block) } else { throw PARSER_EXCEPTION("Statement type not supported"); } - NOISEPAGE_ASSERT(stmts.back() != nullptr, "It broke"); } - std::unique_ptr seq_stmt_ast(new SeqStmtAST(std::move(stmts))); - return std::move(seq_stmt_ast); + return std::make_unique(std::move(stmts)); } std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &decl) { const auto &decl_names = decl.items().begin(); - PARSER_LOG_DEBUG("Declaration : {}", decl_names.key()); if (decl_names.key() == kPLpgSQL_var) { auto var_name = decl[kPLpgSQL_var][kRefname].get(); @@ -167,9 +155,7 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &decl) { initial = ParseExprSQL(decl[kPLpgSQL_var][kDefaultVal][kPLpgSQL_expr][kQuery].get()); } - PARSER_LOG_INFO("Registering type {0}: {1}", var_name.c_str(), type.c_str()); - - type::TypeId temp_type; + type::TypeId temp_type{}; if (udf_ast_context_->GetVariableType(var_name, &temp_type)) { return std::unique_ptr(new DeclStmtAST(var_name, temp_type, std::move(initial))); } @@ -204,7 +190,6 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &decl) { } std::unique_ptr PLpgSQLParser::ParseIf(const nlohmann::json &branch) { - PARSER_LOG_DEBUG("ParseIf"); auto cond_expr = ParseExprSQL(branch[kCond][kPLpgSQL_expr][kQuery].get()); auto then_stmt = ParseBlock(branch[kThenBody]); std::unique_ptr else_stmt = nullptr; @@ -215,19 +200,15 @@ std::unique_ptr PLpgSQLParser::ParseIf(const nlohmann::json &branch) { } std::unique_ptr PLpgSQLParser::ParseWhile(const nlohmann::json &loop) { - PARSER_LOG_DEBUG("ParseWhile"); auto cond_expr = ParseExprSQL(loop[kCond][kPLpgSQL_expr][kQuery].get()); auto body_stmt = ParseBlock(loop[kBody]); return std::unique_ptr(new WhileStmtAST(std::move(cond_expr), std::move(body_stmt))); } std::unique_ptr PLpgSQLParser::ParseFor(const nlohmann::json &loop) { - PARSER_LOG_DEBUG("ParseFor"); auto sql_query = loop[kQuery][kPLpgSQL_expr][kQuery].get(); - ; auto parse_result = PostgresParser::BuildParseTree(sql_query.c_str()); if (parse_result == nullptr) { - PARSER_LOG_DEBUG("Bad SQL statement"); return nullptr; } auto body_stmt = ParseBlock(loop[kBody]); @@ -243,9 +224,10 @@ std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql_stmt) auto sql_query = sql_stmt[kSqlstmt][kPLpgSQL_expr][kQuery].get(); auto var_name = sql_stmt[kRow][kPLpgSQL_row][kFields][0][kName].get(); - // TODO(Kyle): Should probably do something else on malformed SQL auto parse_result = PostgresParser::BuildParseTree(sql_query.c_str()); - NOISEPAGE_ASSERT(parse_result != nullptr, "Malformed SQL Statement"); + if (parse_result == nullptr) { + return nullptr; + } binder::BindNodeVisitor visitor{accessor_, db_oid_}; std::unordered_map> query_params{}; @@ -254,8 +236,7 @@ std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql_stmt) // It's already bound in the ConnectionContext binder::BindNodeVisitor visitor(accessor_, db_oid_); query_params = visitor.BindAndGetUDFParams(common::ManagedPointer{parse_result}, udf_ast_context_); } catch (BinderException &b) { - // TODO(Kyle): Same here - NOISEPAGE_ASSERT(false, "Malformed SQL Statement"); + return nullptr; } // Check to see if a record type can be bound to this @@ -280,14 +261,12 @@ std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql_stmt) } std::unique_ptr PLpgSQLParser::ParseDynamicSQL(const nlohmann::json &sql_stmt) { - PARSER_LOG_DEBUG("ParseDynamicSQL"); auto sql_expr = ParseExprSQL(sql_stmt[kQuery][kPLpgSQL_expr][kQuery].get()); auto var_name = sql_stmt[kRow][kPLpgSQL_row][kFields][0][kName].get(); return std::unique_ptr(new DynamicSQLStmtAST(std::move(sql_expr), std::move(var_name))); } std::unique_ptr PLpgSQLParser::ParseExprSQL(const std::string expr_sql_str) { - PARSER_LOG_DEBUG("Parsing Expr SQL : {}", expr_sql_str.c_str()); auto stmt_list = PostgresParser::BuildParseTree(expr_sql_str.c_str()); if (stmt_list == nullptr) { return nullptr; From 6ece761b6a4578bb0f1d605926d8cee2af5ceea7 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 23 Apr 2021 13:46:08 -0400 Subject: [PATCH 026/139] wip on bytecode generation --- src/execution/vm/bytecode_generator.cpp | 106 ++++++++---------- .../execution/vm/bytecode_function_info.h | 17 ++- 2 files changed, 58 insertions(+), 65 deletions(-) diff --git a/src/execution/vm/bytecode_generator.cpp b/src/execution/vm/bytecode_generator.cpp index 678e447d12..741bc249f9 100644 --- a/src/execution/vm/bytecode_generator.cpp +++ b/src/execution/vm/bytecode_generator.cpp @@ -221,72 +221,58 @@ void BytecodeGenerator::VisitFunctionDecl(ast::FunctionDecl *node) { Visit(node->Function()); } - // TODO(Kyle): what is this doing? - for (auto f : func_info->actions_) { + // Execute the deferred actions for the function + for (auto &f : func_info->actions_) { f(); } } void BytecodeGenerator::VisitLambdaExpr(ast::LambdaExpr *node) { // TODO(Kyle): Implement. - throw NOT_IMPLEMENTED_EXCEPTION("VisitLambdaExpr Not Implemented"); - // // The function's TPL type - // auto *func_type = node->GetFunctionLitExpr()->GetType()->As(); - - // // Allocate the function - // // func_type->RegisterCapture(); - // if(!GetExecutionResult()->HasDestination()){ - // return; - // } - // auto captures = GetCurrentFunction()->NewLocal(node->GetCaptureStructType(), node->GetName().GetString() + - // "captures"); auto fields = node->GetCaptureStructType()->As()->GetFieldsWithoutPadding(); - // // auto &locals = GetCurrentFunction()->GetLocals(); - // for(size_t i = 0;i < fields.size() - 1;i++){ - // auto field = fields[i]; - // ast::IdentifierExpr ident(node->Position(), field.name_); - // ident.SetType(field.type_->GetPointeeType()); - // auto local = VisitExpressionForLValue(&ident); - // // auto local_it = std::find_if(locals.begin(), locals.end(), [=](const auto &loc){ return loc.GetName() == - // field.name_.GetString();}); - // // bool is_capture = false; - // // LocalVar local; - // // if(local_it == locals.end()){ - // // // should be inside captures - // // NOISEPAGE_ASSERT(GetCurrentFunction()->IsLambda(), "not lambda and local to capture not found"); - // // is_capture = true; - // // auto caller_captures = GetCurrentFunction()->GetFuncType()->GetCapturesType()->GetFieldsWithoutPadding(); - // // - // // auto cap_it = std::find_if(caller_captures.begin(), caller_captures.end(), - // // [=](const auto &loc){ return loc.GetName() == field.name_.GetString();}); - // // NOISEPAGE_ASSERT(cap_it != caller_captures.end(), "local to capture straight up not found"); - // // GetEmitter()-> - // // } - // LocalVar fieldvar = GetCurrentFunction()->NewLocal( - // fields[i].type_->PointerTo(), ""); - // GetEmitter()->EmitLea(fieldvar, captures.AddressOf(), - // node->GetCaptureStructType() - // ->As()->GetOffsetOfFieldByName(fields[i].name_)); - // GetEmitter()->EmitAssign(Bytecode::Assign8, fieldvar.ValueOf(), local); - // } - - // GetEmitter()->EmitAssign(Bytecode::Assign8, GetExecutionResult()->GetDestination(), captures.AddressOf()); - // FunctionInfo *func_info = AllocateFunc(node->GetName().GetString(), func_type); - // GetCurrentFunction()->DeferAction([=](){ - // func_info->captures_ = captures; - // func_info->is_lambda_ = true; - // { - // // Visit the body of the function. We use this handy scope object to track - // // the start and end position of this function's bytecode in the module's - // // bytecode array. Upon destruction, the scoped class will set the bytecode - // // range in the function. - // EnterFunction(func_info->GetId()); - // BytecodePositionScope position_scope(this, func_info); - // Visit(node->GetFunctionLitExpr()->Body()); - // } - // for(auto f : func_info->actions_){ - // f(); - // } - // }); + // The function's TPL type + auto *func_type = node->GetFunctionLitExpr()->GetType()->As(); + + // Allocate the function + if (!GetExecutionResult()->HasDestination()) { + return; + } + auto captures = + GetCurrentFunction()->NewLocal(node->GetCaptureStructType(), node->GetName().GetString() + "captures"); + auto fields = node->GetCaptureStructType()->As()->GetFieldsWithoutPadding(); + for (size_t i = 0; i < fields.size() - 1; i++) { + auto field = fields[i]; + ast::IdentifierExpr ident(node->Position(), field.name_); + ident.SetType(field.type_->GetPointeeType()); + auto local = VisitExpressionForLValue(&ident); + + LocalVar fieldvar = GetCurrentFunction()->NewLocal(fields[i].type_->PointerTo(), ""); + GetEmitter()->EmitLea(fieldvar, captures.AddressOf(), + node->GetCaptureStructType()->As()->GetOffsetOfFieldByName(fields[i].name_)); + GetEmitter()->EmitAssign(Bytecode::Assign8, fieldvar.ValueOf(), local); + } + + GetEmitter()->EmitAssign(Bytecode::Assign8, GetExecutionResult()->GetDestination(), captures.AddressOf()); + FunctionInfo *func_info = AllocateFunc(node->GetName().GetString(), func_type); + + // Create a new deferred action for the current function + // that visits the body of the lambda; this actions is subsequently + // executed when the function declaration itself is visited + GetCurrentFunction()->DeferAction([=]() { + func_info->captures_ = captures; + func_info->is_lambda_ = true; + { + // Visit the body of the function. We use this handy scope object to track + // the start and end position of this function's bytecode in the module's + // bytecode array. Upon destruction, the scoped class will set the bytecode + // range in the function. + EnterFunction(func_info->GetId()); + BytecodePositionScope position_scope(this, func_info); + Visit(node->GetFunctionLitExpr()->Body()); + } + for (auto &f : func_info->actions_) { + f(); + } + }); } // TODO(Kyle): Do we need a VisitLambdaDecl()? diff --git a/src/include/execution/vm/bytecode_function_info.h b/src/include/execution/vm/bytecode_function_info.h index 19cbd577e8..9eeb1f2f22 100644 --- a/src/include/execution/vm/bytecode_function_info.h +++ b/src/include/execution/vm/bytecode_function_info.h @@ -289,12 +289,19 @@ class FunctionInfo { uint32_t GetParamsCount() const noexcept { return num_params_; } /** - * TODO(Kyle): this. + * @brief Defer an action for the current function. + * + * This functionality is used for TPL lambda expressions. + * When we visit a lambda expression in the nody of the + * current function, we defer an action that in turn visits + * the body of the lambda. This action is evaluated when we + * later visit the declaration for the function itself. */ void DeferAction(const std::function action) { actions_.push_back(action); } /** - * TODO(Kyle): this. + * @return `true` if the TBC function represented by this object + * is generated by a TPL lambda, `false` otherwise. */ bool IsLambda() const { return is_lambda_; } @@ -313,13 +320,13 @@ class FunctionInfo { // Allocate a new local variable in the function. LocalVar NewLocal(ast::Type *type, const std::string &name, LocalInfo::Kind kind); - // TODO(Kyle): this + // The captures in the event this function is a TPL lambda. LocalVar captures_; - // TODO(Kyle): this + // Indicates whether this TBC function is generated by a TPL lambda. bool is_lambda_{false}; - // TODO(Kyle): this + // The collection of deferred actions if this function is a TPL lambda. std::vector> actions_; private: From 684ffbc8cc9fd646092f29363d7154a08a3a8de3 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 23 Apr 2021 14:00:21 -0400 Subject: [PATCH 027/139] fix up loop builder in bytecode generation --- src/execution/vm/bytecode_generator.cpp | 8 +++----- src/execution/vm/control_flow_builders.cpp | 5 +++++ src/include/execution/vm/control_flow_builders.h | 11 +++++++---- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/execution/vm/bytecode_generator.cpp b/src/execution/vm/bytecode_generator.cpp index 741bc249f9..18ceceb258 100644 --- a/src/execution/vm/bytecode_generator.cpp +++ b/src/execution/vm/bytecode_generator.cpp @@ -191,11 +191,9 @@ void BytecodeGenerator::VisitForStmt(ast::ForStmt *node) { } void BytecodeGenerator::VisitBreakStmt(ast::BreakStmt *node) { - // TODO(Kyle): Implement. - throw NOT_IMPLEMENTED_EXCEPTION("VisitBreakStmt Not Implemented"); - // if(current_loop_ != nullptr && current_loop_->GetPrev() != nullptr) { - // current_loop_->GetPrev()->Break(); - // } + if (current_loop_ != nullptr && current_loop_->GetPrevLoop() != nullptr) { + current_loop_->GetPrevLoop()->Break(); + } } void BytecodeGenerator::VisitForInStmt(UNUSED_ATTRIBUTE ast::ForInStmt *node) { diff --git a/src/execution/vm/control_flow_builders.cpp b/src/execution/vm/control_flow_builders.cpp index 340747ffa5..835f03abd0 100644 --- a/src/execution/vm/control_flow_builders.cpp +++ b/src/execution/vm/control_flow_builders.cpp @@ -40,6 +40,11 @@ void LoopBuilder::BindContinueTarget() { GetGenerator()->GetEmitter()->Bind(GetContinueLabel()); } +LoopBuilder *LoopBuilder::GetPrevLoop() const { + NOISEPAGE_ASSERT(prev_loop_ != nullptr, "Attempt to access a non-existent outer loop"); + return prev_loop_; +} + // --------------------------------------------------------- // If-Then-Else Builders // --------------------------------------------------------- diff --git a/src/include/execution/vm/control_flow_builders.h b/src/include/execution/vm/control_flow_builders.h index 68eca8dfff..ccc68a0e82 100644 --- a/src/include/execution/vm/control_flow_builders.h +++ b/src/include/execution/vm/control_flow_builders.h @@ -81,11 +81,9 @@ class LoopBuilder : public BreakableBlockBuilder { /** * Construct a loop builder. - * - * TODO(Kyle): Why was this construtor removed? - * * @param generator The generator the loop writes. - * @param prev + * @param prev The previous (outer) loop in the current + * code generation context */ explicit LoopBuilder(BytecodeGenerator *generator, LoopBuilder *prev = nullptr) : BreakableBlockBuilder(generator), prev_loop_(prev) {} @@ -115,6 +113,11 @@ class LoopBuilder : public BreakableBlockBuilder { */ void BindContinueTarget(); + /** + * Get the previous (outer) loop. + */ + LoopBuilder *GetPrevLoop() const; + private: /** @return The label associated with the header of the loop. */ BytecodeLabel *GetHeaderLabel() { return &header_label_; } From c63256d135d21355a06aae40ced0e8c7224fb095 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 23 Apr 2021 14:47:42 -0400 Subject: [PATCH 028/139] misc fixes and cleanup, looking primarily at semantic analysis --- src/execution/ast/ast_clone.cpp | 3 -- src/execution/compiler/executable_query.cpp | 8 ++-- src/execution/sema/sema_expr.cpp | 2 +- src/execution/sema/sema_stmt.cpp | 27 ++++++----- src/execution/vm/bytecode_emitter.cpp | 2 +- src/execution/vm/bytecode_generator.cpp | 5 +-- src/include/execution/ast/ast.h | 45 +++++++++++++++---- .../execution/compiler/executable_query.h | 15 +------ .../expression/expression_translator.h | 14 +++++- src/include/execution/vm/bytecode_emitter.h | 27 +++++++---- 10 files changed, 88 insertions(+), 60 deletions(-) diff --git a/src/execution/ast/ast_clone.cpp b/src/execution/ast/ast_clone.cpp index 48579d2b82..858f584441 100644 --- a/src/execution/ast/ast_clone.cpp +++ b/src/execution/ast/ast_clone.cpp @@ -14,9 +14,6 @@ namespace noisepage::execution::ast { -/** - * TODO(Kyle): Document. - */ class AstCloneImpl : public AstVisitor { public: explicit AstCloneImpl(AstNode *root, AstNodeFactory *factory, Context *old_context, Context *new_context, diff --git a/src/execution/compiler/executable_query.cpp b/src/execution/compiler/executable_query.cpp index cf32ac204d..353c97d268 100644 --- a/src/execution/compiler/executable_query.cpp +++ b/src/execution/compiler/executable_query.cpp @@ -189,10 +189,10 @@ void ExecutableQuery::Run(common::ManagedPointer exec_ct } std::vector ExecutableQuery::GetDecls() const { - std::vector decls; - for (auto &f : fragments_) { - auto frag_decls = f->GetFile()->Declarations(); - decls.insert(decls.end(), frag_decls.begin(), frag_decls.end()); + std::vector decls{}; + for (const auto &f : fragments_) { + const auto &frag_decls = f->GetFile()->Declarations(); + decls.insert(decls.end(), frag_decls.cbegin(), frag_decls.cend()); } return decls; } diff --git a/src/execution/sema/sema_expr.cpp b/src/execution/sema/sema_expr.cpp index bb5284298f..fd6a85c2fc 100644 --- a/src/execution/sema/sema_expr.cpp +++ b/src/execution/sema/sema_expr.cpp @@ -249,7 +249,7 @@ void Sema::VisitLambdaExpr(ast::LambdaExpr *node) { SourcePosition(), GetContext()->GetIdentifier("lambda" + std::to_string(node->Position().line_)), struct_type_repr); VisitStructDecl(struct_decl); - node->capture_type_ = Resolve(struct_type_repr); + node->SetCaptureStructType(Resolve(struct_type_repr)); node->SetType(ast::LambdaType::Get(Resolve(node->GetFunctionLitExpr()->TypeRepr())->As())); // GetCurrentScope()->Declare(struct_decl->Name(), node->capture_type_); diff --git a/src/execution/sema/sema_stmt.cpp b/src/execution/sema/sema_stmt.cpp index c7790fef4a..d3a7ec607b 100644 --- a/src/execution/sema/sema_stmt.cpp +++ b/src/execution/sema/sema_stmt.cpp @@ -85,21 +85,20 @@ void Sema::VisitForStmt(ast::ForStmt *node) { Visit(node->Body()); } -// TODO(Kyle): Implement. void Sema::VisitBreakStmt(ast::BreakStmt *node) { - // look for a loop in my scope stack - // auto scope = GetCurrentScope(); - // bool found_loop = false; - // while(scope != nullptr){ - // found_loop |= scope->GetKind() == Scope::Kind::Loop; - // if(found_loop){ - // break; - // } - // scope = scope->Outer(); - // } - // if(!found_loop){ - // error_reporter_->Report(node->Position(), ErrorMessages::kNoScopeToBreak); - // } + // Look for a loop in my scope stack + auto scope = GetCurrentScope(); + bool found_loop = false; + while (scope != nullptr) { + found_loop |= scope->GetKind() == Scope::Kind::Loop; + if (found_loop) { + break; + } + scope = scope->Outer(); + } + if (!found_loop) { + error_reporter_->Report(node->Position(), ErrorMessages::kNoScopeToBreak); + } } void Sema::VisitForInStmt(ast::ForInStmt *node) { NOISEPAGE_ASSERT(false, "Not supported"); } diff --git a/src/execution/vm/bytecode_emitter.cpp b/src/execution/vm/bytecode_emitter.cpp index c8bc2f3d03..a75257b31a 100644 --- a/src/execution/vm/bytecode_emitter.cpp +++ b/src/execution/vm/bytecode_emitter.cpp @@ -66,7 +66,7 @@ void BytecodeEmitter::EmitCall(FunctionId func_id, const std::vector & } } -std::function BytecodeEmitter::DeferedEmitCall(const std::vector ¶ms) { +std::function BytecodeEmitter::DeferredEmitCall(const std::vector ¶ms) { NOISEPAGE_ASSERT(Bytecodes::GetNthOperandSize(Bytecode::Call, 1) == OperandSize::Short, "Expected argument count to be 2-byte short"); NOISEPAGE_ASSERT(params.size() < std::numeric_limits::max(), "Too many parameters!"); diff --git a/src/execution/vm/bytecode_generator.cpp b/src/execution/vm/bytecode_generator.cpp index 18ceceb258..89fdebeff5 100644 --- a/src/execution/vm/bytecode_generator.cpp +++ b/src/execution/vm/bytecode_generator.cpp @@ -226,7 +226,6 @@ void BytecodeGenerator::VisitFunctionDecl(ast::FunctionDecl *node) { } void BytecodeGenerator::VisitLambdaExpr(ast::LambdaExpr *node) { - // TODO(Kyle): Implement. // The function's TPL type auto *func_type = node->GetFunctionLitExpr()->GetType()->As(); @@ -273,8 +272,6 @@ void BytecodeGenerator::VisitLambdaExpr(ast::LambdaExpr *node) { }); } -// TODO(Kyle): Do we need a VisitLambdaDecl()? - void BytecodeGenerator::VisitLambdaTypeRepr(ast::LambdaTypeRepr *node) { UNREACHABLE("Should not visit type-representation nodes!"); } @@ -3463,7 +3460,7 @@ void BytecodeGenerator::VisitRegularCallExpr(ast::CallExpr *call) { // Emit call const auto func_id = LookupFuncIdByName(call->GetFuncName().GetData()); if (func_id == FunctionInfo::K_INVALID_FUNC_ID) { - auto action = GetEmitter()->DeferedEmitCall(params); + auto action = GetEmitter()->DeferredEmitCall(params); deferred_function_create_actions_[call->GetFuncName().GetString()].push_back(action); return; } diff --git a/src/include/execution/ast/ast.h b/src/include/execution/ast/ast.h index 6a0ded3e21..1486884199 100644 --- a/src/include/execution/ast/ast.h +++ b/src/include/execution/ast/ast.h @@ -1080,32 +1080,59 @@ class BinaryOpExpr : public Expr { /** * A lambda expression. - * TODO(Kyle): Document. */ class LambdaExpr : public Expr { public: + /** + * Construct + * @param pos source position + * @param func the associated function literal expression + * @param captures a collection of lambda captures + */ LambdaExpr(const SourcePosition &pos, FunctionLitExpr *func, util::RegionVector &&captures) - : Expr(Kind::LambdaExpr, pos), captures_{nullptr}, func_lit_(func), capture_idents_{std::move(captures)} {} + : Expr{Kind::LambdaExpr, pos}, func_lit_{func}, capture_idents_{std::move(captures)} {} - FunctionLitExpr *GetFunctionLitExpr() const { return func_lit_; } + /** + * @return The identifier for this lambda expression. + */ + const Identifier &GetName() const { return name_; } - ast::StructTypeRepr *GetCaptureStruct() const { return captures_; } + /** + * Set the name of this lambda expression. + * @param name The desired name. + */ + void SetName(Identifier name) { name_ = name; } + /** + * @return Get the capture struct type for this lambda expression. + */ ast::Type *GetCaptureStructType() const { return capture_type_; } - const Identifier &GetName() const { return name_; } + /** + * Set the capture struct type for this lambda expression. + * @param capture_type The desired type. + */ + void SetCaptureStructType(ast::Type *capture_type) { capture_type_ = capture_type; } - const util::RegionVector &GetCaptureIdents() const { return capture_idents_; } + /** + * @return The function literal expression associated with this lambda. + */ + FunctionLitExpr *GetFunctionLitExpr() const { return func_lit_; } - void SetName(Identifier name) { name_ = name; } + /** + * @return The identifiers for the captures of this lambda expression. + */ + const util::RegionVector &GetCaptureIdents() const { return capture_idents_; } private: friend class sema::Sema; - + // The identifier for the lambda expression. Identifier name_; - ast::StructTypeRepr *captures_; + // The type of the lambda captures struct. ast::Type *capture_type_; + // The associated function literal expression. FunctionLitExpr *func_lit_; + // The collection of identifers for lambda captures. util::RegionVector capture_idents_; }; diff --git a/src/include/execution/compiler/executable_query.h b/src/include/execution/compiler/executable_query.h index dfcdb4083d..899bda96dd 100644 --- a/src/include/execution/compiler/executable_query.h +++ b/src/include/execution/compiler/executable_query.h @@ -186,20 +186,7 @@ class ExecutableQuery { common::ManagedPointer GetQueryText() { return query_text_; } /** - * @return The functions. - */ - const std::vector GetFunctions() const { - // TODO(Kyle): string copying, figure out something better - std::vector ret{}; - for (auto &f : fragments_) { - auto fns = f->GetFunctions(); - ret.insert(ret.end(), fns.begin(), fns.end()); - } - return ret; - } - - /** - * @return TODO(Kyle): this. + * @return All of the declarations in the executable query. */ std::vector GetDecls() const; diff --git a/src/include/execution/compiler/expression/expression_translator.h b/src/include/execution/compiler/expression/expression_translator.h index d8130c4c0f..a91e5074cc 100644 --- a/src/include/execution/compiler/expression/expression_translator.h +++ b/src/include/execution/compiler/expression/expression_translator.h @@ -49,12 +49,22 @@ class ExpressionTranslator { virtual ast::Expr *DeriveValue(WorkContext *ctx, const ColumnValueProvider *provider) const = 0; /** - * TODO(Kyle): this + * Define all of the helper functions for this expression translator. + * + * The default implementation simply invokes the DefineHelperFunctions() + * method for each child of the current expression translator. + * + * @param decls The collection of function declarations. */ virtual void DefineHelperFunctions(util::RegionVector *decls); /** - * TODO(Kyle): this + * Define all of the helper structs for this expression translator. + * + * The default implementation simply invokes the DefineHelperStructs() + * method for each child of the current expression translator. + * + * @param decls The collection of struct declarations. */ virtual void DefineHelperStructs(util::RegionVector *decls); diff --git a/src/include/execution/vm/bytecode_emitter.h b/src/include/execution/vm/bytecode_emitter.h index 164b487864..c5e8ea2228 100644 --- a/src/include/execution/vm/bytecode_emitter.h +++ b/src/include/execution/vm/bytecode_emitter.h @@ -59,7 +59,7 @@ class BytecodeEmitter { // ------------------------------------------------------- /** - * Emit arbitrary assignment code + * Emit arbitrary assignment code. * @param bytecode assignment bytecode * @param dest destination variable * @param src source variable @@ -67,7 +67,10 @@ class BytecodeEmitter { void EmitAssign(Bytecode bytecode, LocalVar dest, LocalVar src); /** - * TODO(Kyle): this. + * Emit arbitrary assignment code. + * @param dest destination variable + * @param src source variable + * @param len length */ void EmitAssignN(LocalVar dest, LocalVar src, uint32_t len); @@ -167,15 +170,18 @@ class BytecodeEmitter { /** * Emit a function call - * @param func_id id of the function to call - * @param params parameters of the function + * @param func_id The ID of the function to call. + * @param params The parameters of the function. */ void EmitCall(FunctionId func_id, const std::vector ¶ms); /** - * TODO(Kyle): this. + * Create a function that emits a function call. + * @param params The parameters of the function. + * @return A new callable that, when invoked with a FunctionID, + * emits a fuinction call into the bytecode stream. */ - std::function DeferedEmitCall(const std::vector ¶ms); + std::function DeferredEmitCall(const std::vector ¶ms); /** * Emit a return bytecode @@ -427,7 +433,10 @@ class BytecodeEmitter { void EmitConcat(LocalVar ret, LocalVar exec_ctx, LocalVar inputs, uint32_t num_inputs); private: - /** Copy a scalar immediate value into the bytecode stream */ + /** + * Copy a scalar immediate value into the bytecode stream. + * @param val The scalar value to emit into the stream. + */ template auto EmitScalarValue(const T val) -> std::enable_if_t> { bytecode_->insert(bytecode_->end(), sizeof(T), 0); @@ -435,7 +444,9 @@ class BytecodeEmitter { } /** - * TODO(Kyle): this. + * Copy a scalar immediate value into the bytecode stream at specified index. + * @param val The scalar value to emit into the stream. + * @param index The index in the stream at which to emit the value. */ template auto EmitScalarValue(const T val, std::size_t index) -> std::enable_if_t> { From 59e5d8ad2173014e7460c55e8ea7569f686a4913 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 23 Apr 2021 15:06:56 -0400 Subject: [PATCH 029/139] address some issues in binder and sema --- src/binder/bind_node_visitor.cpp | 1 - src/execution/sema/sema_expr.cpp | 10 +++++----- src/execution/vm/module.cpp | 5 +---- src/include/binder/bind_node_visitor.h | 6 +++++- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 7a39f37f32..0e0d0868d3 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -566,7 +566,6 @@ void BindNodeVisitor::Visit(common::ManagedPointerGetVariableType(expr->GetColumnName(), &the_type)) { diff --git a/src/execution/sema/sema_expr.cpp b/src/execution/sema/sema_expr.cpp index fd6a85c2fc..b03dcab0d2 100644 --- a/src/execution/sema/sema_expr.cpp +++ b/src/execution/sema/sema_expr.cpp @@ -83,7 +83,7 @@ void Sema::VisitCallExpr(ast::CallExpr *node) { return; } - // TODO(Kyle): This seems weird + // Type checking already performed if (node->GetType() != nullptr) { return; } @@ -101,7 +101,7 @@ void Sema::VisitCallExpr(ast::CallExpr *node) { if (func_type == nullptr) { if (struct_type != nullptr) { func_type = struct_type->GetFunctionType(); - // TODO(Kyle): find a better way to see if sema has processed this already + // TODO(Kyle): Find a better way to see if sema has processed this already ast::IdentifierExpr *last_arg = nullptr; if (!node->Arguments().empty()) { last_arg = node->Arguments().back()->SafeAs(); @@ -117,9 +117,9 @@ void Sema::VisitCallExpr(ast::CallExpr *node) { } // Check argument count matches - // TODO(Kyle): Refactor this, gross. - if (!CheckArgCount( - node, struct_type != nullptr ? func_type->GetNumParams() - lambda_adjustment : func_type->GetNumParams())) { + const auto arg_count = + (struct_type != nullptr) ? func_type->GetNumParams() - lambda_adjustment : func_type->GetNumParams(); + if (!CheckArgCount(node, arg_count)) { return; } diff --git a/src/execution/vm/module.cpp b/src/execution/vm/module.cpp index 1aea32b09e..8eee118e5e 100644 --- a/src/execution/vm/module.cpp +++ b/src/execution/vm/module.cpp @@ -279,10 +279,7 @@ void Module::CompileToMachineCode() { // previous implementation. for (const auto &func_info : bytecode_module_->GetFunctionsInfo()) { auto *jit_function = jit_module_->GetFunctionPointer(func_info.GetName()); - // TODO(Kyle): Why is this OK now? - if (jit_function == nullptr) { - continue; - } + NOISEPAGE_ASSERT(jit_function != nullptr, "Function not found!"); functions_[func_info.GetId()].store(jit_function, std::memory_order_relaxed); } }); diff --git a/src/include/binder/bind_node_visitor.h b/src/include/binder/bind_node_visitor.h index 66315bec21..599aa7304a 100644 --- a/src/include/binder/bind_node_visitor.h +++ b/src/include/binder/bind_node_visitor.h @@ -51,7 +51,11 @@ class BindNodeVisitor final : public SqlNodeVisitor { ~BindNodeVisitor() final; /** - * TODO(Kyle): Document. + * Perform binding for a UDF. + * @param parse_result The result of parsing the UDF. + * @param udf_ast_context The AST context for the UDF. + * @return The map of UDF parameters: + * Column Name -> (Parameter Name, Parameter Index) */ std::unordered_map> BindAndGetUDFParams( common::ManagedPointer parse_result, From b540f436ba6ca5430d1dd567789cfda743105ae3 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 23 Apr 2021 16:58:50 -0400 Subject: [PATCH 030/139] updates to udf codegen and executable query --- src/execution/compiler/executable_query.cpp | 9 + src/execution/compiler/udf/udf_codegen.cpp | 317 +++++++++--------- .../execution/compiler/executable_query.h | 7 +- 3 files changed, 172 insertions(+), 161 deletions(-) diff --git a/src/execution/compiler/executable_query.cpp b/src/execution/compiler/executable_query.cpp index 353c97d268..569a0bb190 100644 --- a/src/execution/compiler/executable_query.cpp +++ b/src/execution/compiler/executable_query.cpp @@ -188,6 +188,15 @@ void ExecutableQuery::Run(common::ManagedPointer exec_ct } } +std::vector ExecutableQuery::GetFunctionNames() const { + std::vector function_names{}; + for (const auto &f : fragments_) { + const auto &frag_functions = f->GetFunctions(); + function_names.insert(function_names.end(), frag_functions.cbegin(), frag_functions.cend()); + } + return function_names; +} + std::vector ExecutableQuery::GetDecls() const { std::vector decls{}; for (const auto &f : fragments_) { diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 4a90c62a39..0e2418d008 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -5,6 +5,7 @@ #include "execution/ast/ast.h" #include "execution/ast/ast_clone.h" #include "execution/ast/context.h" +#include "planner/plannodes/output_schema.h" #include "execution/compiler/compilation_context.h" #include "execution/compiler/executable_query.h" @@ -488,180 +489,176 @@ void UDFCodegen::Visit(ast::udf::RetStmtAST *ast) { // TODO(Kyle): Implement void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { - throw NOT_IMPLEMENTED_EXCEPTION("Visit(SQLStmtAST*) Not Implemented"); - // needs_exec_ctx_ = true; - // auto exec_ctx = fb_->GetParameterByPosition(0); - // const auto query = common::ManagedPointer(ast->query); - - // // TODO(Matt): I don't think the binder should need the database name. It's already bound in the ConnectionContext - // binder::BindNodeVisitor visitor(common::ManagedPointer(accessor_), db_oid_); - - // TODO(Kyle): Implement - // // auto query_params = visitor.BindAndGetUDFParams(query, common::ManagedPointer(udf_ast_context_)); - // auto query_params = ast->udf_params; - // auto stats = optimizer::StatsStorage(); - - // std::unique_ptr plan = trafficcop::TrafficCopUtil::Optimize( - // accessor_->GetTxn(), common::ManagedPointer(accessor_), query, db_oid_, common::ManagedPointer(&stats), - // std::make_unique(), 1000000); - // // make lambda that just writes into this + // As soon as we encounter an embedded SQL statement, + // we know we need an execution context + needs_exec_ctx_ = true; + auto exec_ctx = fb_->GetParameterByPosition(0); + const auto query = common::ManagedPointer(ast->query); + + binder::BindNodeVisitor visitor(common::ManagedPointer(accessor_), db_oid_); + + auto &query_params = ast->udf_params; + + // NOTE(Kyle): Assumptions: + // - This is a valid optimizer timeout + // - No parameters are required for the call to Optimize() + + auto stats = optimizer::StatsStorage(); + const std::uint64_t optimizer_timeout = 1000000; + auto optimize_result = trafficcop::TrafficCopUtil::Optimize( + accessor_->GetTxn(), common::ManagedPointer(accessor_), query, db_oid_, common::ManagedPointer(&stats), + std::make_unique(), optimizer_timeout, nullptr); + + // Make a lambda that just writes into this + auto lam_var = codegen_->MakeFreshIdentifier("lamb"); + + auto plan = optimize_result->GetPlanNode(); + auto &cols = plan->GetOutputSchema()->GetColumns(); + + execution::util::RegionVector params{codegen_->GetAstContext()->GetRegion()}; + params.push_back(codegen_->MakeField( + exec_ctx->As()->Name(), + codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); + + std::size_t i{0}; + std::vector assignees{}; + execution::util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; + for (auto &col : cols) { + execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->var_name)->second); + type::TypeId udf_type{}; + udf_ast_context_->GetVariableType(ast->var_name, &udf_type); + if (udf_type == type::TypeId::INVALID) { + // Record type + auto &struct_vars = udf_ast_context_->GetRecordType(ast->var_name); + if (captures.empty()) { + captures.push_back(capture_var); + } + capture_var = codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(struct_vars[i].first)); + assignees.push_back(capture_var); + } else { + assignees.push_back(capture_var); + captures.push_back(capture_var); + } + auto *type = codegen_->TplType(execution::sql::GetTypeId(col.GetType())); - // auto lam_var = codegen_->MakeFreshIdentifier("lamb"); - // // NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, "Can't support non scalars yet!"); - // auto &cols = plan->GetOutputSchema()->GetColumns(); - // // auto &col = cols[0]; - // execution::util::RegionVector params(codegen_->GetAstContext()->GetRegion()); - // std::vector assignees; - // execution::util::RegionVector captures(codegen_->GetAstContext()->GetRegion()); - // size_t i = 0; - // params.push_back(codegen_->MakeField( - // exec_ctx->As()->Name(), - // codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); - // for (auto &col : cols) { - // execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->var_name)->second); - // type::TypeId udf_type; - // udf_ast_context_->GetVariableType(ast->var_name, &udf_type); - // if (udf_type == type::TypeId::INVALID) { - // // record type - // auto &struct_vars = udf_ast_context_->GetRecordType(ast->var_name); - // if (captures.empty()) { - // captures.push_back(capture_var); - // } - // capture_var = codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(struct_vars[i].first)); - // assignees.push_back(capture_var); - // } else { - // assignees.push_back(capture_var); - // captures.push_back(capture_var); - // } - // // auto capture_var = str_to_ident_.find(ast->var_name)->second; - // auto type = codegen_->TplType(execution::sql::GetTypeId(col.GetType())); + auto input_param = codegen_->MakeFreshIdentifier("input"); + params.push_back(codegen_->MakeField(input_param, type)); + i++; + } - // auto input_param = codegen_->MakeFreshIdentifier("input"); - // params.push_back(codegen_->MakeField(input_param, type)); - // i++; - // } + execution::ast::LambdaExpr *lambda_expr{}; + FunctionBuilder fn{codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)}; + { + for (auto j = 0UL; j < assignees.size(); ++j) { + auto capture_var = assignees[j]; + auto input_param = fn.GetParameterByPosition(j + 1); + fn.Append(codegen_->Assign(capture_var, input_param)); + } + } - // execution::ast::LambdaExpr *lambda_expr; - // FunctionBuilder fn(codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)); - // { - // for (size_t j = 0; j < assignees.size(); j++) { - // auto capture_var = assignees[j]; - // auto input_param = fn.GetParameterByPosition(j + 1); - // fn.Append(codegen_->Assign(capture_var, input_param)); - // } - // } + lambda_expr = fn.FinishLambda(std::move(captures)); + lambda_expr->SetName(lam_var); - // lambda_expr = fn.FinishLambda(std::move(captures)); - // lambda_expr->SetName(lam_var); + // We want to pass something down that will materialize the lambda function + // into lambda_expr and will also feed in a lambda_expr to the compiler + execution::exec::ExecutionSettings exec_settings{}; + const std::string dummy_query = ""; + auto exec_query = execution::compiler::CompilationContext::Compile( + *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, std::nullopt, + common::ManagedPointer(&dummy_query), lambda_expr, codegen_->GetAstContext()); - // // want to pass something down that will materialize the lambda function for me into lambda_expr and will - // // also feed in a lambda_expr to the compiler - // execution::exec::ExecutionSettings exec_settings{}; - // const std::string dummy_query = ""; - // auto exec_query = execution::compiler::CompilationContext::Compile( - // *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, - // common::ManagedPointer(&dummy_query), lambda_expr, codegen_->GetAstContext()); - // auto fns = exec_query->GetFunctions(); - // auto decls = exec_query->GetDecls(); + auto decls = exec_query->GetDecls(); + aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); - // aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); + fb_->Append( + codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLitExpr()->TypeRepr()), lambda_expr)); - // fb_->Append( - // codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLitExpr()->TypeRepr()), - // lambda_expr)); + // Make query state + auto query_state = codegen_->MakeFreshIdentifier("query_state"); + fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); - // // make query state - // auto query_state = codegen_->MakeFreshIdentifier("query_state"); - // fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); - // // set its execution context to whatever exec context was passed in here - // fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); - // std::vector>::iterator> sorted_vec; - // for (auto it = query_params.begin(); it != query_params.end(); it++) { - // sorted_vec.push_back(it); - // } + // Set its execution context to whatever exec context was passed in here + fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); + std::vector>::iterator> sorted_vec{}; + for (auto it = query_params.begin(); it != query_params.end(); it++) { + sorted_vec.push_back(it); + } - // std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second.second < y->second.second; - // }); for (auto entry : sorted_vec) { - // // TODO(order these dudes) - // type::TypeId type = type::TypeId::INVALID; - // execution::ast::Expr *expr = nullptr; - // if (entry->second.first.length() > 0) { - // auto &fields = udf_ast_context_->GetRecordType(entry->second.first); - // auto it = std::find_if(fields.begin(), fields.end(), [=](auto p) { return p.first == entry->first; }); - // type = it->second; - // expr = codegen_->AccessStructMember(codegen_->MakeExpr(str_to_ident_[entry->second.first]), - // codegen_->MakeIdentifier(entry->first)); - // } else { - // udf_ast_context_->GetVariableType(entry->first, &type); - // expr = codegen_->MakeExpr(str_to_ident_[entry->first]); - // } + std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second.second < y->second.second; }); + for (auto entry : sorted_vec) { + // TODO(Kyle): Order these + type::TypeId type = type::TypeId::INVALID; + execution::ast::Expr *expr = nullptr; + if (entry->second.first.length() > 0) { + auto &fields = udf_ast_context_->GetRecordType(entry->second.first); + auto it = std::find_if(fields.begin(), fields.end(), [=](auto p) { return p.first == entry->first; }); + type = it->second; + expr = codegen_->AccessStructMember(codegen_->MakeExpr(str_to_ident_[entry->second.first]), + codegen_->MakeIdentifier(entry->first)); + } else { + udf_ast_context_->GetVariableType(entry->first, &type); + expr = codegen_->MakeExpr(str_to_ident_[entry->first]); + } - // // NOISEPAGE_ASSERT(ret, "didn't find param in udf ast context"); - // execution::ast::Builtin builtin; - // switch (type) { - // case type::TypeId::BOOLEAN: - // builtin = execution::ast::Builtin::AddParamBool; - // break; - // case type::TypeId::TINYINT: - // builtin = execution::ast::Builtin::AddParamTinyInt; - // break; - // case type::TypeId::SMALLINT: - // builtin = execution::ast::Builtin::AddParamSmallInt; - // break; - // case type::TypeId::INTEGER: - // builtin = execution::ast::Builtin::AddParamInt; - // break; - // case type::TypeId::BIGINT: - // builtin = execution::ast::Builtin::AddParamBigInt; - // break; - // case type::TypeId::DECIMAL: - // builtin = execution::ast::Builtin::AddParamDouble; - // break; - // case type::TypeId::DATE: - // builtin = execution::ast::Builtin::AddParamDate; - // break; - // case type::TypeId::TIMESTAMP: - // builtin = execution::ast::Builtin::AddParamTimestamp; - // break; - // case type::TypeId::VARCHAR: - // builtin = execution::ast::Builtin::AddParamString; - // break; - // default: - // UNREACHABLE("Unsupported parameter type"); - // } - // fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, expr})); - // } - // // set param 1 - // // set param 2 - // // etc etc - // fb_->Append(codegen_->Assign( - // codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); + execution::ast::Builtin builtin{}; + switch (type) { + case type::TypeId::BOOLEAN: + builtin = execution::ast::Builtin::AddParamBool; + break; + case type::TypeId::TINYINT: + builtin = execution::ast::Builtin::AddParamTinyInt; + break; + case type::TypeId::SMALLINT: + builtin = execution::ast::Builtin::AddParamSmallInt; + break; + case type::TypeId::INTEGER: + builtin = execution::ast::Builtin::AddParamInt; + break; + case type::TypeId::BIGINT: + builtin = execution::ast::Builtin::AddParamBigInt; + break; + case type::TypeId::DECIMAL: + builtin = execution::ast::Builtin::AddParamDouble; + break; + case type::TypeId::DATE: + builtin = execution::ast::Builtin::AddParamDate; + break; + case type::TypeId::TIMESTAMP: + builtin = execution::ast::Builtin::AddParamTimestamp; + break; + case type::TypeId::VARCHAR: + builtin = execution::ast::Builtin::AddParamString; + break; + default: + UNREACHABLE("Unsupported parameter type"); + } + fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, expr})); + } - // for (auto &col : cols) { - // execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->var_name)->second); - // auto lhs = capture_var; - // if (cols.size() > 1) { - // // record struct type - // lhs = codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(col.GetName())); - // } - // fb_->Append(codegen_->Assign(lhs, codegen_->ConstNull(col.GetType()))); - // } - // // set its execution context to whatever exec context was passed in here + fb_->Append(codegen_->Assign( + codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); - // for (auto &sub_fn : fns) { - // // aux_decls_.push_back(c) - // if (sub_fn.find("Run") != std::string::npos) { - // fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), - // {codegen_->AddressOf(query_state), codegen_->MakeExpr(lam_var)})); - // } else { - // fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), - // {codegen_->AddressOf(query_state)})); - // } - // } + for (auto &col : cols) { + execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->var_name)->second); + auto lhs = capture_var; + if (cols.size() > 1) { + // Record struct type + lhs = codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(col.GetName())); + } + fb_->Append(codegen_->Assign(lhs, codegen_->ConstNull(col.GetType()))); + } - // fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); + auto fns = exec_query->GetFunctionNames(); + for (auto &sub_fn : fns) { + if (sub_fn.find("Run") != std::string::npos) { + fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), + {codegen_->AddressOf(query_state), codegen_->MakeExpr(lam_var)})); + } else { + fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), {codegen_->AddressOf(query_state)})); + } + } - // return; + fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); } void UDFCodegen::Visit(ast::udf::MemberExprAST *ast) { diff --git a/src/include/execution/compiler/executable_query.h b/src/include/execution/compiler/executable_query.h index 899bda96dd..af5e6d8b09 100644 --- a/src/include/execution/compiler/executable_query.h +++ b/src/include/execution/compiler/executable_query.h @@ -118,7 +118,7 @@ class ExecutableQuery { * Create a query object. * @param plan The physical plan. * @param exec_settings The execution settings used for this query. - * @param context TODO(Kyle): this + * @param context The AST context for the executable query; may be nullptr */ ExecutableQuery(const planner::AbstractPlanNode &plan, const exec::ExecutionSettings &exec_settings, ast::Context *context = nullptr); @@ -185,6 +185,11 @@ class ExecutableQuery { /** @return The SQL query string */ common::ManagedPointer GetQueryText() { return query_text_; } + /** + * @return All of the function names in the executable query. + */ + std::vector GetFunctionNames() const; + /** * @return All of the declarations in the executable query. */ From 307393e65181c0b1851125cfd7da99e140618035 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 23 Apr 2021 17:14:07 -0400 Subject: [PATCH 031/139] last big visitor in udf codegen --- src/execution/compiler/udf/udf_codegen.cpp | 288 ++++++++++----------- 1 file changed, 140 insertions(+), 148 deletions(-) diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 0e2418d008..d030cb7f1e 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -331,154 +331,146 @@ void UDFCodegen::Visit(ast::udf::WhileStmtAST *ast) { // TODO(Kyle): Implement void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { - throw NOT_IMPLEMENTED_EXCEPTION("Visit(ForStmtAst*) Not Implemented"); - // needs_exec_ctx_ = true; - // const auto query = common::ManagedPointer(ast->query_); - // auto exec_ctx = fb_->GetParameterByPosition(0); - - // TODO(Matt): I don't think the binder should need the database name. It's already bound in the ConnectionContext - // binder::BindNodeVisitor visitor(common::ManagedPointer(accessor_), db_oid_); - - // auto query_params = visitor.BindAndGetUDFParams(query, common::ManagedPointer(udf_ast_context_)); - - // auto stats = optimizer::StatsStorage(); - - // std::unique_ptr plan = trafficcop::TrafficCopUtil::Optimize( - // accessor_->GetTxn(), common::ManagedPointer(accessor_), query, db_oid_, common::ManagedPointer(&stats), - // std::make_unique(), 1000000); - // make lambda that just writes into this - // std::vector var_idents; - // auto lam_var = codegen_->MakeFreshIdentifier("looplamb"); - // execution::util::RegionVector params(codegen_->GetAstContext()->GetRegion()); - // params.push_back(codegen_->MakeField( - // exec_ctx->As()->Name(), - // codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); - // size_t i = 0; - // for (auto var : ast->vars_) { - // var_idents.push_back(str_to_ident_.find(var)->second); - // auto var_ident = var_idents.back(); - // NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, "Can't support non scalars yet!"); - // auto type = codegen_->TplType(execution::sql::GetTypeId(plan->GetOutputSchema()->GetColumn(i).GetType())); - - // fb_->Append(codegen_->Assign(codegen_->MakeExpr(var_ident), - // codegen_->ConstNull(plan->GetOutputSchema()->GetColumn(i).GetType()))); - // auto input = codegen_->MakeFreshIdentifier(var); - // params.push_back(codegen_->MakeField(input, type)); - // i++; - // } - // execution::ast::LambdaExpr *lambda_expr; - // FunctionBuilder fn(codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)); - // { - // size_t j = 1; - // for (auto var : var_idents) { - // fn.Append(codegen_->Assign(codegen_->MakeExpr(var), fn.GetParameterByPosition(j))); - // j++; - // } - // auto prev_fb = fb_; - // fb_ = &fn; - // ast->body_stmt_->Accept(this); - // fb_ = prev_fb; - // } - - // execution::util::RegionVector captures(codegen_->GetAstContext()->GetRegion()); - // for (auto it : str_to_ident_) { - // if (it.first == "executionCtx") { - // continue; - // } - // captures.push_back(codegen_->MakeExpr(it.second)); - // } - - // lambda_expr = fn.FinishLambda(std::move(captures)); - // lambda_expr->SetName(lam_var); - - // want to pass something down that will materialize the lambda function for me into lambda_expr and will - // also feed in a lambda_expr to the compiler - // execution::exec::ExecutionSettings exec_settings{}; - // const std::string dummy_query = ""; - // auto exec_query = execution::compiler::CompilationContext::Compile( - // *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, - // common::ManagedPointer(&dummy_query), lambda_expr, codegen_->GetAstContext()); - // auto fns = exec_query->GetFunctions(); - // auto decls = exec_query->GetDecls(); - - // aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); - - // fb_->Append( - // codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLitExpr()->TypeRepr()), - // lambda_expr)); - - // make query state - // auto query_state = codegen_->MakeFreshIdentifier("query_state"); - // fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); - // set its execution context to whatever exec context was passed in here - // fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); - // std::vector>::iterator> sorted_vec; - // for (auto it = query_params.begin(); it != query_params.end(); it++) { - // sorted_vec.push_back(it); - // } - - // std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second < y->second; }); - // for (auto entry : sorted_vec) { - // TODO(order these dudes) - // type::TypeId type = type::TypeId::INVALID; - // udf_ast_context_->GetVariableType(entry->first, &type); - // NOISEPAGE_ASSERT(ret, "didn't find param in udf ast context"); - - // execution::ast::Builtin builtin; - // switch (type) { - // case type::TypeId::BOOLEAN: - // builtin = execution::ast::Builtin::AddParamBool; - // break; - // case type::TypeId::TINYINT: - // builtin = execution::ast::Builtin::AddParamTinyInt; - // break; - // case type::TypeId::SMALLINT: - // builtin = execution::ast::Builtin::AddParamSmallInt; - // break; - // case type::TypeId::INTEGER: - // builtin = execution::ast::Builtin::AddParamInt; - // break; - // case type::TypeId::BIGINT: - // builtin = execution::ast::Builtin::AddParamBigInt; - // break; - // case type::TypeId::DECIMAL: - // builtin = execution::ast::Builtin::AddParamDouble; - // break; - // case type::TypeId::DATE: - // builtin = execution::ast::Builtin::AddParamDate; - // break; - // case type::TypeId::TIMESTAMP: - // builtin = execution::ast::Builtin::AddParamTimestamp; - // break; - // case type::TypeId::VARCHAR: - // builtin = execution::ast::Builtin::AddParamString; - // break; - // default: - // UNREACHABLE("Unsupported parameter type"); - // } - // fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, codegen_->MakeExpr(str_to_ident_[entry->first])})); - // } - // set param 1 - // set param 2 - // etc etc - // fb_->Append(codegen_->Assign( - // codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); - // set its execution context to whatever exec context was passed in here - - // for (auto &sub_fn : fns) { - // aux_decls_.push_back(c) - // if (sub_fn.find("Run") != std::string::npos) { - // fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), - // {codegen_->AddressOf(query_state), codegen_->MakeExpr(lam_var)})); - // } else { - // fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), - // {codegen_->AddressOf(query_state)})); - // } - // } - - // fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); - - // return; + // Once we encounter a For-statement we know we need an execution context + needs_exec_ctx_ = true; + + const auto query = common::ManagedPointer(ast->query_); + auto exec_ctx = fb_->GetParameterByPosition(0); + + binder::BindNodeVisitor visitor{common::ManagedPointer(accessor_), db_oid_}; + auto query_params = visitor.BindAndGetUDFParams(query, common::ManagedPointer(udf_ast_context_)); + + auto stats = optimizer::StatsStorage(); + const uint64_t optimizer_timeout = 1000000; + auto optimizer_result = trafficcop::TrafficCopUtil::Optimize( + accessor_->GetTxn(), common::ManagedPointer(accessor_), query, db_oid_, common::ManagedPointer(&stats), + std::make_unique(), optimizer_timeout, nullptr); + auto plan = optimizer_result->GetPlanNode(); + + // Make a lambda that just writes into this + std::vector var_idents; + auto lam_var = codegen_->MakeFreshIdentifier("looplamb"); + execution::util::RegionVector params(codegen_->GetAstContext()->GetRegion()); + params.push_back(codegen_->MakeField( + exec_ctx->As()->Name(), + codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); + std::size_t i{0}; + for (const auto &var : ast->vars_) { + var_idents.push_back(str_to_ident_.find(var)->second); + auto var_ident = var_idents.back(); + NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, "Can't support non scalars yet!"); + auto type = codegen_->TplType(execution::sql::GetTypeId(plan->GetOutputSchema()->GetColumn(i).GetType())); + + fb_->Append(codegen_->Assign(codegen_->MakeExpr(var_ident), + codegen_->ConstNull(plan->GetOutputSchema()->GetColumn(i).GetType()))); + auto input = codegen_->MakeFreshIdentifier(var); + params.push_back(codegen_->MakeField(input, type)); + i++; + } + + execution::ast::LambdaExpr *lambda_expr{}; + FunctionBuilder fn(codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)); + { + std::size_t j{1}; + for (auto var : var_idents) { + fn.Append(codegen_->Assign(codegen_->MakeExpr(var), fn.GetParameterByPosition(j))); + j++; + } + auto prev_fb = fb_; + fb_ = &fn; + ast->body_stmt_->Accept(this); + fb_ = prev_fb; + } + + execution::util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; + for (auto it : str_to_ident_) { + if (it.first == "executionCtx") { + continue; + } + captures.push_back(codegen_->MakeExpr(it.second)); + } + + lambda_expr = fn.FinishLambda(std::move(captures)); + lambda_expr->SetName(lam_var); + + // We want to pass something down that will materialize the lambda + // function into lambda_expr and will also feed in a lambda_expr to the compiler + execution::exec::ExecutionSettings exec_settings{}; + const std::string dummy_query = ""; + auto exec_query = execution::compiler::CompilationContext::Compile( + *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, std::nullopt, + common::ManagedPointer(&dummy_query), lambda_expr, codegen_->GetAstContext()); + + auto decls = exec_query->GetDecls(); + aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); + + fb_->Append( + codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLitExpr()->TypeRepr()), lambda_expr)); + + auto query_state = codegen_->MakeFreshIdentifier("query_state"); + fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); + + // Set its execution context to whatever exec context was passed in here + fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); + std::vector>::iterator> sorted_vec; + for (auto it = query_params.begin(); it != query_params.end(); it++) { + sorted_vec.push_back(it); + } + + std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second < y->second; }); + for (auto entry : sorted_vec) { + // TODO(Kyle): order these + type::TypeId type = type::TypeId::INVALID; + udf_ast_context_->GetVariableType(entry->first, &type); + execution::ast::Builtin builtin{}; + switch (type) { + case type::TypeId::BOOLEAN: + builtin = execution::ast::Builtin::AddParamBool; + break; + case type::TypeId::TINYINT: + builtin = execution::ast::Builtin::AddParamTinyInt; + break; + case type::TypeId::SMALLINT: + builtin = execution::ast::Builtin::AddParamSmallInt; + break; + case type::TypeId::INTEGER: + builtin = execution::ast::Builtin::AddParamInt; + break; + case type::TypeId::BIGINT: + builtin = execution::ast::Builtin::AddParamBigInt; + break; + case type::TypeId::DECIMAL: + builtin = execution::ast::Builtin::AddParamDouble; + break; + case type::TypeId::DATE: + builtin = execution::ast::Builtin::AddParamDate; + break; + case type::TypeId::TIMESTAMP: + builtin = execution::ast::Builtin::AddParamTimestamp; + break; + case type::TypeId::VARCHAR: + builtin = execution::ast::Builtin::AddParamString; + break; + default: + UNREACHABLE("Unsupported parameter type"); + } + fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, codegen_->MakeExpr(str_to_ident_[entry->first])})); + } + + fb_->Append(codegen_->Assign( + codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); + + auto fns = exec_query->GetFunctionNames(); + for (auto &sub_fn : fns) { + if (sub_fn.find("Run") != std::string::npos) { + fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), + {codegen_->AddressOf(query_state), codegen_->MakeExpr(lam_var)})); + } else { + fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), {codegen_->AddressOf(query_state)})); + } + } + + fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); } void UDFCodegen::Visit(ast::udf::RetStmtAST *ast) { From 9d67d8b9ea36f31a9627ada30e52cbc0985415ef Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 23 Apr 2021 21:44:50 -0400 Subject: [PATCH 032/139] slowly working down the list --- src/binder/bind_node_visitor.cpp | 2 +- src/catalog/database_catalog.cpp | 1 - src/catalog/postgres/pg_proc_impl.cpp | 1 - src/execution/ast/ast_clone.cpp | 23 ++++---- src/execution/ast/context.cpp | 28 ++++++++- src/execution/ast/type.cpp | 4 +- .../compiler/compilation_context.cpp | 4 +- src/execution/compiler/executable_query.cpp | 40 ++++++------- .../expression/function_translator.cpp | 4 +- src/execution/compiler/function_builder.cpp | 2 +- src/execution/compiler/udf/udf_codegen.cpp | 20 +++---- src/execution/sema/sema_builtin.cpp | 5 -- src/execution/sema/sema_expr.cpp | 59 +++---------------- src/execution/sema/sema_type.cpp | 3 +- src/execution/vm/bytecode_generator.cpp | 13 ++-- src/execution/vm/llvm_engine.cpp | 1 - src/include/catalog/postgres/pg_language.h | 5 +- src/include/common/strong_typedef.h | 1 - src/include/execution/ast/ast_clone.h | 13 ++-- src/include/execution/ast/type.h | 51 ++++++++++++++-- .../execution/compiler/executable_query.h | 25 ++++++-- .../execution/compiler/function_builder.h | 2 +- 22 files changed, 164 insertions(+), 143 deletions(-) diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 0e0d0868d3..075f703661 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -67,13 +67,13 @@ BindNodeVisitor::~BindNodeVisitor() = default; std::unordered_map> BindNodeVisitor::BindAndGetUDFParams( common::ManagedPointer parse_result, common::ManagedPointer udf_ast_context) { - // TODO(Kyle): Revisit this. NOISEPAGE_ASSERT(parse_result != nullptr, "We shouldn't be trying to bind something without a ParseResult."); sherpa_ = std::make_unique(parse_result, nullptr, nullptr); NOISEPAGE_ASSERT(sherpa_->GetParseResult()->GetStatements().size() == 1, "Binder can only bind one at a time."); udf_ast_context_ = udf_ast_context; sherpa_->GetParseResult()->GetStatement(0)->Accept( common::ManagedPointer(this).CastManagedPointerTo()); + // TODO(Kyle): This is strange, why are we returning this member by value? return udf_params_; } diff --git a/src/catalog/database_catalog.cpp b/src/catalog/database_catalog.cpp index 8dfcf05205..50435f5763 100644 --- a/src/catalog/database_catalog.cpp +++ b/src/catalog/database_catalog.cpp @@ -442,7 +442,6 @@ proc_oid_t DatabaseCatalog::CreateProcedure(common::ManagedPointer { public: - explicit AstCloneImpl(AstNode *root, AstNodeFactory *factory, Context *old_context, Context *new_context, - std::string prefix) - : root_(root), factory_{factory}, old_context_{old_context}, new_context_{new_context}, prefix_{prefix} {} + explicit AstCloneImpl(AstNode *root, AstNodeFactory *factory, Context *old_context, Context *new_context) + : root_(root), factory_{factory}, old_context_{old_context}, new_context_{new_context} {} AstNode *Run() { return Visit(root_); } @@ -35,14 +34,16 @@ class AstCloneImpl : public AstVisitor { } private: + // The root of the AST to clone. AstNode *root_; + // The AST node factory used to allocate new nodes. AstNodeFactory *factory_; - + // The AST context of the source AST. Context *old_context_; + // The AST context of the destination AST. Context *new_context_; - std::string prefix_; - llvm::DenseMap allocated_strings_; + // llvm::DenseMap allocated_strings_; }; AstNode *AstCloneImpl::VisitFile(File *node) { @@ -155,10 +156,7 @@ AstNode *AstCloneImpl::VisitIdentifierExpr(IdentifierExpr *node) { return factory_->NewIdentifierExpr(node->Position(), CloneIdentifier(node->Name())); } -AstNode *AstCloneImpl::VisitImplicitCastExpr(ImplicitCastExpr *node) { - // TODO(Kyle): The type might have to be cloned - return Visit(node->Input()); -} +AstNode *AstCloneImpl::VisitImplicitCastExpr(ImplicitCastExpr *node) { return Visit(node->Input()); } AstNode *AstCloneImpl::VisitIndexExpr(IndexExpr *node) { return factory_->NewIndexExpr(node->Position(), reinterpret_cast(Visit(node->Object())), @@ -253,9 +251,8 @@ AstNode *AstCloneImpl::VisitLambdaTypeRepr(LambdaTypeRepr *node) { return factory_->NewLambdaType(node->Position(), reinterpret_cast(Visit(node->FunctionType()))); } -AstNode *AstClone::Clone(AstNode *node, AstNodeFactory *factory, std::string prefix, Context *old_context, - Context *new_context) { - AstCloneImpl cloner(node, factory, old_context, new_context, prefix); +AstNode *AstClone::Clone(AstNode *node, AstNodeFactory *factory, Context *old_context, Context *new_context) { + AstCloneImpl cloner{node, factory, old_context, new_context}; return cloner.Run(); } diff --git a/src/execution/ast/context.cpp b/src/execution/ast/context.cpp index 86dab369db..5633f8b0d6 100644 --- a/src/execution/ast/context.cpp +++ b/src/execution/ast/context.cpp @@ -396,7 +396,7 @@ StructType *StructType::Get(util::RegionVector &&fields) { } // static -FunctionType *FunctionType::Get(util::RegionVector &¶ms, Type *ret, bool is_lambda = false) { +FunctionType *FunctionType::Get(util::RegionVector &¶ms, Type *ret) { Context *ctx = ret->GetContext(); const FunctionTypeKeyInfo::KeyTy key(ret, params); @@ -410,7 +410,31 @@ FunctionType *FunctionType::Get(util::RegionVector &¶ms, Type *ret, b if (inserted) { // The function type was not in the cache, create the type now and insert it // into the cache - func_type = new (ctx->GetRegion()) FunctionType(std::move(params), ret, is_lambda); + func_type = new (ctx->GetRegion()) FunctionType(std::move(params), ret, false); + *iter = func_type; + } else { + func_type = *iter; + } + + return func_type; +} + +// static +FunctionType *FunctionType::GetLambda(util::RegionVector &¶ms, Type *ret) { + Context *ctx = ret->GetContext(); + + const FunctionTypeKeyInfo::KeyTy key(ret, params); + + auto insert_res = ctx->Impl()->func_types_.insert_as(nullptr, key); + auto iter = insert_res.first; + auto inserted = insert_res.second; + + FunctionType *func_type = nullptr; + + if (inserted) { + // The function type was not in the cache, create the type now and insert it + // into the cache + func_type = new (ctx->GetRegion()) FunctionType(std::move(params), ret, true); *iter = func_type; } else { func_type = *iter; diff --git a/src/execution/ast/type.cpp b/src/execution/ast/type.cpp index 4b88f11a88..912c0ebdf5 100644 --- a/src/execution/ast/type.cpp +++ b/src/execution/ast/type.cpp @@ -97,7 +97,7 @@ bool FunctionType::IsEqual(const FunctionType *other) { return false; } - for (size_t i = 0; i < params_.size(); i++) { + for (auto i = 0UL; i < params_.size(); i++) { if (params_[i].type_ != other->params_[i].type_) { return false; } @@ -107,7 +107,7 @@ bool FunctionType::IsEqual(const FunctionType *other) { } void FunctionType::RegisterCapture() { - NOISEPAGE_ASSERT(captures_ != nullptr, "no capture given?"); + NOISEPAGE_ASSERT(captures_ != nullptr, "No capture given"); params_.emplace_back(GetContext()->GetIdentifier("captures"), captures_); } diff --git a/src/execution/compiler/compilation_context.cpp b/src/execution/compiler/compilation_context.cpp index d6d57e36f7..81186ee0da 100644 --- a/src/execution/compiler/compilation_context.cpp +++ b/src/execution/compiler/compilation_context.cpp @@ -213,10 +213,8 @@ std::unique_ptr CompilationContext::Compile(const planner::Abst query->SetQueryText(query_text); // Generate the plan for the query - CompilationContext ctx(query.get(), accessor, mode, exec_settings, output_callback); + CompilationContext ctx{query.get(), accessor, mode, exec_settings, output_callback}; ctx.GeneratePlan(plan); - - // TODO(Kyle): hacking query->SetQueryStateType(ctx.query_state_.GetType()); // Done diff --git a/src/execution/compiler/executable_query.cpp b/src/execution/compiler/executable_query.cpp index 569a0bb190..40e73ec3fb 100644 --- a/src/execution/compiler/executable_query.cpp +++ b/src/execution/compiler/executable_query.cpp @@ -23,9 +23,16 @@ namespace noisepage::execution::compiler { // //===----------------------------------------------------------------------===// -ExecutableQuery::Fragment::Fragment(std::vector &&functions, std::vector &&teardown_fn, +ExecutableQuery::Fragment::Fragment(std::vector &&functions, std::vector &&teardown_fns, + std::unique_ptr module) + : functions_{std::move(functions)}, teardown_fns_{std::move(teardown_fns)}, module_{std::move(module)} {} + +ExecutableQuery::Fragment::Fragment(std::vector &&functions, std::vector &&teardown_fns, std::unique_ptr module, ast::File *file) - : functions_(std::move(functions)), teardown_fn_(std::move(teardown_fn)), module_(std::move(module)), file_(file) {} + : functions_{std::move(functions)}, + teardown_fns_{std::move(teardown_fns)}, + module_{std::move(module)}, + file_{file} {} ExecutableQuery::Fragment::~Fragment() = default; @@ -45,7 +52,7 @@ void ExecutableQuery::Fragment::Run(byte query_state[], vm::ExecutionMode mode) try { func(query_state); } catch (const AbortException &e) { - for (const auto &teardown_name : teardown_fn_) { + for (const auto &teardown_name : teardown_fns_) { if (!module_->GetFunction(teardown_name, mode, &func)) { throw EXECUTION_EXCEPTION(fmt::format("Could not find teardown function '{}' in query fragment.", func_name), common::ErrorCode::ERRCODE_INTERNAL_ERROR); @@ -89,11 +96,9 @@ ExecutableQuery::ExecutableQuery(const planner::AbstractPlanNode &plan, const ex query_state_size_{0}, pipeline_operating_units_{nullptr}, query_id_{query_identifier++} { - if (ast_context_ == nullptr) { + owns_ast_context_ = (ast_context_ == nullptr); + if (owns_ast_context_) { ast_context_ = new ast::Context(context_region_.get(), errors_.get()); - owned = true; - } else { - owned = false; } } @@ -106,13 +111,11 @@ ExecutableQuery::ExecutableQuery(const std::string &contents, exec_settings_{exec_settings}, context_region_{std::make_unique("context_region")}, errors_region_{std::make_unique("error_region")}, - errors_{std::make_unique(errors_region_.get())} { - if (context) { - ast_context_ = context; - owned = false; - } else { + errors_{std::make_unique(errors_region_.get())}, + ast_context_{context} { + owns_ast_context_ = (ast_context_ == nullptr); + if (owns_ast_context_) { ast_context_ = new ast::Context(context_region_.get(), errors_.get()); - owned = true; } // Let's scan the source std::string source; @@ -132,13 +135,11 @@ ExecutableQuery::ExecutableQuery(const std::string &contents, auto module = compiler::Compiler::RunCompilationSimple(input); std::vector functions{"main"}; - std::vector teardown_functions; + std::vector teardown_functions{}; - // TODO(Kyle): bad API - auto fragment = - std::make_unique(std::move(functions), std::move(teardown_functions), std::move(module), nullptr); + auto fragment = std::make_unique(std::move(functions), std::move(teardown_functions), std::move(module)); - std::vector> fragments; + std::vector> fragments{}; fragments.emplace_back(std::move(fragment)); Setup(std::move(fragments), query_state_size, nullptr); @@ -151,8 +152,7 @@ ExecutableQuery::ExecutableQuery(const std::string &contents, // Needed because we forward-declare classes used as template types to std::unique_ptr<> ExecutableQuery::~ExecutableQuery() { - // TODO(Kyle): This is a bad ownership model, revisit - if (owned) { + if (owns_ast_context_) { delete ast_context_; } } diff --git a/src/execution/compiler/expression/function_translator.cpp b/src/execution/compiler/expression/function_translator.cpp index c2cc4d172d..ec1fb819f4 100644 --- a/src/execution/compiler/expression/function_translator.cpp +++ b/src/execution/compiler/expression/function_translator.cpp @@ -56,7 +56,7 @@ void FunctionTranslator::DefineHelperFunctions(util::RegionVector( - ast::AstClone::Clone(func_context->GetFile(), GetCodeGen()->GetAstContext()->GetNodeFactory(), "", nullptr, + ast::AstClone::Clone(func_context->GetFile(), GetCodeGen()->GetAstContext()->GetNodeFactory(), nullptr, GetCodeGen()->GetAstContext().Get())); auto udf_decls = file->Declarations(); main_fn_ = udf_decls.back()->Name(); @@ -77,7 +77,7 @@ void FunctionTranslator::DefineHelperStructs(util::RegionVector( - ast::AstClone::Clone(func_context->GetFile(), GetCodeGen()->GetAstContext()->GetNodeFactory(), "", nullptr, + ast::AstClone::Clone(func_context->GetFile(), GetCodeGen()->GetAstContext()->GetNodeFactory(), nullptr, GetCodeGen()->GetAstContext().Get())); auto udf_decls = file->Declarations(); size_t num_added = 0; diff --git a/src/execution/compiler/function_builder.cpp b/src/execution/compiler/function_builder.cpp index 5413dbbbe5..51edad919a 100644 --- a/src/execution/compiler/function_builder.cpp +++ b/src/execution/compiler/function_builder.cpp @@ -25,7 +25,7 @@ FunctionBuilder::FunctionBuilder(CodeGen *codegen, util::RegionVectorMakeExpr(params_[param_idx]->Name()); } diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index d030cb7f1e..97f0f1827e 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -27,8 +27,6 @@ #include "planner/plannodes/abstract_plan_node.h" -// TODO(Kyle): Documentation. - namespace noisepage { namespace execution { namespace compiler { @@ -42,7 +40,7 @@ UDFCodegen::UDFCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, codegen_{codegen}, aux_decls_(codegen->GetAstContext()->GetRegion()), db_oid_{db_oid} { - for (size_t i = 0; fb->GetParameterByPosition(i) != nullptr; i++) { + for (auto i = 0UL; fb->GetParameterByPosition(i) != nullptr; ++i) { auto param = fb->GetParameterByPosition(i); const auto &name = param->As()->Name(); str_to_ident_.emplace(name.GetString(), name); @@ -53,9 +51,13 @@ const char *UDFCodegen::GetReturnParamString() { return "return_val"; } void UDFCodegen::GenerateUDF(ast::udf::AbstractAST *ast) { ast->Accept(this); } -void UDFCodegen::Visit(ast::udf::AbstractAST *ast) { UNREACHABLE("Not implemented"); } +void UDFCodegen::Visit(ast::udf::AbstractAST *ast) { + throw NOT_IMPLEMENTED_EXCEPTION("UDFCodegen::Visit(AbstractAST*)"); +} -void UDFCodegen::Visit(ast::udf::DynamicSQLStmtAST *ast) { UNREACHABLE("Not implemented"); } +void UDFCodegen::Visit(ast::udf::DynamicSQLStmtAST *ast) { + throw NOT_IMPLEMENTED_EXCEPTION("UDFCodegen::Visit(DynamicSQLStmtAST*)"); +} catalog::type_oid_t UDFCodegen::GetCatalogTypeOidFromSQLType(execution::ast::BuiltinType::Kind type) { switch (type) { @@ -111,7 +113,7 @@ void UDFCodegen::Visit(ast::udf::CallExprAST *ast) { ident_expr = it->second; } else { auto file = reinterpret_cast( - execution::ast::AstClone::Clone(context->GetFile(), codegen_->GetAstContext()->GetNodeFactory(), "", + execution::ast::AstClone::Clone(context->GetFile(), codegen_->GetAstContext()->GetNodeFactory(), context->GetASTContext(), codegen_->GetAstContext().Get())); for (auto decl : file->Declarations()) { aux_decls_.push_back(decl); @@ -322,14 +324,11 @@ void UDFCodegen::Visit(ast::udf::SeqStmtAST *ast) { void UDFCodegen::Visit(ast::udf::WhileStmtAST *ast) { ast->cond_expr->Accept(this); auto cond = dst_; - // cond = codegen_->Compare(execution::parsing::Token::Type::EQUAL_EQUAL, cond, ) - // cond = codegen_->CallBuiltin(execution::ast::Builtin::SqlToBool, {cond}); Loop loop(fb_, cond); ast->body_stmt->Accept(this); loop.EndLoop(); } -// TODO(Kyle): Implement void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { // Once we encounter a For-statement we know we need an execution context needs_exec_ctx_ = true; @@ -419,7 +418,7 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second < y->second; }); for (auto entry : sorted_vec) { - // TODO(Kyle): order these + // TODO(Kyle): Order these type::TypeId type = type::TypeId::INVALID; udf_ast_context_->GetVariableType(entry->first, &type); execution::ast::Builtin builtin{}; @@ -479,7 +478,6 @@ void UDFCodegen::Visit(ast::udf::RetStmtAST *ast) { fb_->Append(codegen_->Return(ret_expr)); } -// TODO(Kyle): Implement void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { // As soon as we encounter an embedded SQL statement, // we know we need an execution context diff --git a/src/execution/sema/sema_builtin.cpp b/src/execution/sema/sema_builtin.cpp index 4f0e2013a1..1c9c0d524b 100644 --- a/src/execution/sema/sema_builtin.cpp +++ b/src/execution/sema/sema_builtin.cpp @@ -2841,11 +2841,6 @@ void Sema::CheckBuiltinAbortCall(ast::CallExpr *call) { } void Sema::CheckBuiltinParamCall(ast::CallExpr *call, ast::Builtin builtin) { - // TODO(Kyle): Revisit. - // if (!CheckArgCount(call, 1)) { - // return; - // } - // first argument is an exec ctx auto exec_ctx_kind = ast::BuiltinType::ExecutionContext; if (!IsPointerToSpecificBuiltin(call->Arguments()[0]->GetType(), exec_ctx_kind)) { diff --git a/src/execution/sema/sema_expr.cpp b/src/execution/sema/sema_expr.cpp index b03dcab0d2..ab0b347088 100644 --- a/src/execution/sema/sema_expr.cpp +++ b/src/execution/sema/sema_expr.cpp @@ -167,49 +167,9 @@ void Sema::VisitCallExpr(ast::CallExpr *node) { node->SetType(func_type->GetReturnType()); } -// TODO(Kyle): Implement this void Sema::VisitLambdaExpr(ast::LambdaExpr *node) { - // make struct type - // node->SetType(Resolve(node->GetFunctionLitExpr()->TypeRepr())); - // const auto &locals = GetCurrentScope()->GetLocals(); auto factory = GetContext()->GetNodeFactory(); util::RegionVector fields(GetContext()->GetRegion()); - // std::unordered_set used_idents; - // TODO support more than just assignment statements - // for(auto s : node->GetFunctionLitExpr()->Body()->Statements()){ - // if(s->IsAssignmentStmt()) { - // auto expr = s->As()->Destination()->As(); - // used_idents.insert(expr->Name()); - // auto s_expr = s->As()->Source()->SafeAs(); - // if(s_expr != nullptr){ - // used_idents.insert(s_expr->Name()); - // } - // } - // } - // for(auto local : used_idents){ - // auto name = local; - // auto iter = std::find_if(locals.begin(), locals.end(), [=](auto p){ return p.first == name; }); - // if(iter == locals.end()){ - // continue; - // } - // auto type = iter->second; - // ast::Expr *type_repr = nullptr; - // if(type->IsBuiltinType()) { - // type_repr = factory->NewPointerType(SourcePosition(), - // factory->NewIdentifierExpr(SourcePosition(), - // GetContext()->GetIdentifier(ast::BuiltinType::Get(GetContext(), - // type->As()->GetKind()) - // ->GetTplName()))); - // }else{ - // if(type->IsLambdaType()){ - // continue; - // } - // NOISEPAGE_ASSERT(false, "UNSUPPORTED CAPTURED TYPE"); - // } - // type_repr->SetType(type->PointerTo()); - // ast::FieldDecl *field = factory->NewFieldDecl(SourcePosition(), name, type_repr); - // fields.push_back(field); - // } for (auto expr : node->GetCaptureIdents()) { auto ident = expr->As(); Resolve(ident); @@ -244,22 +204,21 @@ void Sema::VisitLambdaExpr(ast::LambdaExpr *node) { factory->NewPointerType(SourcePosition(), node->GetFunctionLitExpr()->TypeRepr()))); ast::StructTypeRepr *struct_type_repr = factory->NewStructType(SourcePosition(), std::move(fields)); - // TODO(tanujnay112) Find a better name + // TODO(Kyle): Find a better name for this identifier ast::StructDecl *struct_decl = factory->NewStructDecl( SourcePosition(), GetContext()->GetIdentifier("lambda" + std::to_string(node->Position().line_)), struct_type_repr); VisitStructDecl(struct_decl); node->SetCaptureStructType(Resolve(struct_type_repr)); node->SetType(ast::LambdaType::Get(Resolve(node->GetFunctionLitExpr()->TypeRepr())->As())); - // GetCurrentScope()->Declare(struct_decl->Name(), node->capture_type_); - - // TODO(Kyle): Why do we need to modify internals? - // auto type = Resolve(node->GetFunctionLitExpr()->TypeRepr()); - // auto fn_type = type->As(); - // fn_type->GetParams().emplace_back(GetContext()->GetIdentifier("captures"), - // GetBuiltinType(ast::BuiltinType::Kind::Int32)->PointerTo()); - // fn_type->is_lambda_ = true; - // fn_type->captures_ = node->GetCaptureStructType()->As(); + + // TODO(Kyle): Why are we performing so much mutation in semantic analysis? + auto type = Resolve(node->GetFunctionLitExpr()->TypeRepr()); + auto fn_type = type->As(); + fn_type->GetParams().emplace_back(GetContext()->GetIdentifier("captures"), + GetBuiltinType(ast::BuiltinType::Kind::Int32)->PointerTo()); + fn_type->SetIsLambda(true); + fn_type->SetCapturesType(node->GetCaptureStructType()->As()); VisitFunctionLitExpr(node->GetFunctionLitExpr()); } diff --git a/src/execution/sema/sema_type.cpp b/src/execution/sema/sema_type.cpp index 9e744d2439..e467ce31a6 100644 --- a/src/execution/sema/sema_type.cpp +++ b/src/execution/sema/sema_type.cpp @@ -51,8 +51,7 @@ void Sema::VisitFunctionTypeRepr(ast::FunctionTypeRepr *node) { } // Create type - // TODO(Kyle): this is a bad API - ast::FunctionType *func_type = ast::FunctionType::Get(std::move(param_types), ret, false); + ast::FunctionType *func_type = ast::FunctionType::Get(std::move(param_types), ret); node->SetType(func_type); } diff --git a/src/execution/vm/bytecode_generator.cpp b/src/execution/vm/bytecode_generator.cpp index 89fdebeff5..8020a7d8b3 100644 --- a/src/execution/vm/bytecode_generator.cpp +++ b/src/execution/vm/bytecode_generator.cpp @@ -135,7 +135,6 @@ class BytecodeGenerator::BytecodePositionScope { // Bytecode Generator begins // --------------------------------------------------------- -// TODO(Kyle): reserve here on functions? BytecodeGenerator::BytecodeGenerator() noexcept : emitter_(&code_) {} void BytecodeGenerator::VisitIfStmt(ast::IfStmt *node) { @@ -300,20 +299,21 @@ void BytecodeGenerator::VisitIdentifierExpr(ast::IdentifierExpr *node) { if (local.IsInvalid()) { NOISEPAGE_ASSERT(GetCurrentFunction()->is_lambda_, "Not a lambda and variable not found"); - - // TODO(Kyle): modularize this fetch of capture struct auto params = GetCurrentFunction()->func_type_->GetParams(); auto captures = GetCurrentFunction()->func_type_->GetCapturesType(); for (auto field : captures->GetFieldsWithoutPadding()) { - // TODO(Kyle): cache these if (field.name_.GetString() == local_name) { auto captures_local = GetCurrentFunction()->LookupLocal("captures"); + auto local_ptr = GetCurrentFunction()->NewLocal(field.type_->PointerTo()); GetEmitter()->EmitLea(local_ptr, captures_local.ValueOf(), captures->GetOffsetOfFieldByName(field.name_)); + auto local_ptr_2 = GetCurrentFunction()->NewLocal(field.type_, local_name + "ptr"); GetEmitter()->EmitDerefN(local_ptr_2, local_ptr.ValueOf(), field.type_->GetSize()); + local = local_ptr_2; suffix = "ptr"; + if (GetExecutionResult()->IsRValue()) { local = GetCurrentFunction()->NewLocal(field.type_->GetPointeeType(), ""); GetEmitter()->EmitDerefN(local, local_ptr_2.ValueOf(), field.type_->GetPointeeType()->GetSize()); @@ -328,9 +328,8 @@ void BytecodeGenerator::VisitIdentifierExpr(ast::IdentifierExpr *node) { NOISEPAGE_ASSERT(!local.IsInvalid(), "Local not found"); if (GetExecutionResult()->IsLValue()) { - // TODO(Kyle): crappy names - auto *local_info_2 = GetCurrentFunction()->LookupLocalInfoByOffset(local.GetOffset()); - if (local_info_2->GetType()->IsPointerType() && local_info_2->GetType()->GetPointeeType()->IsSqlValueType()) { + auto *local_info = GetCurrentFunction()->LookupLocalInfoByOffset(local.GetOffset()); + if (local_info->GetType()->IsPointerType() && local_info->GetType()->GetPointeeType()->IsSqlValueType()) { GetExecutionResult()->SetDestination(local.ValueOf()); } else { GetExecutionResult()->SetDestination(local); diff --git a/src/execution/vm/llvm_engine.cpp b/src/execution/vm/llvm_engine.cpp index 7b3bcefd29..aeb9565849 100644 --- a/src/execution/vm/llvm_engine.cpp +++ b/src/execution/vm/llvm_engine.cpp @@ -287,7 +287,6 @@ llvm::FunctionType *LLVMEngine::TypeMap::GetLLVMFunctionType(const ast::Function // for (const auto ¶m_info : func_type->GetParams()) { - // TODO(Kyle): make this read from bytecode stuff instead to avoid this if (param_info.type_->IsSqlValueType()) { param_types.push_back(GetLLVMType(param_info.type_->PointerTo())); } else { diff --git a/src/include/catalog/postgres/pg_language.h b/src/include/catalog/postgres/pg_language.h index ce0875f920..14cce1adb8 100644 --- a/src/include/catalog/postgres/pg_language.h +++ b/src/include/catalog/postgres/pg_language.h @@ -21,7 +21,9 @@ class PgProcImpl; /** The OIDs used by the NoisePage version of pg_language. */ class PgLanguage { private: - // TODO(Kyle): How do we want to expose these constants? + // TODO(Kyle): Should we come up with a better way of exposting + // these constants rather than simply adding friends for each + // class that needs to access them? This is not scalable. friend class storage::RecoveryManager; friend class execution::sql::DDLExecutors; @@ -45,7 +47,6 @@ class PgLanguage { static constexpr CatalogColumnDef LANISPL{col_oid_t{3}}; // BOOLEAN (skey) static constexpr CatalogColumnDef LANPLTRUSTED{col_oid_t{4}}; // BOOLEAN (skey) - // TODO(Kyle): Make these foreign keys when we implement pg_proc static constexpr CatalogColumnDef LANPLCALLFOID{ col_oid_t{5}}; // INTEGER (skey) (fkey: pg_proc) static constexpr CatalogColumnDef LANINLINE{col_oid_t{6}}; // INTEGER (skey) (fkey: pg_proc) diff --git a/src/include/common/strong_typedef.h b/src/include/common/strong_typedef.h index 50d831393e..64f93cab51 100644 --- a/src/include/common/strong_typedef.h +++ b/src/include/common/strong_typedef.h @@ -102,7 +102,6 @@ class StrongTypeAlias { */ constexpr const IntType &UnderlyingValue() const { return val_; } - // TODO(Kyle): perhaps remove ability to static_cast to underlying value altogether. /** * * @return the underlying value diff --git a/src/include/execution/ast/ast_clone.h b/src/include/execution/ast/ast_clone.h index 2d8b6396f0..2e345cf9e2 100644 --- a/src/include/execution/ast/ast_clone.h +++ b/src/include/execution/ast/ast_clone.h @@ -9,17 +9,18 @@ namespace noisepage::execution::ast { class AstNode; -/** - * TODO(Kyle): Document. - */ class AstClone { public: /** * Clones an ASTNode and its descendants. - * TODO(Kyle): Document. + * @param node The root of the AST to clone. + * @param factory The AstNodeFactory instance from which AST nodes are allocated. + * @param prefix The + * @param old_context + * @param new_context + * @return */ - static AstNode *Clone(AstNode *node, AstNodeFactory *factory, std::string prefix, Context *old_context, - Context *new_context); + static AstNode *Clone(AstNode *node, AstNodeFactory *factory, Context *old_context, Context *new_context); }; } // namespace noisepage::execution::ast diff --git a/src/include/execution/ast/type.h b/src/include/execution/ast/type.h index 03c80735fa..250dd2eb4a 100644 --- a/src/include/execution/ast/type.h +++ b/src/include/execution/ast/type.h @@ -645,25 +645,59 @@ class FunctionType : public Type { */ Type *GetReturnType() const { return ret_; } + /** + * Determine if this function is equivalent to `other`. + * @param other The other function of interest + * @return `true` if the functions are equivalent, `false` otherwise. + */ bool IsEqual(const FunctionType *other); + /** + * @return `true` if this function is a lambda, `false` otherwise. + */ bool IsLambda() const { return is_lambda_; } + /** + * Set the lambda disposition for this function. + * @param `true` if this function is a lambda, `false` otherwise. + */ + void SetIsLambda(bool is_lambda) { is_lambda_ = is_lambda; } + + /** + * Get the type of the lambda captures struct. + * @return The struct type for lambda captures. + */ ast::StructType *GetCapturesType() const { NOISEPAGE_ASSERT(is_lambda_, "Getting capture type from not lambda"); return captures_; } + /** + * Set the type of the lambda captures struct. + * @param captures The struct type for lambda captures. + */ + void SetCapturesType(ast::StructType *captures) { captures_ = captures; } + + /** + * Register lambda captures as a parameter to this function. + */ void RegisterCapture(); /** - * Create a function with parameters @em params and returning types of type @em ret. + * Create a function with parameters `params` and returning types of type `ret`. + * @param params The parameters to the function. + * @param ret The type of the object the function returns. + * @return The function type. + */ + static FunctionType *Get(util::RegionVector &¶ms, Type *ret); + + /** + * Create a lambda function with params `params` and returning types of type `ret`. * @param params The parameters to the function. * @param ret The type of the object the function returns. - * @param is_lambda `true` if this function is a lambda, `false` otherwise. * @return The function type. */ - static FunctionType *Get(util::RegionVector &¶ms, Type *ret, bool is_lambda); + static FunctionType *GetLambda(util::RegionVector &¶ms, Type *ret); /** * @param type type to compare with @@ -677,8 +711,8 @@ class FunctionType : public Type { private: util::RegionVector params_; Type *ret_; - const bool is_lambda_; - ast::StructType *captures_{}; + bool is_lambda_; + ast::StructType *captures_; }; /** @@ -720,12 +754,17 @@ class MapType : public Type { /** * Lambda type. - * TODO(Kyle): Document. */ class LambdaType : public Type { public: + /** + * @return The function type representation. + */ FunctionType *GetFunctionType() const { return fn_type_; } + /** + * @return A newly-constructed lambda type. + */ static LambdaType *Get(FunctionType *fn_type); static bool classof(const Type *type) { return type->GetTypeId() == TypeId::LambdaType; } // NOLINT diff --git a/src/include/execution/compiler/executable_query.h b/src/include/execution/compiler/executable_query.h index af5e6d8b09..316ed22b04 100644 --- a/src/include/execution/compiler/executable_query.h +++ b/src/include/execution/compiler/executable_query.h @@ -65,10 +65,22 @@ class ExecutableQuery { public: /** * Construct a fragment composed of the given functions from the given module. + * + * This constructor assumes that no file is present for the fragment. + * * @param functions The name of the functions to execute, in order. * @param teardown_fns The name of the teardown functions in the module, in order. * @param module The module that contains the functions. - * @param file TODO(Kyle): this + */ + Fragment(std::vector &&functions, std::vector &&teardown_fns, + std::unique_ptr module); + + /** + * Construct a fragment composed of the given functions from the given module. + * @param functions The name of the functions to execute, in order. + * @param teardown_fns The name of the teardown functions in the module, in order. + * @param module The module that contains the functions. + * @param file The file associated with the fragment */ Fragment(std::vector &&functions, std::vector &&teardown_fns, std::unique_ptr module, ast::File *file); @@ -101,11 +113,13 @@ class ExecutableQuery { ast::File *GetFile() { return file_; }; private: - // The functions that must be run (in the provided order) to execute this - // query fragment. + // The functions that must be run (in the provided order) + // to execute this query fragment. std::vector functions_; - std::vector teardown_fn_; + // The functions that must be run (in the provided order) + // to tear down this query fragment. + std::vector teardown_fns_; // The module. std::unique_ptr module_; @@ -209,7 +223,8 @@ class ExecutableQuery { // The AST context used to generate the TPL AST. ast::Context *ast_context_; - bool owned{true}; + // Denotes whether or not the ExecutableQuery owns the AST context. + bool owns_ast_context_; // The compiled query fragments that make up the query. std::vector> fragments_; diff --git a/src/include/execution/compiler/function_builder.h b/src/include/execution/compiler/function_builder.h index 448aa8563c..23568d5b3e 100644 --- a/src/include/execution/compiler/function_builder.h +++ b/src/include/execution/compiler/function_builder.h @@ -45,7 +45,7 @@ class FunctionBuilder { /** * @return A reference to a function parameter by its ordinal position. */ - ast::Expr *GetParameterByPosition(uint32_t param_idx); + ast::Expr *GetParameterByPosition(std::size_t param_idx); /** * Append a statement to the list of statements in this function. From 5ea15c7b4f4912226b9d867b063dd1e04346d6c6 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 23 Apr 2021 22:21:51 -0400 Subject: [PATCH 033/139] wip, editor crashed --- src/execution/compiler/function_builder.cpp | 55 ++++++++++--------- src/execution/sema/sema_stmt.cpp | 2 +- src/include/execution/ast/ast.h | 21 +++++-- .../execution/ast/udf/udf_ast_context.h | 7 ++- .../execution/ast/udf/udf_ast_node_visitor.h | 2 - src/include/execution/compiler/codegen.h | 2 - .../execution/compiler/compilation_context.h | 4 +- .../compiler/expression/function_translator.h | 6 +- .../execution/compiler/function_builder.h | 21 +++---- 9 files changed, 65 insertions(+), 55 deletions(-) diff --git a/src/execution/compiler/function_builder.cpp b/src/execution/compiler/function_builder.cpp index 51edad919a..ef968322c5 100644 --- a/src/execution/compiler/function_builder.cpp +++ b/src/execution/compiler/function_builder.cpp @@ -7,21 +7,23 @@ namespace noisepage::execution::compiler { FunctionBuilder::FunctionBuilder(CodeGen *codegen, ast::Identifier name, util::RegionVector &¶ms, ast::Expr *ret_type) - : codegen_(codegen), - name_(name), - params_(std::move(params)), - ret_type_(ret_type), - start_(codegen->GetPosition()), - statements_(codegen->MakeEmptyBlock()), - is_lambda_(false) {} + : codegen_{codegen}, + name_{name}, + params_{std::move(params)}, + ret_type_{ret_type}, + start_{codegen->GetPosition()}, + statements_{codegen->MakeEmptyBlock()}, + is_lambda_{false}, + decl_{std::in_place_type, nullptr} {} FunctionBuilder::FunctionBuilder(CodeGen *codegen, util::RegionVector &¶ms, ast::Expr *ret_type) - : codegen_(codegen), - params_(std::move(params)), - ret_type_(ret_type), - start_(codegen->GetPosition()), - statements_(codegen->MakeEmptyBlock()), - is_lambda_(true) {} + : codegen_{codegen}, + params_{std::move(params)}, + ret_type_{ret_type}, + start_{codegen->GetPosition()}, + statements_{codegen->MakeEmptyBlock()}, + is_lambda_{true}, + decl_{std::in_place_type, nullptr} {} FunctionBuilder::~FunctionBuilder() { Finish(); } @@ -44,8 +46,11 @@ void FunctionBuilder::Append(ast::Expr *expr) { Append(codegen_->GetFactory()->N void FunctionBuilder::Append(ast::VariableDecl *decl) { Append(codegen_->GetFactory()->NewDeclStmt(decl)); } ast::FunctionDecl *FunctionBuilder::Finish(ast::Expr *ret) { - if (decl_.fn_decl_ != nullptr) { - return decl_.fn_decl_; + NOISEPAGE_ASSERT(!is_lambda_, "Attempt to call Finish() on a FunctionDecl that is a lambda"); + NOISEPAGE_ASSERT(std::holds_alternative(decl_), "Broken invariant"); + auto *declaration = std::get(decl_); + if (declaration != nullptr) { + return declaration; } NOISEPAGE_ASSERT(ret == nullptr || statements_->IsEmpty() || !statements_->GetLast()->IsReturnStmt(), @@ -66,17 +71,17 @@ ast::FunctionDecl *FunctionBuilder::Finish(ast::Expr *ret) { // Create the declaration. auto func_lit = codegen_->GetFactory()->NewFunctionLitExpr(func_type, statements_); - decl_.fn_decl_ = codegen_->GetFactory()->NewFunctionDecl(start_, name_, func_lit); - - // Done - return decl_.fn_decl_; + decl_ = codegen_->GetFactory()->NewFunctionDecl(start_, name_, func_lit); + return std::get(decl_); } noisepage::execution::ast::LambdaExpr *FunctionBuilder::FinishLambda(util::RegionVector &&captures, ast::Expr *ret) { - NOISEPAGE_ASSERT(is_lambda_, "Asking to finish a lambda function that's not actually a lambda function"); - if (decl_.lambda_expr_ != nullptr) { - return decl_.lambda_expr_; + NOISEPAGE_ASSERT(is_lambda_, "Attempt to call FinishLambda() on a FunctionDecl that is not a lambda"); + NOISEPAGE_ASSERT(std::holds_alternative(decl_), "Broken invariant"); + auto *declaration = std::get(decl_); + if (declaration != nullptr) { + return declaration; } NOISEPAGE_ASSERT(ret == nullptr || statements_->IsEmpty() || !statements_->GetLast()->IsReturnStmt(), @@ -94,10 +99,8 @@ noisepage::execution::ast::LambdaExpr *FunctionBuilder::FinishLambda(util::Regio // Create the declaration. auto func_lit = codegen_->GetFactory()->NewFunctionLitExpr(func_type, statements_); - decl_.lambda_expr_ = codegen_->GetFactory()->NewLambdaExpr(start_, func_lit, std::move(captures)); - - // Done - return decl_.lambda_expr_; + decl_ = codegen_->GetFactory()->NewLambdaExpr(start_, func_lit, std::move(captures)); + return std::get(decl_); } } // namespace noisepage::execution::compiler diff --git a/src/execution/sema/sema_stmt.cpp b/src/execution/sema/sema_stmt.cpp index d3a7ec607b..788207de09 100644 --- a/src/execution/sema/sema_stmt.cpp +++ b/src/execution/sema/sema_stmt.cpp @@ -66,7 +66,7 @@ void Sema::VisitForStmt(ast::ForStmt *node) { auto context = GetContext(); auto factory = context->GetNodeFactory(); auto args = util::RegionVector({node->Condition()}, context->GetRegion()); - node->SetCond(factory->NewBuiltinCallExpr( + node->SetCondition(factory->NewBuiltinCallExpr( factory->NewIdentifierExpr(node->Position(), GetContext()->GetBuiltinFunction(execution::ast::Builtin::SqlToBool)), std::move(args))); diff --git a/src/include/execution/ast/ast.h b/src/include/execution/ast/ast.h index 1486884199..34cc7de336 100644 --- a/src/include/execution/ast/ast.h +++ b/src/include/execution/ast/ast.h @@ -766,9 +766,12 @@ class ForStmt : public IterationStmt { friend class sema::Sema; /** - * TODO(Kyle): Why? + * Set the condition for the for-loop. */ - void SetCond(Expr *cond) { cond_ = cond; } + void SetCondition(Expr *cond) { + NOISEPAGE_ASSERT(cond != nullptr, "Cannot set null condition"); + cond_ = cond; + } private: Stmt *init_; @@ -870,6 +873,9 @@ class IfStmt : public Stmt { private: friend class sema::Sema; + /** + * Set the condition for the if-statement. + */ void SetCondition(Expr *cond) { NOISEPAGE_ASSERT(cond != nullptr, "Cannot set null condition"); cond_ = cond; @@ -1183,7 +1189,7 @@ class CallExpr : public Expr { uint32_t NumArgs() const { return static_cast(args_.size()); } /** - * TODO(Kyle): Document. + * Add an argument to the call. */ void PushArgument(Expr *expr) { args_.push_back(expr); } @@ -1909,12 +1915,19 @@ class MapTypeRepr : public Expr { /** * Lambda type. - * TODO(Kyle): Document. */ class LambdaTypeRepr : public Expr { public: + /** + * Constructor + * @param pos source position + * @param fn_type function type + */ LambdaTypeRepr(const SourcePosition &pos, Expr *fn_type) : Expr(Kind::LambdaTypeRepr, pos), fn_type_(fn_type) {} + /** + * @return The expression for the type. + */ Expr *FunctionType() const { return fn_type_; } static bool classof(const AstNode *node) { // NOLINT diff --git a/src/include/execution/ast/udf/udf_ast_context.h b/src/include/execution/ast/udf/udf_ast_context.h index def6df89d7..1fe01b2610 100644 --- a/src/include/execution/ast/udf/udf_ast_context.h +++ b/src/include/execution/ast/udf/udf_ast_context.h @@ -2,8 +2,6 @@ #include "type/type_id.h" -// TODO(Kyle): Documentation. - namespace noisepage { namespace execution { namespace ast { @@ -29,7 +27,7 @@ class UDFASTContext { void AddVariable(const std::string &name) { local_variables_.push_back(name); } const std::string &GetVariableAtIndex(const std::size_t index) { - NOISEPAGE_ASSERT(local_variables_.size() >= index, "Bad var"); + NOISEPAGE_ASSERT(local_variables_.size() >= index, "Bad variable"); // TODO(Kyle): Why did this originally have index - 1? return local_variables_.at(index); } @@ -43,8 +41,11 @@ class UDFASTContext { } private: + // The symbol table for the UDF. std::unordered_map symbol_table_; + // Collection of local variable names for the UDF. std::vector local_variables_; + // Collection of record types for the UDF. std::unordered_map>> record_types_; }; diff --git a/src/include/execution/ast/udf/udf_ast_node_visitor.h b/src/include/execution/ast/udf/udf_ast_node_visitor.h index 178a7ae65b..2ca9c6a17c 100644 --- a/src/include/execution/ast/udf/udf_ast_node_visitor.h +++ b/src/include/execution/ast/udf/udf_ast_node_visitor.h @@ -1,7 +1,5 @@ #pragma once -// TODO(Kyle): This whole file needs documentation. - namespace noisepage { namespace execution { namespace ast { diff --git a/src/include/execution/compiler/codegen.h b/src/include/execution/compiler/codegen.h index 8fdb10bb6c..6732793a76 100644 --- a/src/include/execution/compiler/codegen.h +++ b/src/include/execution/compiler/codegen.h @@ -416,8 +416,6 @@ class CodeGen { */ [[nodiscard]] ast::Expr *AccessStructMember(ast::Expr *object, ast::Identifier member); - // TODO(Kyle): These should be in a different section? - /** * Create a break statement. * @return The statement. diff --git a/src/include/execution/compiler/compilation_context.h b/src/include/execution/compiler/compilation_context.h index 761698d61d..8200a08664 100644 --- a/src/include/execution/compiler/compilation_context.h +++ b/src/include/execution/compiler/compilation_context.h @@ -50,8 +50,8 @@ class CompilationContext { * @param mode The compilation mode. * @param override_qid Optional indicating how to override the plan's query id * @param query_text The SQL query string (temporary) - * @param output_callback TODO(Kyle) - * @param context TODO(Kyle) + * @param output_callback The lambda utilized as the output callback for the query + * @param context The AST context for the query */ static std::unique_ptr Compile(const planner::AbstractPlanNode &plan, const exec::ExecutionSettings &exec_settings, diff --git a/src/include/execution/compiler/expression/function_translator.h b/src/include/execution/compiler/expression/function_translator.h index f892509b7c..f6e48b2099 100644 --- a/src/include/execution/compiler/expression/function_translator.h +++ b/src/include/execution/compiler/expression/function_translator.h @@ -34,12 +34,14 @@ class FunctionTranslator : public ExpressionTranslator { ast::Expr *DeriveValue(WorkContext *ctx, const ColumnValueProvider *provider) const override; /** - * TODO(Kyle): this. + * Define the helper functions for this function translator. + * @param decls The collection of helper function declarations */ void DefineHelperFunctions(util::RegionVector *decls) override; /** - * TODO(Kyle): this. + * Define the helper structs for this function translator. + * @param decls The collection of helper struct declarations */ void DefineHelperStructs(util::RegionVector *decls) override; diff --git a/src/include/execution/compiler/function_builder.h b/src/include/execution/compiler/function_builder.h index 23568d5b3e..5ae3f20513 100644 --- a/src/include/execution/compiler/function_builder.h +++ b/src/include/execution/compiler/function_builder.h @@ -2,6 +2,7 @@ #include #include +#include #include "execution/ast/identifier.h" #include "execution/compiler/ast_fwd.h" @@ -82,16 +83,16 @@ class FunctionBuilder { ast::Expr *ret = nullptr); /** - * @return The final constructed function; null if the builder hasn't been constructed through - * FunctionBuilder::Finish(). + * @return The final constructed function, or nullptr if the builder + * hasn't been constructed through FunctionBuilder::Finish(). */ - ast::FunctionDecl *GetConstructedFunction() const { return decl_.fn_decl_; } + ast::FunctionDecl *GetConstructedFunction() const { return std::get(decl_); } /** - * @return The final constructed lambda; null if the builder hasn't been constructed through - * FunctionBuilder::FinishLambda(). + * @return The final constructed lambda, or nullptr if the builder + * hasn't been constructed through FunctionBuilder::FinishLambda(). */ - ast::LambdaExpr *GetConstructedLambda() const { return decl_.lambda_expr_; } + ast::LambdaExpr *GetConstructedLambda() const { return std::get(decl_); } /** * @return The code generator instance. @@ -111,16 +112,10 @@ class FunctionBuilder { SourcePosition start_; // The list of generated statements making up the function. ast::BlockStmt *statements_; - // `true` if this function is a lambda, `false` otherwise. bool is_lambda_; - // The cached function declaration. Constructed once in Finish(). - // TODO(Kyle): This needs to be a variant... - union { - ast::FunctionDecl *fn_decl_{nullptr}; - ast::LambdaExpr *lambda_expr_; - } decl_; + std::variant decl_; }; } // namespace noisepage::execution::compiler From f997bb55eb9340b39e6507083ee1d8ee03235b5e Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sat, 24 Apr 2021 08:45:08 -0400 Subject: [PATCH 034/139] all major todos addressed, now time to start running tests --- src/binder/bind_node_visitor.cpp | 4 +- src/execution/compiler/udf/udf_codegen.cpp | 10 +- src/execution/sql/ddl_executors.cpp | 74 ++++------ .../execution/ast/udf/udf_ast_context.h | 49 +++++-- .../compiler/operator/operator_translator.h | 16 +-- .../execution/compiler/udf/udf_codegen.h | 131 ++++++++++++++++-- .../execution/exec/execution_context.h | 21 ++- src/include/execution/sema/scope.h | 6 +- src/include/execution/vm/bytecode_generator.h | 1 - .../expression/column_value_expression.h | 4 +- .../expression/constant_value_expression.h | 8 +- src/include/parser/udf/udf_parser.h | 1 - .../expression/constant_value_expression.cpp | 4 +- src/parser/postgresparser.cpp | 1 - src/parser/udf/udf_parser.cpp | 4 +- 15 files changed, 230 insertions(+), 104 deletions(-) diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 075f703661..fe543cefad 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -570,7 +570,7 @@ void BindNodeVisitor::Visit(common::ManagedPointerGetVariableType(expr->GetColumnName(), &the_type)) { expr->SetReturnValueType(the_type); - auto idx = 0; + std::size_t idx = 0; if (udf_params_.count(expr->GetColumnName()) == 0) { udf_params_[expr->GetColumnName()] = std::make_pair("", udf_params_.size()); idx = udf_params_.size() - 1; @@ -593,7 +593,7 @@ void BindNodeVisitor::Visit(common::ManagedPointerGetRecordType(expr->GetTableName()); auto it = std::find_if(fields.begin(), fields.end(), [=](auto p) { return p.first == expr->GetColumnName(); }); - auto idx = 0; + std::size_t idx = 0; if (it != fields.end()) { if (udf_params_.count(expr->GetColumnName()) == 0) { udf_params_[expr->GetColumnName()] = std::make_pair(expr->GetTableName(), udf_params_.size()); diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 97f0f1827e..11b1a9c167 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -38,8 +38,9 @@ UDFCodegen::UDFCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, fb_{fb}, udf_ast_context_{udf_ast_context}, codegen_{codegen}, + db_oid_{db_oid}, aux_decls_(codegen->GetAstContext()->GetRegion()), - db_oid_{db_oid} { + needs_exec_ctx_{false} { for (auto i = 0UL; fb->GetParameterByPosition(i) != nullptr; ++i) { auto param = fb->GetParameterByPosition(i); const auto &name = param->As()->Name(); @@ -47,6 +48,7 @@ UDFCodegen::UDFCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, } } +// Static const char *UDFCodegen::GetReturnParamString() { return "return_val"; } void UDFCodegen::GenerateUDF(ast::udf::AbstractAST *ast) { ast->Accept(this); } @@ -75,11 +77,7 @@ catalog::type_oid_t UDFCodegen::GetCatalogTypeOidFromSQLType(execution::ast::Bui execution::ast::File *UDFCodegen::Finish() { auto fn = fb_->Finish(); - // util::RegionVector decls_reg_vec{decls->begin(), decls->end(), codegen.Region()}; - execution::util::RegionVector decls({fn}, codegen_->GetAstContext()->GetRegion()); - // for(auto decl : aux_decls_){ - // decls.push_back(decl); - // } + execution::util::RegionVector decls{{fn}, codegen_->GetAstContext()->GetRegion()}; decls.insert(decls.begin(), aux_decls_.begin(), aux_decls_.end()); auto file = codegen_->GetAstContext()->GetNodeFactory()->NewFile({0, 0}, std::move(decls)); return file; diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index d696d0879d..27ec09ee15 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -48,11 +48,11 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointer accessor) { // Request permission from the Catalog to see if this a valid namespace name NOISEPAGE_ASSERT(node->GetUDFLanguage() == parser::PLType::PL_PGSQL, "Unsupported language"); - NOISEPAGE_ASSERT(node->GetFunctionBody().size() >= 1, "Unsupported function body?"); + NOISEPAGE_ASSERT(node->GetFunctionBody().size() >= 1, "Unsupported function body contents"); // I don't like how we have to separate the two here - std::vector param_type_ids; - std::vector param_types; + std::vector param_type_ids{}; + std::vector param_types{}; for (auto t : node->GetFunctionParameterTypes()) { param_type_ids.push_back(parser::FuncParameter::DataTypeToTypeId(t)); param_types.push_back(accessor->GetTypeOidFromTypeId(parser::FuncParameter::DataTypeToTypeId(t))); @@ -66,10 +66,9 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerGetDatabaseOid()); + parser::udf::PLpgSQLParser udf_parser{(common::ManagedPointer(&udf_ast_context)), accessor, node->GetDatabaseOid()}; std::unique_ptr ast{}; try { ast = udf_parser.ParsePLpgSQL(node->GetFunctionParameterNames(), std::move(param_type_ids), body, @@ -79,66 +78,53 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerGetFunctionName()); - sema::ErrorReporter error_reporter(region); - auto ast_context = new ast::Context(region, &error_reporter); + sema::ErrorReporter error_reporter{region}; - compiler::CodeGen codegen(ast_context, accessor.Get()); + auto ast_context = std::make_unique(region, &error_reporter); + + compiler::CodeGen codegen{ast_context.get(), accessor.Get()}; util::RegionVector fn_params{codegen.GetAstContext()->GetRegion()}; - // auto ret_name = parser::udf::UDFCodegen::GetReturnParamString(); - // auto ret_type = parser::ReturnType::DataTypeToTypeId(node->GetReturnType()); - // fn_params.emplace_back(codegen.MakeField(ast::Identifier{ret_name}, - // codegen.PointerType(codegen.TplType(ret_type)))); fn_params.emplace_back( codegen.MakeField(codegen.MakeFreshIdentifier("executionCtx"), codegen.PointerType(codegen.BuiltinType(ast::BuiltinType::ExecutionContext)))); - for (size_t i = 0; i < node->GetFunctionParameterNames().size(); i++) { - auto name = node->GetFunctionParameterNames()[i]; - auto type = parser::ReturnType::DataTypeToTypeId(node->GetFunctionParameterTypes()[i]); - // auto name_alloc = reinterpret_cast(codegen.GetAstContext()->GetRegion()->Allocate(name.length()+1)); - // std::memcpy(name_alloc, name.c_str(), name.length() + 1); - fn_params.emplace_back(codegen.MakeField(ast_context->GetIdentifier(name), - // codegen.PointerType( - codegen.TplType(execution::sql::GetTypeId(type)) - // ) - )); + for (auto i = 0UL; i < node->GetFunctionParameterNames().size(); i++) { + const auto &name = node->GetFunctionParameterNames()[i]; + const auto &type = parser::ReturnType::DataTypeToTypeId(node->GetFunctionParameterTypes()[i]); + fn_params.emplace_back( + codegen.MakeField(ast_context->GetIdentifier(name), codegen.TplType(execution::sql::GetTypeId(type)))); } auto name = node->GetFunctionName(); - // char *name_alloc = reinterpret_cast(codegen.GetAstContext()->GetRegion()->Allocate(name.length() + 1)); - // std::memcpy(name_alloc, name.c_str(), name.length() + 1); - compiler::FunctionBuilder fb{ &codegen, codegen.MakeFreshIdentifier(name), std::move(fn_params), - // codegen.PointerType( - codegen.TplType(execution::sql::GetTypeId(parser::ReturnType::DataTypeToTypeId(node->GetReturnType()))) - // ) - }; + codegen.TplType(execution::sql::GetTypeId(parser::ReturnType::DataTypeToTypeId(node->GetReturnType())))}; + compiler::udf::UDFCodegen udf_codegen{accessor.Get(), &fb, &udf_ast_context, &codegen, node->GetDatabaseOid()}; udf_codegen.GenerateUDF(ast->body.get()); - auto fn = fb.Finish(); - //// util::RegionVector decls_reg_vec{decls->begin(), decls->end(), codegen.Region()}; - util::RegionVector decls({fn}, codegen.GetAstContext()->GetRegion()); - auto file = udf_codegen.Finish(); + auto *file = udf_codegen.Finish(); { - sema::Sema type_check(codegen.GetAstContext().Get()); + sema::Sema type_check{codegen.GetAstContext().Get()}; type_check.GetErrorReporter()->Reset(); - type_check.Run(file); - EXECUTION_LOG_ERROR("Errors: \n {}", type_check.GetErrorReporter()->SerializeErrors()); - execution::ast::AstPrettyPrint::Dump(std::cout, file); - // NOISEPAGE_ASSERT(!bad, "bad function"); + if (type_check.Run(file)) { + EXECUTION_LOG_ERROR("Errors: \n {}", type_check.GetErrorReporter()->SerializeErrors()); + execution::ast::AstPrettyPrint::Dump(std::cout, file); + return false; + } } - auto udf_context = new functions::FunctionContext( + auto udf_context = std::make_unique( node->GetFunctionName(), parser::ReturnType::DataTypeToTypeId(node->GetReturnType()), std::move(param_type_ids), - std::unique_ptr(region), std::unique_ptr(ast_context), file); - if (!accessor->SetFunctionContextPointer(proc_id, udf_context)) { - delete udf_context; + std::unique_ptr(region), std::move(ast_context), file); + if (!accessor->SetFunctionContextPointer(proc_id, udf_context.get())) { return false; } - accessor->GetTxn()->RegisterAbortAction([=]() { delete udf_context; }); + // TODO(Kyle): Not quite sure how abort actions work, but is + // the implication here that we leak in the event that we do + // not abort and the associated transaction completes? + accessor->GetTxn()->RegisterAbortAction([udf_context = udf_context.release()]() { delete udf_context; }); return true; } diff --git a/src/include/execution/ast/udf/udf_ast_context.h b/src/include/execution/ast/udf/udf_ast_context.h index 1fe01b2610..9c34b756a1 100644 --- a/src/include/execution/ast/udf/udf_ast_context.h +++ b/src/include/execution/ast/udf/udf_ast_context.h @@ -9,12 +9,24 @@ namespace udf { class UDFASTContext { public: - UDFASTContext() {} + UDFASTContext() = default; - void SetVariableType(const std::string &var, type::TypeId type) { symbol_table_[var] = type; } + /** + * Set the type of the variabel identifed by `name`. + * @param name The name of the variable + * @param type The type to which the variable should be set + */ + void SetVariableType(const std::string &name, type::TypeId type) { symbol_table_[name] = type; } - bool GetVariableType(const std::string &var, type::TypeId *type) { - auto it = symbol_table_.find(var); + /** + * Get the type of the variable identified by `name`. + * @param name The name of the variable + * @param type The out-parameter used to store the result + * @return `true` if the variable is present in the symbol + * table and the Get() succeeds, `false` otherwise + */ + bool GetVariableType(const std::string &name, type::TypeId *type) { + auto it = symbol_table_.find(name); if (it == symbol_table_.end()) { return false; } @@ -24,20 +36,39 @@ class UDFASTContext { return true; } + /** + * Add a new variable to the symbol table. + * @param name The name of the variable + */ void AddVariable(const std::string &name) { local_variables_.push_back(name); } - const std::string &GetVariableAtIndex(const std::size_t index) { + /** + * Get the local variable at index `index`. + * @param index The index of interest + * @return The name of the variable at the specified index + */ + const std::string &GetLocalVariableAtIndex(const std::size_t index) { NOISEPAGE_ASSERT(local_variables_.size() >= index, "Bad variable"); // TODO(Kyle): Why did this originally have index - 1? return local_variables_.at(index); } - void SetRecordType(std::string var, std::vector> &&elems) { - record_types_[var] = std::move(elems); + /** + * Get the record type for the specified variable. + * @param name The name of the variable + * @return The record + */ + const std::vector> &GetRecordType(const std::string &name) const { + return record_types_.find(name)->second; } - const std::vector> &GetRecordType(const std::string &var) { - return record_types_.find(var)->second; + /** + * Set the record type for the specified variable. + * @param name The name of the variable + * @param elems The record + */ + void SetRecordType(const std::string &name, std::vector> &&elems) { + record_types_[name] = std::move(elems); } private: diff --git a/src/include/execution/compiler/operator/operator_translator.h b/src/include/execution/compiler/operator/operator_translator.h index 18b8114088..f35ce91468 100644 --- a/src/include/execution/compiler/operator/operator_translator.h +++ b/src/include/execution/compiler/operator/operator_translator.h @@ -254,17 +254,6 @@ class OperatorTranslator : public ColumnValueProvider { /** @return The address of the current tuple slot, if any. */ virtual ast::Expr *GetSlotAddress() const { UNREACHABLE("This translator does not deal with tupleslots."); } - /** - * TODO(Kyle): This. - */ - virtual void RegisterNeedValue(const OperatorTranslator *requester, uint32_t child_idx, uint32_t attr_idx) { - UNREACHABLE("not implemented"); - } - - /** @return The pipeline this translator is a part of. */ - // TODO(Kyle): Why did we change visibility of this? Protected to public - Pipeline *GetPipeline() const { return pipeline_; } - protected: /** Get the code generator instance. */ CodeGen *GetCodeGen() const; @@ -281,6 +270,9 @@ class OperatorTranslator : public ColumnValueProvider { /** Get the memory pool pointer from the execution context stored in the query state. */ ast::Expr *GetMemoryPool() const; + /** @return The pipeline this translator is a part of. */ + Pipeline *GetPipeline() const { return pipeline_; } + /** The plan node for this translator as its concrete type. */ template const T &GetPlanAs() const { @@ -369,7 +361,7 @@ class OperatorTranslator : public ColumnValueProvider { const planner::AbstractPlanNode &plan_; // The compilation context. CompilationContext *compilation_context_; - // The pipeline the operator belongs to. + // The pipeline to which the operator belongs. Pipeline *pipeline_; /** The child operator translator. */ diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index c0f10ee5f2..20034a0d92 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -6,8 +6,6 @@ #include "execution/compiler/function_builder.h" #include "execution/functions/function_context.h" -// TODO(Kyle): Documentation. - namespace noisepage::catalog { class CatalogAccessor; } @@ -43,55 +41,164 @@ class ForStmtAST; namespace compiler { namespace udf { -// TODO(Kyle): Is distinguishing the standard codegen -// namespace stuff from the UDF stuff here going to be -// an issue (i.e. disambiguation)? - class UDFCodegen : ast::udf::ASTNodeVisitor { public: + /** + * Construct a new UDFCodegen instance. + * @param accessor The catalog accessor used in code generation + * @param fb The function builder instance used for the UDF + * @param udf_ast_context The AST context for the UDF + * @param codegen The codegen instance + * @param db_oid The OID for the relevant database + */ UDFCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, ast::udf::UDFASTContext *udf_ast_context, CodeGen *codegen, catalog::db_oid_t db_oid); - ~UDFCodegen(){}; - catalog::type_oid_t GetCatalogTypeOidFromSQLType(execution::ast::BuiltinType::Kind type); + ~UDFCodegen() = default; - void GenerateUDF(ast::udf::AbstractAST *); + /** + * Generate a UDF from the given abstract syntax tree. + * @param ast The AST from which to generate the UDF + */ + void GenerateUDF(ast::udf::AbstractAST *ast); + /** + * Visit an AbstractAST node. + */ void Visit(ast::udf::AbstractAST *) override; + + /** + * Visit a FunctionAST node. + */ void Visit(ast::udf::FunctionAST *) override; + + /** + * Visit a StmtAST node. + */ void Visit(ast::udf::StmtAST *) override; + + /** + * Visit an ExprAST node. + */ void Visit(ast::udf::ExprAST *) override; + + /** + * Visit a ValueExprAST node. + */ void Visit(ast::udf::ValueExprAST *) override; + + /** + * Visit a VariableExprAST node. + */ void Visit(ast::udf::VariableExprAST *) override; + + /** + * Visit a BinaryExprAST node. + */ void Visit(ast::udf::BinaryExprAST *) override; + + /** + * Visit a CallExprAST node. + */ void Visit(ast::udf::CallExprAST *) override; + + /** + * Visit an IsNullExprAST node. + */ void Visit(ast::udf::IsNullExprAST *) override; + + /** + * Visit a SeqStmtAST node. + */ void Visit(ast::udf::SeqStmtAST *) override; + + /** + * Visit a DeclStmtNode node. + */ void Visit(ast::udf::DeclStmtAST *) override; + + /** + * Visit a IfStmtAST node. + */ void Visit(ast::udf::IfStmtAST *) override; + + /** + * Visit a WhileStmtAST node. + */ void Visit(ast::udf::WhileStmtAST *) override; + + /** + * Visit a RetStmtAST node. + */ void Visit(ast::udf::RetStmtAST *) override; + + /** + * Visit an AssignStmtAST node. + */ void Visit(ast::udf::AssignStmtAST *) override; + + /** + * Visit a SQLStmtAST node. + */ void Visit(ast::udf::SQLStmtAST *) override; + + /** + * Visit a DynamicSQLStmtAST node. + */ void Visit(ast::udf::DynamicSQLStmtAST *) override; + + /** + * Visit a ForStmtAST node. + */ void Visit(ast::udf::ForStmtAST *) override; + + /** + * Visit a MemberExprAST node. + */ void Visit(ast::udf::MemberExprAST *) override; + /** + * Complete UDF code generation. + * @return The result of code generation as a file + */ execution::ast::File *Finish(); + /** + * Return the string that represents the return value. + * @return The string that represents the return value + */ static const char *GetReturnParamString(); private: + /** + * Translate a SQL type to its corresponding catalog type. + * @param type The SQL type of interest + * @return The corresponding catalog type + */ + catalog::type_oid_t GetCatalogTypeOidFromSQLType(execution::ast::BuiltinType::Kind type); + + // The catalog access used during code generation catalog::CatalogAccessor *accessor_; + // The function builder used during code generation FunctionBuilder *fb_; + // The AST context for the UDF ast::udf::UDFASTContext *udf_ast_context_; + // The code generation instance CodeGen *codegen_; + // The OID of the relevant database + catalog::db_oid_t db_oid_; + // Auxiliary declarations + execution::util::RegionVector aux_decls_; + + // Flag indicating whether this UDF requires an execution context + bool needs_exec_ctx_; + + // The current type during code generation type::TypeId current_type_{type::TypeId::INVALID}; + // The destination expression execution::ast::Expr *dst_; + // Map from human-readable string identifier to internal identifier std::unordered_map str_to_ident_; - execution::util::RegionVector aux_decls_; - catalog::db_oid_t db_oid_; - bool needs_exec_ctx_{false}; }; } // namespace udf diff --git a/src/include/execution/exec/execution_context.h b/src/include/execution/exec/execution_context.h index 1b7921cb36..5d4a85617e 100644 --- a/src/include/execution/exec/execution_context.h +++ b/src/include/execution/exec/execution_context.h @@ -188,12 +188,20 @@ class EXPORT ExecutionContext { */ void InitializeParallelOUFeatureVector(selfdriving::ExecOUFeatureVector *ouvec, pipeline_id_t pipeline_id); - // TODO(Kyle): Document + revisit this. - + /** + * Initialize the UDF parameter stack. + */ void StartParams() { udf_param_stack_.push_back({}); } + /** + * Remove an element from the UDF parameter stack. + */ void PopParams() { udf_param_stack_.pop_back(); } + /** + * Add a parameter to the set of parameters at the top of the UDF parameter stack. + * @param val The parameter to be added + */ void AddParam(common::ManagedPointer val) { udf_param_stack_.back().push_back(val.CastManagedPointerTo()); } @@ -226,10 +234,11 @@ class EXPORT ExecutionContext { } /** + * Get the parameter at the specified index. * @param param_idx index of parameter to access * @return immutable parameter at provided index */ - common::ManagedPointer GetParam(uint32_t param_idx) const { + common::ManagedPointer GetParam(std::size_t param_idx) const { return udf_param_stack_.empty() ? (*params_)[param_idx] : udf_param_stack_.back()[param_idx]; } @@ -255,8 +264,8 @@ class EXPORT ExecutionContext { void AddRowsAffected(int64_t num_rows) { rows_affected_ += num_rows; } /** - * @return On the primary, returns the ID of the last txn sent. - * On a replica, returns the ID of the last txn applied. + * @return On the primary, returns the ID of the last txn sent. + * On a replica, returns the ID of the last txn applied. */ uint64_t ReplicationGetLastTransactionId() const; @@ -374,6 +383,8 @@ class EXPORT ExecutionContext { common::ManagedPointer replication_manager_; common::ManagedPointer recovery_manager_; + // The stack of UDF parameters; each element in the stack + // is itself a (possibly-incomplete) set of parameters std::vector>> udf_param_stack_; bool memory_use_override_ = false; diff --git a/src/include/execution/sema/scope.h b/src/include/execution/sema/scope.h index 3a3411d43a..b30a5659b9 100644 --- a/src/include/execution/sema/scope.h +++ b/src/include/execution/sema/scope.h @@ -68,12 +68,14 @@ class Scope { ast::Type *LookupLocal(ast::Identifier name) const; /** - * TODO(Kyle): Document. + * Get the kind of the scope. + * @return The kind */ Kind GetKind() const; /** - * TODO(Kyle): Document. + * Get the local variables for the scope. + * @return A collection of the scope's locals */ std::vector> GetLocals() const; diff --git a/src/include/execution/vm/bytecode_generator.h b/src/include/execution/vm/bytecode_generator.h index 4876144f38..8b24f3f674 100644 --- a/src/include/execution/vm/bytecode_generator.h +++ b/src/include/execution/vm/bytecode_generator.h @@ -226,7 +226,6 @@ class BytecodeGenerator final : public ast::AstVisitor { ExpressionResultScope *execution_result_{nullptr}; // The loop builder for the current loop. - // TODO(Kyle): seems messy. LoopBuilder *current_loop_{nullptr}; }; diff --git a/src/include/parser/expression/column_value_expression.h b/src/include/parser/expression/column_value_expression.h index 9baebae0d9..0fcc401489 100644 --- a/src/include/parser/expression/column_value_expression.h +++ b/src/include/parser/expression/column_value_expression.h @@ -111,13 +111,11 @@ class ColumnValueExpression : public AbstractExpression { /** @return column oid */ catalog::col_oid_t GetColumnOid() const { return column_oid_; } - // TODO(Kyle): Why are we narrowing here? - /** @return parameter index */ std::int32_t GetParamIdx() const { return param_idx_; } /** @brief set the parameter index */ - void SetParamIdx(std::uint32_t param_idx) { param_idx_ = static_cast(param_idx); } + void SetParamIdx(const std::size_t param_idx) { param_idx_ = static_cast(param_idx); } /** * Get Column Full Name [tbl].[col] diff --git a/src/include/parser/expression/constant_value_expression.h b/src/include/parser/expression/constant_value_expression.h index 5d2a8c1ddc..90a588b2b7 100644 --- a/src/include/parser/expression/constant_value_expression.h +++ b/src/include/parser/expression/constant_value_expression.h @@ -106,8 +106,8 @@ class ConstantValueExpression : public AbstractExpression { } } - // TODO(Kyle): Is this safe? common::ManagedPointer GetVal() const { + NOISEPAGE_ASSERT(std::holds_alternative(value_), "GetVal() bad variant access"); return common::ManagedPointer(&std::get(value_)); } @@ -241,7 +241,8 @@ class ConstantValueExpression : public AbstractExpression { T Peek() const; /** - * TODO(Kyle): Document. + * Peek at the underlying value for the constant value expression. + * @return The underlying value pointer as a SQL value */ const execution::sql::Val *PeekPtr() const; @@ -266,10 +267,13 @@ class ConstantValueExpression : public AbstractExpression { private: friend class binder::BindNodeVisitor; /* value_ may be modified, e.g., when parsing dates. */ void Validate() const; + + // The undelrying constant value std::variant value_{execution::sql::Val(true)}; + // Buffer for inlined string values std::unique_ptr buffer_ = nullptr; }; diff --git a/src/include/parser/udf/udf_parser.h b/src/include/parser/udf/udf_parser.h index fd3d4a5ef3..5624f795cd 100644 --- a/src/include/parser/udf/udf_parser.h +++ b/src/include/parser/udf/udf_parser.h @@ -10,7 +10,6 @@ #include "parser/expression_util.h" #include "parser/postgresparser.h" -// TODO(Kyle): Do we want to place UDF parsing in its own namespace? namespace noisepage { // Forward declaration diff --git a/src/parser/expression/constant_value_expression.cpp b/src/parser/expression/constant_value_expression.cpp index b074d9ab50..b5a381c21e 100644 --- a/src/parser/expression/constant_value_expression.cpp +++ b/src/parser/expression/constant_value_expression.cpp @@ -102,8 +102,8 @@ T ConstantValueExpression::Peek() const { } const execution::sql::Val *ConstantValueExpression::PeekPtr() const { - // TODO(Kyle): seems unsafe. - return reinterpret_cast(&value_); + NOISEPAGE_ASSERT(std::holds_alternative(value_), "PeekPtr() bad variant access"); + return &std::get(value_); } ConstantValueExpression &ConstantValueExpression::operator=(const ConstantValueExpression &other) { diff --git a/src/parser/postgresparser.cpp b/src/parser/postgresparser.cpp index 0e4ff3e2fd..9e2e7b68d2 100644 --- a/src/parser/postgresparser.cpp +++ b/src/parser/postgresparser.cpp @@ -1291,7 +1291,6 @@ std::unique_ptr PostgresParser::CreateFunctionTransform(ParseResul } default: { // TODO(WAN): previous code just ignored it, is this right? - // TODO(Kyle): Good question^ break; } } diff --git a/src/parser/udf/udf_parser.cpp b/src/parser/udf/udf_parser.cpp index 7022ab2726..12d470591c 100644 --- a/src/parser/udf/udf_parser.cpp +++ b/src/parser/udf/udf_parser.cpp @@ -121,8 +121,8 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nlohmann::json &block) stmts.push_back(ParseIf(stmt[kPLpgSQL_stmt_if])); } else if (stmt_names.key() == kPLpgSQL_stmt_assign) { // TODO[Siva]: Need to fix Assignment expression / statement - const std::string &var_name = - udf_ast_context_->GetVariableAtIndex(stmt[kPLpgSQL_stmt_assign][kVarno].get()); + const auto &var_name = + udf_ast_context_->GetLocalVariableAtIndex(stmt[kPLpgSQL_stmt_assign][kVarno].get()); std::unique_ptr lhs(new VariableExprAST(var_name)); auto rhs = ParseExprSQL(stmt[kPLpgSQL_stmt_assign][kExpr][kPLpgSQL_expr][kQuery].get()); std::unique_ptr ass_expr_ast(new AssignStmtAST(std::move(lhs), std::move(rhs))); From eee35bbcead17f2e0381e39a26b2c9aff54f8916 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Mon, 24 May 2021 09:14:45 -0400 Subject: [PATCH 035/139] fix build after merge --- src/execution/compiler/udf/udf_codegen.cpp | 6 ++++-- src/include/execution/compiler/executable_query.h | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 11b1a9c167..80425c98b9 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -366,7 +366,7 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { } execution::ast::LambdaExpr *lambda_expr{}; - FunctionBuilder fn(codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)); + FunctionBuilder fn{codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)}; { std::size_t j{1}; for (auto var : var_idents) { @@ -392,11 +392,12 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { // We want to pass something down that will materialize the lambda // function into lambda_expr and will also feed in a lambda_expr to the compiler + // TODO(Kyle): Using a NULL plan metatdata here... execution::exec::ExecutionSettings exec_settings{}; const std::string dummy_query = ""; auto exec_query = execution::compiler::CompilationContext::Compile( *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, std::nullopt, - common::ManagedPointer(&dummy_query), lambda_expr, codegen_->GetAstContext()); + common::ManagedPointer{}, common::ManagedPointer{&dummy_query}, lambda_expr, codegen_->GetAstContext()); auto decls = exec_query->GetDecls(); aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); @@ -553,6 +554,7 @@ void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { const std::string dummy_query = ""; auto exec_query = execution::compiler::CompilationContext::Compile( *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, std::nullopt, + common::ManagedPointer{}, common::ManagedPointer(&dummy_query), lambda_expr, codegen_->GetAstContext()); auto decls = exec_query->GetDecls(); diff --git a/src/include/execution/compiler/executable_query.h b/src/include/execution/compiler/executable_query.h index 000b27938c..dd01fd6307 100644 --- a/src/include/execution/compiler/executable_query.h +++ b/src/include/execution/compiler/executable_query.h @@ -256,9 +256,14 @@ class ExecutableQuery { */ void SetQueryId(query_id_t query_id) { query_id_ = query_id; } + // The name of the query std::string query_name_; + // The query identitifier query_id_t query_id_; + // TODO(Kyle): What is this for? static std::atomic query_identifier; + // The text of the query + common::ManagedPointer query_text_; // MiniRunners needs to set query_identifier and pipeline_operating_units_. friend class noisepage::runner::ExecutionRunners; From bb64f35564b64bad602e9c7becc0281f96f79525 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 17 Jun 2021 11:05:49 -0400 Subject: [PATCH 036/139] various updates to ast, ast tests, and translation from constant value expressions to sql values, binder test is still broken --- src/include/execution/ast/ast.h | 18 +++- .../execution/ast/ast_traversal_visitor.h | 18 ++++ .../execution/exec/execution_context.h | 32 +++---- src/include/execution/vm/bytecode_handlers.h | 10 +- .../expression/constant_value_expression.h | 59 ++++++------ .../expression/constant_value_expression.cpp | 12 +-- src/traffic_cop/traffic_cop.cpp | 2 +- src/util/query_exec_util.cpp | 2 +- test/execution/ast_test.cpp | 6 ++ test/execution/compiler_test.cpp | 93 ++++++++++++------- 10 files changed, 159 insertions(+), 93 deletions(-) diff --git a/src/include/execution/ast/ast.h b/src/include/execution/ast/ast.h index 34cc7de336..27b57d8513 100644 --- a/src/include/execution/ast/ast.h +++ b/src/include/execution/ast/ast.h @@ -714,10 +714,10 @@ class BreakStmt : public Stmt { BreakStmt(const SourcePosition &pos) : Stmt(Kind::BreakStmt, pos) {} /** - * Is the given node a return statement? + * Is the given node a break statement? * Needed as part of the custom AST RTTI infrastructure. * @param node The node to check. - * @return `true` if the node is a return statement, `false` otherwise. + * @return `true` if the node is a break statement, `false` otherwise. */ static bool classof(const AstNode *node) { return node->GetKind() == Kind::BreakStmt; } }; @@ -1130,6 +1130,15 @@ class LambdaExpr : public Expr { */ const util::RegionVector &GetCaptureIdents() const { return capture_idents_; } + /** + * Is the given node a lambda expression? Needed as part of the custom AST RTTI infrastructure. + * @param node The node to check. + * @return `true` if the node is a lambda expression; `false` otherwise. + */ + static bool classof(const AstNode *node) { // NOLINT + return node->GetKind() == Kind::LambdaExpr; + } + private: friend class sema::Sema; // The identifier for the lambda expression. @@ -1930,6 +1939,11 @@ class LambdaTypeRepr : public Expr { */ Expr *FunctionType() const { return fn_type_; } + /** + * Is the given node a lambda type representation? Needed as part of the custom AST RTTI infrastructure. + * @param node The node to check. + * @return `true` if the node is a lambda type representation; `false` otherwise. + */ static bool classof(const AstNode *node) { // NOLINT return node->GetKind() == Kind::LambdaTypeRepr; } diff --git a/src/include/execution/ast/ast_traversal_visitor.h b/src/include/execution/ast/ast_traversal_visitor.h index f4e90e5d32..7a0ff7c203 100644 --- a/src/include/execution/ast/ast_traversal_visitor.h +++ b/src/include/execution/ast/ast_traversal_visitor.h @@ -210,6 +210,12 @@ inline void AstTraversalVisitor::VisitForStmt(ForStmt *node) { RECURSE(Visit(node->Body())); } +template +inline void AstTraversalVisitor::VisitBreakStmt(BreakStmt *node) { + PROCESS_NODE(node); + // TODO(Kyle): Implement this?? +} + template inline void AstTraversalVisitor::VisitForInStmt(ForInStmt *node) { PROCESS_NODE(node); @@ -232,6 +238,12 @@ inline void AstTraversalVisitor::VisitMapTypeRepr(MapTypeRepr *node) { RECURSE(Visit(node->ValType())); } +template +inline void AstTraversalVisitor::VisitLambdaTypeRepr(LambdaTypeRepr *node) { + PROCESS_NODE(node); + // TODO(Kyle): Implement this?? +} + template inline void AstTraversalVisitor::VisitLitExpr(LitExpr *node) { PROCESS_NODE(node); @@ -294,6 +306,12 @@ inline void AstTraversalVisitor::VisitIndexExpr(IndexExpr *node) { RECURSE(Visit(node->Index())); } +template +inline void AstTraversalVisitor::VisitLambdaExpr(LambdaExpr *node) { + PROCESS_NODE(node); + // TODO(Kyle): Implement this?? +} + template inline void AstTraversalVisitor::VisitFunctionTypeRepr(FunctionTypeRepr *node) { PROCESS_NODE(node); diff --git a/src/include/execution/exec/execution_context.h b/src/include/execution/exec/execution_context.h index 5d4a85617e..7b327e5d8c 100644 --- a/src/include/execution/exec/execution_context.h +++ b/src/include/execution/exec/execution_context.h @@ -140,21 +140,17 @@ class EXPORT ExecutionContext { */ static uint32_t ComputeTupleSize(const planner::OutputSchema *schema); - /** - * @return The catalog accessor. - */ + /** @return The catalog accessor. */ catalog::CatalogAccessor *GetAccessor() { return accessor_.Get(); } /** @return The execution settings. */ const exec::ExecutionSettings &GetExecutionSettings() const { return exec_settings_; } - /** - * Start the resource tracker - */ + /** Start the resource tracker. */ void StartResourceTracker(metrics::MetricsComponent component); /** - * End the resource tracker and record the metrics + * End the resource tracker and record the metrics. * @param name the string name get printed out with the time * @param len the length of the string name */ @@ -188,14 +184,10 @@ class EXPORT ExecutionContext { */ void InitializeParallelOUFeatureVector(selfdriving::ExecOUFeatureVector *ouvec, pipeline_id_t pipeline_id); - /** - * Initialize the UDF parameter stack. - */ + /** Initialize the UDF parameter stack. */ void StartParams() { udf_param_stack_.push_back({}); } - /** - * Remove an element from the UDF parameter stack. - */ + /** Remove an element from the UDF parameter stack. */ void PopParams() { udf_param_stack_.pop_back(); } /** @@ -206,15 +198,13 @@ class EXPORT ExecutionContext { udf_param_stack_.back().push_back(val.CastManagedPointerTo()); } - /** - * @return the db oid - */ + /** @return The database OID. */ catalog::db_oid_t DBOid() { return db_oid_; } /** * Set the mode for this execution. - * This only records the mode and serves the metrics collection purpose, which does not have any impact on the - * actual execution. + * This only records the mode and serves the metrics collection purpose, + * which does not have any impact on the actual execution. * @param mode the integer value of the execution mode to record */ void SetExecutionMode(uint8_t mode) { execution_mode_ = mode; } @@ -227,7 +217,7 @@ class EXPORT ExecutionContext { /** * Set the execution parameters. - * @param params The execution parameters. + * @param params The execution parameters */ void SetParams(common::ManagedPointer>> params) { params_ = params; @@ -236,9 +226,11 @@ class EXPORT ExecutionContext { /** * Get the parameter at the specified index. * @param param_idx index of parameter to access - * @return immutable parameter at provided index + * @return An immutable point to the parameter at specified index */ common::ManagedPointer GetParam(std::size_t param_idx) const { + // TODO(Kyle): This logic is confusing, why are we transparently + // switching between the "regular" parameters and the UDF parameters? return udf_param_stack_.empty() ? (*params_)[param_idx] : udf_param_stack_.back()[param_idx]; } diff --git a/src/include/execution/vm/bytecode_handlers.h b/src/include/execution/vm/bytecode_handlers.h index 934a44fd2d..8da087b5a2 100644 --- a/src/include/execution/vm/bytecode_handlers.h +++ b/src/include/execution/vm/bytecode_handlers.h @@ -2172,12 +2172,19 @@ VM_OP_WARM void OpExtractYearFromDate(noisepage::execution::sql::Integer *result } } +// --------------------------------- +// Transaction Calls +// --------------------------------- + VM_OP_WARM void OpAbortTxn(noisepage::execution::exec::ExecutionContext *exec_ctx) { exec_ctx->GetTxn()->SetMustAbort(); throw noisepage::ABORT_EXCEPTION("transaction aborted"); } -// Parameter calls +// --------------------------------- +// Parameter Calls +// --------------------------------- + // TODO(Kyle): this used to have a conditional check; was it safe to remove? #define GEN_SCALAR_PARAM_GET(Name, SqlType) \ VM_OP_HOT void OpGetParam##Name(noisepage::execution::sql::SqlType *ret, \ @@ -2199,7 +2206,6 @@ GEN_SCALAR_PARAM_GET(TimestampVal, TimestampVal) GEN_SCALAR_PARAM_GET(String, StringVal) #undef GEN_SCALAR_PARAM_GET -// Parameter calls #define GEN_SCALAR_PARAM_ADD(Name, SqlType, typeId) \ VM_OP_HOT void OpAddParam##Name(noisepage::execution::exec::ExecutionContext *exec_ctx, \ noisepage::execution::sql::SqlType *ret) { \ diff --git a/src/include/parser/expression/constant_value_expression.h b/src/include/parser/expression/constant_value_expression.h index e435b6bc73..49d479cd1d 100644 --- a/src/include/parser/expression/constant_value_expression.h +++ b/src/include/parser/expression/constant_value_expression.h @@ -76,13 +76,22 @@ class ConstantValueExpression : public AbstractExpression { */ ConstantValueExpression(const ConstantValueExpression &other); + /** + * Compute a hash for the expression. + * @return The hash value + */ common::hash_t Hash() const override; + /** + * Equality comparison. + * @param other The other ConstantValueExpression instance + * @return `true` if the instances are equivalent, `false` otherwise + */ bool operator==(const AbstractExpression &other) const override; /** * Copies this ConstantValueExpression - * @returns copy of this + * @returns A copy of `this` */ std::unique_ptr Copy() const override { return std::unique_ptr{std::make_unique(*this)}; @@ -111,56 +120,44 @@ class ConstantValueExpression : public AbstractExpression { return common::ManagedPointer(&std::get(value_)); } - /** - * @return copy of the underlying Val - */ + /** @return A copy of the underlying Val */ execution::sql::BoolVal GetBoolVal() const { NOISEPAGE_ASSERT(std::holds_alternative(value_), "Invalid variant type for Get."); return std::get(value_); } - /** - * @return copy of the underlying Val - */ + /** @return A copy of the underlying Val */ execution::sql::Integer GetInteger() const { NOISEPAGE_ASSERT(std::holds_alternative(value_), "Invalid variant type for Get."); return std::get(value_); } - /** - * @return copy of the underlying Val - */ + /** @return A copy of the underlying Val */ execution::sql::Real GetReal() const { NOISEPAGE_ASSERT(std::holds_alternative(value_), "Invalid variant type for Get."); return std::get(value_); } - /** - * @return copy of underlying Val - */ + /** @return A copy of underlying Val */ execution::sql::DecimalVal GetDecimalVal() const { NOISEPAGE_ASSERT(std::holds_alternative(value_), "Invalid variant type for Get."); return std::get(value_); } - /** - * @return copy of the underlying Val - */ + /** @return A copy of the underlying Val */ execution::sql::DateVal GetDateVal() const { NOISEPAGE_ASSERT(std::holds_alternative(value_), "Invalid variant type for Get."); return std::get(value_); } - /** - * @return copy of the underlying Val - */ + /** @return A copy of the underlying Val */ execution::sql::TimestampVal GetTimestampVal() const { NOISEPAGE_ASSERT(std::holds_alternative(value_), "Invalid variant type for Get."); return std::get(value_); } /** - * @return copy of the underlying Val + * @return A copy of the underlying Val * @warning StringVal may not have inlined its value, in which case the StringVal returned by this function will hold * a pointer to the buffer in this CVE. In that case, do not destroy this CVE before the copied StringVal */ @@ -196,9 +193,7 @@ class ConstantValueExpression : public AbstractExpression { Validate(); } - /** - * @return true if CVE value represents a NULL - */ + /** @return `true` if CVE value represents a NULL, `false` otherwise */ bool IsNull() const { if (std::holds_alternative(value_) && std::get(value_).is_null_) return true; @@ -231,21 +226,25 @@ class ConstantValueExpression : public AbstractExpression { } /** - * Extracts the underlying execution value as a C++ type + * Extracts the underlying execution value as a C++ type. * @tparam T C++ type to extract * @return copy of the underlying value as the requested type - * @warning std::string_view returned by this function will hold a pointer to the buffer in this CVE. In that case, do - * not destroy this CVE before the std::string_view + * @warning std::string_view returned by this function will hold a pointer to the buffer in this CVE. + * In that case, do not destroy this CVE before the std::string_view */ template T Peek() const; /** - * Peek at the underlying value for the constant value expression. - * @return The underlying value pointer as a SQL value + * Get a pointer to the underlying value as a generic SQL type. + * @return An immutable pointer to the underlying value */ - const execution::sql::Val *PeekPtr() const; + const execution::sql::Val *SqlValue() const; + /** + * Visitor pattern for binder. + * @param v The SqlNodeVisitor + */ void Accept(common::ManagedPointer v) override; /** @return A string representation of this ConstantValueExpression. */ @@ -268,7 +267,7 @@ class ConstantValueExpression : public AbstractExpression { friend class binder::BindNodeVisitor; /* value_ may be modified, e.g., when parsing dates. */ void Validate() const; - // The undelrying constant value + // The underlying constant value std::variant diff --git a/src/parser/expression/constant_value_expression.cpp b/src/parser/expression/constant_value_expression.cpp index b5a381c21e..a36b4eb6d6 100644 --- a/src/parser/expression/constant_value_expression.cpp +++ b/src/parser/expression/constant_value_expression.cpp @@ -68,8 +68,7 @@ T ConstantValueExpression::Peek() const { } // NOLINTNEXTLINE: bugprone-suspicious-semicolon: seems like a false positive because of constexpr if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v) { // NOLINT: bugprone-suspicious-semicolon: seems like a false positive - // because of constexpr + std::is_same_v) { return static_cast(GetInteger().val_); } // NOLINTNEXTLINE: bugprone-suspicious-semicolon: seems like a false positive because of constexpr @@ -101,9 +100,11 @@ T ConstantValueExpression::Peek() const { UNREACHABLE("Invalid type for Peek."); } -const execution::sql::Val *ConstantValueExpression::PeekPtr() const { - NOISEPAGE_ASSERT(std::holds_alternative(value_), "PeekPtr() bad variant access"); - return &std::get(value_); +const execution::sql::Val *ConstantValueExpression::SqlValue() const { + // TODO(Kyle): This solution is a bit hacky, we might want to + // consider revisiting (no pun intended) the way that we manage + // parameters provided to the execution context to resolve + return std::visit([](auto &&val) { return static_cast(&val); }, value_); } ConstantValueExpression &ConstantValueExpression::operator=(const ConstantValueExpression &other) { @@ -133,7 +134,6 @@ ConstantValueExpression &ConstantValueExpression::operator=(const ConstantValueE ConstantValueExpression::ConstantValueExpression(const ConstantValueExpression &other) : AbstractExpression(other) { if (std::holds_alternative(other.value_)) { auto string_val = execution::sql::ValueUtil::CreateStringVal(other.GetStringVal()); - value_ = string_val.first; buffer_ = std::move(string_val.second); } else { diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index 22e1fb179a..8737027b9c 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -589,7 +589,7 @@ TrafficCopResult TrafficCop::RunExecutableQuery(const common::ManagedPointer> params{}; params.reserve(portal->Parameters()->size()); std::transform(portal->Parameters()->cbegin(), portal->Parameters()->cend(), std::back_inserter(params), - [](const parser::ConstantValueExpression &cve) { return common::ManagedPointer{cve.PeekPtr()}; }); + [](const parser::ConstantValueExpression &cve) { return common::ManagedPointer{cve.SqlValue()}; }); exec_ctx->SetParams(common::ManagedPointer(¶ms)); const auto exec_query = portal->GetStatement()->GetExecutableQuery(); diff --git a/src/util/query_exec_util.cpp b/src/util/query_exec_util.cpp index 6d1a1e3d86..72bc91ac00 100644 --- a/src/util/query_exec_util.cpp +++ b/src/util/query_exec_util.cpp @@ -288,7 +288,7 @@ bool QueryExecUtil::ExecuteQuery(const std::string &statement, TupleFunction tup std::vector> value_params{}; value_params.reserve(params->size()); std::transform(params->cbegin(), params->cend(), std::back_inserter(value_params), - [](const parser::ConstantValueExpression &cve) { return common::ManagedPointer{cve.PeekPtr()}; }); + [](const parser::ConstantValueExpression &cve) { return common::ManagedPointer{cve.SqlValue()}; }); exec_ctx->SetParams(common::ManagedPointer(&value_params)); NOISEPAGE_ASSERT(!txn->MustAbort(), "Transaction should not be in must-abort state prior to executing"); diff --git a/test/execution/ast_test.cpp b/test/execution/ast_test.cpp index e9a6e5449a..23f369ddcc 100644 --- a/test/execution/ast_test.cpp +++ b/test/execution/ast_test.cpp @@ -64,6 +64,11 @@ TEST_F(AstTest, HierarchyTest) { factory.NewCallExpr(factory.NewNilLiteral(EmptyPos()), util::RegionVector(Region())), factory.NewFunctionLitExpr( factory.NewFunctionType(EmptyPos(), util::RegionVector(Region()), nullptr), nullptr), + factory.NewLambdaExpr( + EmptyPos(), + factory.NewFunctionLitExpr( + factory.NewFunctionType(EmptyPos(), util::RegionVector(Region()), nullptr), nullptr), + util::RegionVector(Region())), factory.NewNilLiteral(EmptyPos()), factory.NewUnaryOpExpr(EmptyPos(), parsing::Token::Type::MINUS, nullptr), factory.NewIdentifierExpr(EmptyPos(), Identifier()), @@ -96,6 +101,7 @@ TEST_F(AstTest, HierarchyTest) { factory.NewDeclStmt(factory.NewVariableDecl(EmptyPos(), Identifier(), nullptr, nullptr)), factory.NewExpressionStmt(factory.NewNilLiteral(EmptyPos())), factory.NewForStmt(EmptyPos(), nullptr, nullptr, nullptr, nullptr), + factory.NewBreakStmt(EmptyPos()), factory.NewIfStmt(EmptyPos(), nullptr, nullptr, nullptr), factory.NewReturnStmt(EmptyPos(), nullptr), }; diff --git a/test/execution/compiler_test.cpp b/test/execution/compiler_test.cpp index e5cef4d5ae..d7fa17bb02 100644 --- a/test/execution/compiler_test.cpp +++ b/test/execution/compiler_test.cpp @@ -70,6 +70,27 @@ class CompilerTest : public SqlBasedTest { static constexpr vm::ExecutionMode MODE = vm::ExecutionMode::Interpret; }; +/** + * Transform the parameters vector supplied to an executable query. + * + * TODO(Kyle): This function is a hack that results from a refactor of + * the API for executable queries. Eventually, when we actually get + * around to refactoring the compiler tests, we should remove this and + * just fix the API itself. + * + * @param parameters The input parameters collection + * @return A non-owning collection of parameters in the format + * expected by the ExecutableQuery API + */ +static std::unique_ptr>> TransformParameters( + const std::vector ¶meters) { + auto params = std::make_unique>>(); + params->reserve(parameters.size()); + std::transform(parameters.cbegin(), parameters.cend(), std::back_inserter(*params), + [](const parser::ConstantValueExpression &cve) { return common::ManagedPointer{cve.SqlValue()}; }); + return params; +} + // NOLINTNEXTLINE TEST_F(CompilerTest, CompileFromSource) { util::Region region{"compiler_test"}; @@ -417,11 +438,13 @@ TEST_F(CompilerTest, SimpleSeqScanWithParamsTest) { MultiOutputCallback callback{std::vector{store, printer}}; exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); auto exec_ctx = MakeExecCtx(&callback_fn, seq_scan->GetOutputSchema().Get()); - std::vector params; - params.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(100)); - params.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(500)); - params.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(3)); - exec_ctx->SetParams(common::ManagedPointer>(¶ms)); + + std::vector param_builder{}; + param_builder.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(100)); + param_builder.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(500)); + param_builder.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(3)); + auto params = TransformParameters(param_builder); + exec_ctx->SetParams(common::ManagedPointer(params)); // Run & Check auto executable = execution::compiler::CompilationContext::Compile(*seq_scan, exec_ctx->GetExecutionSettings(), @@ -3074,10 +3097,12 @@ TEST_F(CompilerTest, InsertIntoSelectWithParamTest) { MultiOutputCallback callback{std::vector{}}; exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); auto exec_ctx = MakeExecCtx(&callback_fn, insert->GetOutputSchema().Get()); - std::vector params; - params.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(495)); - params.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(505)); - exec_ctx->SetParams(common::ManagedPointer>(¶ms)); + std::vector params_builder{}; + params_builder.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(495)); + params_builder.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(505)); + auto params = TransformParameters(params_builder); + exec_ctx->SetParams(common::ManagedPointer(params)); + auto executable = execution::compiler::CompilationContext::Compile(*insert, exec_ctx->GetExecutionSettings(), exec_ctx->GetAccessor()); executable->Run(common::ManagedPointer(exec_ctx), MODE); @@ -3297,28 +3322,33 @@ TEST_F(CompilerTest, SimpleInsertWithParamsTest) { MultiOutputCallback callback{std::vector{}}; exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); auto exec_ctx = MakeExecCtx(&callback_fn, insert->GetOutputSchema().Get()); - std::vector params; + std::vector params_builder{}; + // First parameter list auto str1_val = sql::ValueUtil::CreateStringVal(str1); - params.emplace_back(type::TypeId::VARCHAR, str1_val.first, std::move(str1_val.second)); - params.emplace_back(type::TypeId::DATE, sql::DateVal(date1.val_)); - params.emplace_back(type::TypeId::REAL, sql::Real(real1)); - params.emplace_back(type::TypeId::BOOLEAN, sql::BoolVal(bool1)); - params.emplace_back(type::TypeId::TINYINT, sql::Integer(tinyint1)); - params.emplace_back(type::TypeId::SMALLINT, sql::Integer(smallint1)); - params.emplace_back(type::TypeId::INTEGER, sql::Integer(int1)); - params.emplace_back(type::TypeId::BIGINT, sql::Integer(bigint1)); + params_builder.emplace_back(type::TypeId::VARCHAR, str1_val.first, std::move(str1_val.second)); + params_builder.emplace_back(type::TypeId::DATE, sql::DateVal(date1.val_)); + params_builder.emplace_back(type::TypeId::REAL, sql::Real(real1)); + params_builder.emplace_back(type::TypeId::BOOLEAN, sql::BoolVal(bool1)); + params_builder.emplace_back(type::TypeId::TINYINT, sql::Integer(tinyint1)); + params_builder.emplace_back(type::TypeId::SMALLINT, sql::Integer(smallint1)); + params_builder.emplace_back(type::TypeId::INTEGER, sql::Integer(int1)); + params_builder.emplace_back(type::TypeId::BIGINT, sql::Integer(bigint1)); + // Second parameter list auto str2_val = sql::ValueUtil::CreateStringVal(str2); - params.emplace_back(type::TypeId::VARCHAR, str2_val.first, std::move(str2_val.second)); - params.emplace_back(type::TypeId::DATE, sql::DateVal(date2.val_)); - params.emplace_back(type::TypeId::REAL, sql::Real(real2)); - params.emplace_back(type::TypeId::BOOLEAN, sql::BoolVal(bool2)); - params.emplace_back(type::TypeId::TINYINT, sql::Integer(tinyint2)); - params.emplace_back(type::TypeId::SMALLINT, sql::Integer(smallint2)); - params.emplace_back(type::TypeId::INTEGER, sql::Integer(int2)); - params.emplace_back(type::TypeId::BIGINT, sql::Integer(bigint2)); - exec_ctx->SetParams(common::ManagedPointer>(¶ms)); + params_builder.emplace_back(type::TypeId::VARCHAR, str2_val.first, std::move(str2_val.second)); + params_builder.emplace_back(type::TypeId::DATE, sql::DateVal(date2.val_)); + params_builder.emplace_back(type::TypeId::REAL, sql::Real(real2)); + params_builder.emplace_back(type::TypeId::BOOLEAN, sql::BoolVal(bool2)); + params_builder.emplace_back(type::TypeId::TINYINT, sql::Integer(tinyint2)); + params_builder.emplace_back(type::TypeId::SMALLINT, sql::Integer(smallint2)); + params_builder.emplace_back(type::TypeId::INTEGER, sql::Integer(int2)); + params_builder.emplace_back(type::TypeId::BIGINT, sql::Integer(bigint2)); + + auto params = TransformParameters(params_builder); + exec_ctx->SetParams(common::ManagedPointer(params)); + auto executable = execution::compiler::CompilationContext::Compile(*insert, exec_ctx->GetExecutionSettings(), exec_ctx->GetAccessor()); executable->Run(common::ManagedPointer(exec_ctx), MODE); @@ -3498,12 +3528,13 @@ TEST_F(CompilerTest, SimpleInsertWithParamsTest) { MultiOutputCallback callback{std::vector{store, printer}}; exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); auto exec_ctx = MakeExecCtx(&callback_fn, index_scan->GetOutputSchema().Get()); - std::vector params; + std::vector params_builder{}; auto str1_val = sql::ValueUtil::CreateStringVal(str1); auto str2_val = sql::ValueUtil::CreateStringVal(str2); - params.emplace_back(type::TypeId::VARCHAR, str1_val.first, std::move(str1_val.second)); - params.emplace_back(type::TypeId::VARCHAR, str2_val.first, std::move(str2_val.second)); - exec_ctx->SetParams(common::ManagedPointer>(¶ms)); + params_builder.emplace_back(type::TypeId::VARCHAR, str1_val.first, std::move(str1_val.second)); + params_builder.emplace_back(type::TypeId::VARCHAR, str2_val.first, std::move(str2_val.second)); + auto params = TransformParameters(params_builder); + exec_ctx->SetParams(common::ManagedPointer(params)); auto executable = execution::compiler::CompilationContext::Compile(*index_scan, exec_ctx->GetExecutionSettings(), exec_ctx->GetAccessor()); executable->Run(common::ManagedPointer(exec_ctx), MODE); From decb9d15d034af9d7548b8846da3cd0798994350 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 17 Jun 2021 22:20:45 -0400 Subject: [PATCH 037/139] fixed error causing binder tests to fail, starting in on linting errors --- src/binder/bind_node_visitor.cpp | 22 ++++---- src/execution/ast/ast_clone.cpp | 17 +++--- .../expression/column_value_expression.h | 6 +++ src/include/type/type_id.h | 11 ++++ src/type/type_id.cpp | 53 +++++++++++++++++++ 5 files changed, 90 insertions(+), 19 deletions(-) diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 15fff55281..603f45c3f8 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -680,14 +680,14 @@ void BindNodeVisitor::Visit(common::ManagedPointerGetDesiredType(expr.CastManagedPointerTo()); + // Before checking with the schema, cache the desired type that expr should have + const auto cached_desired_type = sherpa_->GetDesiredType(expr.CastManagedPointerTo()); // TODO(Ling): consider remove precondition check if the *_oid_ will never be initialized till binder // That is, the object would not be initialized using ColumnValueExpression(database_oid, table_oid, column_oid) // at this point if (expr->GetTableOid() == catalog::INVALID_TABLE_OID) { - std::tuple tuple; + std::tuple tuple{}; std::string table_name = expr->GetTableName(); std::string col_name = expr->GetColumnName(); if (table_name.empty() && col_name.empty() && expr->GetColumnOid() != catalog::INVALID_COLUMN_OID) { @@ -695,11 +695,11 @@ void BindNodeVisitor::Visit(common::ManagedPointerGetColumnOid().UnderlyingValue())), common::ErrorCode::ERRCODE_UNDEFINED_COLUMN); } - // Convert all the names to lower cases + // Convert all the names to lower case std::transform(table_name.begin(), table_name.end(), table_name.begin(), ::tolower); std::transform(col_name.begin(), col_name.end(), col_name.begin(), ::tolower); - // Table name not specified in the expression. Loop through all the table in the binder context. + // Table name not specified in the expression; loop through all the tables in the binder context type::TypeId the_type{}; if (table_name.empty()) { if (udf_ast_context_ != nullptr && udf_ast_context_->GetVariableType(expr->GetColumnName(), &the_type)) { @@ -715,7 +715,7 @@ void BindNodeVisitor::Visit(common::ManagedPointerGetRegularTableObj(table_name, expr, common::ManagedPointer(&tuple))) { if (!BinderContext::ColumnInSchema(std::get<2>(tuple), col_name)) { throw BINDER_EXCEPTION(fmt::format("column \"{}\" does not exist", col_name), @@ -723,7 +723,6 @@ void BindNodeVisitor::Visit(common::ManagedPointerGetVariableType(expr->GetTableName(), &the_type)) { - // record type NOISEPAGE_ASSERT(the_type == type::TypeId::INVALID, "unknown type"); auto &fields = udf_ast_context_->GetRecordType(expr->GetTableName()); auto it = std::find_if(fields.begin(), fields.end(), [=](auto p) { return p.first == expr->GetColumnName(); }); @@ -735,17 +734,18 @@ void BindNodeVisitor::Visit(common::ManagedPointerSetReturnValueType(it->second); expr->SetParamIdx(idx); - } else if (context_ == nullptr || !context_->CheckNestedTableColumn(table_name, col_name, expr)) { - throw BINDER_EXCEPTION(fmt::format("Invalid table reference {}", expr->GetTableName()), - common::ErrorCode::ERRCODE_UNDEFINED_TABLE); } + } else if (context_ == nullptr || !context_->CheckNestedTableColumn(table_name, col_name, expr)) { + throw BINDER_EXCEPTION(fmt::format("Invalid table reference {}", expr->GetTableName()), + common::ErrorCode::ERRCODE_UNDEFINED_TABLE); } } } // The schema is authoritative on what the type of this ColumnValueExpression should be, UNLESS // some specific type was already requested. - desired_type = desired_type == type::TypeId::INVALID ? expr->GetReturnValueType() : desired_type; + const auto desired_type = + cached_desired_type == type::TypeId::INVALID ? expr->GetReturnValueType() : cached_desired_type; sherpa_->SetDesiredType(expr.CastManagedPointerTo(), desired_type); sherpa_->CheckDesiredType(expr.CastManagedPointerTo()); } diff --git a/src/execution/ast/ast_clone.cpp b/src/execution/ast/ast_clone.cpp index 875803a6fc..3b27769807 100644 --- a/src/execution/ast/ast_clone.cpp +++ b/src/execution/ast/ast_clone.cpp @@ -26,24 +26,25 @@ class AstCloneImpl : public AstVisitor { AST_NODES(DECLARE_VISIT_METHOD) #undef DECLARE_VISIT_METHOD - Identifier CloneIdentifier(Identifier &ident) { return new_context_->GetIdentifier(ident.GetData()); } + Identifier CloneIdentifier(const Identifier &ident) { return new_context_->GetIdentifier(ident.GetData()); } - Identifier CloneIdentifier(Identifier &&ident) { + Identifier CloneIdentifier(const Identifier &&ident) { (void)old_context_; return new_context_->GetIdentifier(ident.GetData()); } private: - // The root of the AST to clone. + /** The root of the AST to clone. */ AstNode *root_; - // The AST node factory used to allocate new nodes. + + /** The AST node factory used to allocate new nodes. */ AstNodeFactory *factory_; - // The AST context of the source AST. + + /** The AST context of the source AST. */ Context *old_context_; - // The AST context of the destination AST. - Context *new_context_; - // llvm::DenseMap allocated_strings_; + /** The AST context of the destination AST. */ + Context *new_context_; }; AstNode *AstCloneImpl::VisitFile(File *node) { diff --git a/src/include/parser/expression/column_value_expression.h b/src/include/parser/expression/column_value_expression.h index 08eb3ff12c..8f1b28edc2 100644 --- a/src/include/parser/expression/column_value_expression.h +++ b/src/include/parser/expression/column_value_expression.h @@ -193,19 +193,25 @@ class ColumnValueExpression : public AbstractExpression { private: friend class binder::BinderContext; friend class execution::sql::TableGenerator; + /** @param database_oid Database OID to be assigned to this expression */ void SetDatabaseOID(catalog::db_oid_t database_oid) { database_oid_ = database_oid; } + /** @param table_oid Table OID to be assigned to this expression */ void SetTableOID(catalog::table_oid_t table_oid) { table_oid_ = table_oid; } + /** @param column_oid Column OID to be assigned to this expression */ void SetColumnOID(catalog::col_oid_t column_oid) { column_oid_ = column_oid; } + /** @param table_oid Table OID to be assigned to this expression */ void SetTableName(const std::string &table_name) { table_name_ = std::string(table_name); } + /** @param column_oid Column OID to be assigned to this expression */ void SetColumnName(const std::string &col_name) { column_name_ = std::string(col_name); } /** Table name. */ std::string table_name_; + /** Column name. */ std::string column_name_; diff --git a/src/include/type/type_id.h b/src/include/type/type_id.h index 31a6bdf3f8..1f8d7ab135 100644 --- a/src/include/type/type_id.h +++ b/src/include/type/type_id.h @@ -1,5 +1,8 @@ #pragma once +#include + +#include "common/macros.h" #include "common/strong_typedef.h" namespace noisepage::type { @@ -38,4 +41,12 @@ enum class TypeId : uint8_t { VAR_ARRAY, ///< pg_type requires a distinct type for var_array. }; +/** + * Operator overload for printing TypeId values. + * @param os The output stream + * @param type_id The type ID + * @return The output stream + */ +std::ostream &operator<<(std::ostream &os, TypeId type_id); + } // namespace noisepage::type diff --git a/src/type/type_id.cpp b/src/type/type_id.cpp index d6887a11d5..81fba4481e 100644 --- a/src/type/type_id.cpp +++ b/src/type/type_id.cpp @@ -18,4 +18,57 @@ STRONG_TYPEDEF_BODY(date_t, uint32_t); */ STRONG_TYPEDEF_BODY(timestamp_t, uint64_t); +std::ostream &operator<<(std::ostream &os, TypeId type_id) { + switch (type_id) { + case TypeId::INVALID: + os << "INVALID"; + break; + case TypeId::BOOLEAN: + os << "BOOLEAN"; + break; + case TypeId::TINYINT: + os << "TINYINT"; + break; + case TypeId::SMALLINT: + os << "SMALLINT"; + break; + case TypeId::INTEGER: + os << "INTEGER"; + break; + case TypeId::BIGINT: + os << "BIGINT"; + break; + case TypeId::REAL: + os << "REAL"; + break; + case TypeId::DECIMAL: + os << "DECIMAL"; + break; + case TypeId::TIMESTAMP: + os << "TIMESTAMP"; + break; + case TypeId::DATE: + os << "DATE"; + break; + case TypeId::VARCHAR: + os << "VARCHAR"; + break; + case TypeId::VARBINARY: + os << "VARBINARY"; + break; + case TypeId::PARAMETER_OFFSET: + os << "PARAMETER_OFFSET"; + break; + case TypeId::VARIADIC: + os << "VARIADIC"; + break; + case TypeId::VAR_ARRAY: + os << "VAR_ARRAY"; + break; + default: + NOISEPAGE_ASSERT(false, "Invalid Type ID"); + } + return os; +} + } // namespace noisepage::type From bc4a1e5a85366f4f6d7a013989c4c0567617b113 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 18 Jun 2021 07:49:42 -0400 Subject: [PATCH 038/139] progress on linter errors, big refactor in udf parser --- src/include/binder/bind_node_visitor.h | 2 + src/include/execution/ast/ast.h | 2 +- src/include/execution/ast/type.h | 8 +- .../execution/ast/udf/udf_ast_context.h | 5 + .../execution/ast/udf/udf_ast_node_visitor.h | 2 +- src/include/execution/ast/udf/udf_ast_nodes.h | 50 +++++---- .../execution/compiler/executable_query.h | 2 +- .../compiler/expression/function_translator.h | 6 +- .../execution/compiler/udf/udf_codegen.h | 3 + src/include/execution/sema/scope.h | 3 + src/include/parser/udf/udf_parser.h | 7 +- src/parser/udf/udf_parser.cpp | 102 ++++++++---------- 12 files changed, 104 insertions(+), 88 deletions(-) diff --git a/src/include/binder/bind_node_visitor.h b/src/include/binder/bind_node_visitor.h index b3fb4f9b0b..0973a3dd51 100644 --- a/src/include/binder/bind_node_visitor.h +++ b/src/include/binder/bind_node_visitor.h @@ -2,6 +2,8 @@ #include #include +#include +#include #include #include "binder/sql_node_visitor.h" diff --git a/src/include/execution/ast/ast.h b/src/include/execution/ast/ast.h index 27b57d8513..a50021042d 100644 --- a/src/include/execution/ast/ast.h +++ b/src/include/execution/ast/ast.h @@ -711,7 +711,7 @@ class BreakStmt : public Stmt { * Constructor * @param pos source position */ - BreakStmt(const SourcePosition &pos) : Stmt(Kind::BreakStmt, pos) {} + explicit BreakStmt(const SourcePosition &pos) : Stmt(Kind::BreakStmt, pos) {} /** * Is the given node a break statement? diff --git a/src/include/execution/ast/type.h b/src/include/execution/ast/type.h index b8b8b3f62d..e9b2986390 100644 --- a/src/include/execution/ast/type.h +++ b/src/include/execution/ast/type.h @@ -746,7 +746,7 @@ class MapType : public Type { /** * @param type to compare with - * @return whether type is of map type. + * @return whether type is of Map type. */ static bool classof(const Type *type) { return type->GetTypeId() == TypeId::MapType; } // NOLINT @@ -773,10 +773,14 @@ class LambdaType : public Type { */ static LambdaType *Get(FunctionType *fn_type); + /** + * @param type to compare with + * @return whether type is of Lambda type. + */ static bool classof(const Type *type) { return type->GetTypeId() == TypeId::LambdaType; } // NOLINT private: - LambdaType(FunctionType *fn_type); + explicit LambdaType(FunctionType *fn_type); private: FunctionType *fn_type_; diff --git a/src/include/execution/ast/udf/udf_ast_context.h b/src/include/execution/ast/udf/udf_ast_context.h index 9c34b756a1..1e38ee4d9b 100644 --- a/src/include/execution/ast/udf/udf_ast_context.h +++ b/src/include/execution/ast/udf/udf_ast_context.h @@ -1,5 +1,10 @@ #pragma once +#include +#include +#include +#include + #include "type/type_id.h" namespace noisepage { diff --git a/src/include/execution/ast/udf/udf_ast_node_visitor.h b/src/include/execution/ast/udf/udf_ast_node_visitor.h index 2ca9c6a17c..512230995b 100644 --- a/src/include/execution/ast/udf/udf_ast_node_visitor.h +++ b/src/include/execution/ast/udf/udf_ast_node_visitor.h @@ -27,7 +27,7 @@ class FunctionAST; class ASTNodeVisitor { public: - virtual ~ASTNodeVisitor(){}; + virtual ~ASTNodeVisitor() {} virtual void Visit(AbstractAST *ast) = 0; virtual void Visit(StmtAST *ast) = 0; virtual void Visit(ExprAST *ast) = 0; diff --git a/src/include/execution/ast/udf/udf_ast_nodes.h b/src/include/execution/ast/udf/udf_ast_nodes.h index 52fd93e87e..7624ae36cb 100644 --- a/src/include/execution/ast/udf/udf_ast_nodes.h +++ b/src/include/execution/ast/udf/udf_ast_nodes.h @@ -1,5 +1,9 @@ #pragma once +#include +#include +#include + #include "parser/expression/constant_value_expression.h" #include "parser/expression_defs.h" #include "type/type_id.h" @@ -17,7 +21,7 @@ class AbstractAST { public: virtual ~AbstractAST() = default; - virtual void Accept(ASTNodeVisitor *visitor) { visitor->Visit(this); }; + virtual void Accept(ASTNodeVisitor *visitor) { visitor->Visit(this); } }; // StmtAST - Base class for all statement nodes. @@ -25,7 +29,7 @@ class StmtAST : public AbstractAST { public: virtual ~StmtAST() = default; - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; // ExprAST - Base class for all expression nodes. @@ -33,7 +37,7 @@ class ExprAST : public StmtAST { public: virtual ~ExprAST() = default; - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; // DoubleExprAST - Expression class for numeric literals like "1.1". @@ -41,9 +45,9 @@ class ValueExprAST : public ExprAST { public: std::unique_ptr value_; - ValueExprAST(std::unique_ptr value) : value_(std::move(value)) {} + explicit ValueExprAST(std::unique_ptr value) : value_(std::move(value)) {} - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; class IsNullExprAST : public ExprAST { @@ -54,7 +58,7 @@ class IsNullExprAST : public ExprAST { IsNullExprAST(bool is_null_check, std::unique_ptr child) : is_null_check_(is_null_check), child_(std::move(child)) {} - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; // VariableExprAST - Expression class for referencing a variable, like "a". @@ -62,9 +66,9 @@ class VariableExprAST : public ExprAST { public: std::string name; - VariableExprAST(const std::string &name) : name(name) {} + explicit VariableExprAST(const std::string &name) : name(name) {} - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; // VariableExprAST - Expression class for referencing a variable, like "a". @@ -76,7 +80,7 @@ class MemberExprAST : public ExprAST { MemberExprAST(std::unique_ptr &&object, std::string field) : object(std::move(object)), field(field) {} - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; // BinaryExprAST - Expression class for a binary operator. @@ -88,7 +92,7 @@ class BinaryExprAST : public ExprAST { BinaryExprAST(parser::ExpressionType op, std::unique_ptr lhs, std::unique_ptr rhs) : op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {} - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; // CallExprAST - Expression class for function calls. @@ -100,7 +104,7 @@ class CallExprAST : public ExprAST { CallExprAST(const std::string &callee, std::vector> args) : callee(callee), args(std::move(args)) {} - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; // SeqStmtAST - Statement class for sequence of statements @@ -108,9 +112,9 @@ class SeqStmtAST : public StmtAST { public: std::vector> stmts; - SeqStmtAST(std::vector> stmts) : stmts(std::move(stmts)) {} + explicit SeqStmtAST(std::vector>&& stmts) : stmts(std::move(stmts)) {} - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; // DeclStmtAST - Statement class for sequence of statements @@ -157,7 +161,7 @@ class WhileStmtAST : public StmtAST { std::unique_ptr cond_expr; std::unique_ptr body_stmt; - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } WhileStmtAST(std::unique_ptr cond_expr, std::unique_ptr body_stmt) : cond_expr(std::move(cond_expr)), body_stmt(std::move(body_stmt)) {} @@ -168,9 +172,9 @@ class RetStmtAST : public StmtAST { public: std::unique_ptr expr; - RetStmtAST(std::unique_ptr expr) : expr(std::move(expr)) {} + explicit RetStmtAST(std::unique_ptr expr) : expr(std::move(expr)) {} - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; // AssignStmtAST - Expression class for a binary operator. @@ -182,7 +186,7 @@ class AssignStmtAST : public ExprAST { AssignStmtAST(std::unique_ptr lhs, std::unique_ptr rhs) : lhs(std::move(lhs)), rhs(std::move(rhs)) {} - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; // SQLStmtAST - Expression class for a SQL Statement. @@ -196,7 +200,7 @@ class SQLStmtAST : public StmtAST { std::unordered_map> &&udf_params) : query(std::move(query)), var_name(std::move(var_name)), udf_params(std::move(udf_params)) {} - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; // DynamicSQLStmtAST - Expression class for a SQL Statement. @@ -208,7 +212,7 @@ class DynamicSQLStmtAST : public StmtAST { DynamicSQLStmtAST(std::unique_ptr query, std::string var_name) : query(std::move(query)), var_name(std::move(var_name)) {} - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; // FunctionAST - This class represents a function definition itself. @@ -222,12 +226,12 @@ class FunctionAST : public AbstractAST { std::vector &¶m_types) : body(std::move(body)), param_names_(std::move(param_names)), param_types_(std::move(param_types)) {} - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; -/*---------------------------------------------------------------- -/// Error* - These are little helper functions for error handling. ------------------------------------------------------------------*/ +// ---------------------------------------------------------------------------- +// Error Handling Helpers +// ---------------------------------------------------------------------------- std::unique_ptr LogError(const char *str); diff --git a/src/include/execution/compiler/executable_query.h b/src/include/execution/compiler/executable_query.h index 6f7a753555..605f56c97a 100644 --- a/src/include/execution/compiler/executable_query.h +++ b/src/include/execution/compiler/executable_query.h @@ -113,7 +113,7 @@ class ExecutableQuery { /** * @return The file. */ - ast::File *GetFile() { return file_; }; + ast::File *GetFile() { return file_; } /** @return The metadata of this module. */ const vm::ModuleMetadata &GetModuleMetadata() const; diff --git a/src/include/execution/compiler/expression/function_translator.h b/src/include/execution/compiler/expression/function_translator.h index f6e48b2099..a4f645a05f 100644 --- a/src/include/execution/compiler/expression/function_translator.h +++ b/src/include/execution/compiler/expression/function_translator.h @@ -1,12 +1,12 @@ #pragma once +#include +#include + #include "execution/compiler/expression/expression_translator.h" #include "execution/functions/function_context.h" #include "execution/util/region_containers.h" -#include -#include - namespace noisepage::parser { class FunctionExpression; } // namespace noisepage::parser diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index 20034a0d92..e50e4f667a 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -1,5 +1,8 @@ #pragma once +#include +#include + #include "execution/ast/udf/udf_ast_context.h" #include "execution/ast/udf/udf_ast_node_visitor.h" #include "execution/compiler/codegen.h" diff --git a/src/include/execution/sema/scope.h b/src/include/execution/sema/scope.h index b30a5659b9..d3b4404681 100644 --- a/src/include/execution/sema/scope.h +++ b/src/include/execution/sema/scope.h @@ -1,5 +1,8 @@ #pragma once +#include +#include + #include #include "execution/ast/identifier.h" diff --git a/src/include/parser/udf/udf_parser.h b/src/include/parser/udf/udf_parser.h index 5624f795cd..4a17b5ad8d 100644 --- a/src/include/parser/udf/udf_parser.h +++ b/src/include/parser/udf/udf_parser.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include #include #include "catalog/catalog_accessor.h" @@ -43,9 +45,10 @@ class PLpgSQLParser { std::unique_ptr ParseFor(const nlohmann::json &loop); std::unique_ptr ParseSQL(const nlohmann::json &sql_stmt); std::unique_ptr ParseDynamicSQL(const nlohmann::json &sql_stmt); - // Feed the expression (as a sql string) to our parser then transform the + + // Feed the expression (as a SQL string) to our parser then transform the // noisepage expression into ast node - std::unique_ptr ParseExprSQL(const std::string expr_sql_str); + std::unique_ptr ParseExprSQL(const std::string &expr_sql_str); std::unique_ptr ParseExpr(common::ManagedPointer); common::ManagedPointer udf_ast_context_; diff --git a/src/parser/udf/udf_parser.cpp b/src/parser/udf/udf_parser.cpp index 3c44620816..e87d5fb00e 100644 --- a/src/parser/udf/udf_parser.cpp +++ b/src/parser/udf/udf_parser.cpp @@ -11,9 +11,6 @@ namespace noisepage { namespace parser { namespace udf { -using namespace nlohmann; -using namespace execution::ast::udf; - /** * @brief The identifiers used as keys in the parse tree. */ @@ -49,10 +46,10 @@ static const std::string kName = "name"; static const std::string kPLpgSQL_row = "PLpgSQL_row"; static const std::string kPLpgSQL_stmt_dynexecute = "PLpgSQL_stmt_dynexecute"; -std::unique_ptr PLpgSQLParser::ParsePLpgSQL(std::vector &¶m_names, +std::unique_ptr PLpgSQLParser::ParsePLpgSQL(std::vector &¶m_names, std::vector &¶m_types, const std::string &func_body, - common::ManagedPointer ast_context) { + common::ManagedPointer ast_context) { auto result = pg_query_parse_plpgsql(func_body.c_str()); if (result.error) { pg_query_free_plpgsql_parse_result(result); @@ -64,7 +61,7 @@ std::unique_ptr PLpgSQLParser::ParsePLpgSQL(std::vector> ast_json; const auto function_list = ast_json[kFunctionList]; NOISEPAGE_ASSERT(function_list.is_array(), "Function list is not an array"); @@ -78,35 +75,33 @@ std::unique_ptr PLpgSQLParser::ParsePLpgSQL(std::vector(ParseFunction(function), std::move(param_names), std::move(param_types)); + std::make_unique(ParseFunction(function), std::move(param_names), std::move(param_types)); return function_ast; } -std::unique_ptr PLpgSQLParser::ParseFunction(const nlohmann::json &block) { +std::unique_ptr PLpgSQLParser::ParseFunction(const nlohmann::json &block) { const auto decl_list = block[kDatums]; const auto function_body = block[kAction][kPLpgSQL_stmt_block][kBody]; - std::vector> stmts; + std::vector> stmts{}; NOISEPAGE_ASSERT(decl_list.is_array(), "Declaration list is not an array"); - for (uint32_t i = 1; i < decl_list.size(); i++) { + for (std::size_t i = 1UL; i < decl_list.size(); i++) { stmts.push_back(ParseDecl(decl_list[i])); } stmts.push_back(ParseBlock(function_body)); - - std::unique_ptr seq_stmt_ast(new SeqStmtAST(std::move(stmts))); - return std::move(seq_stmt_ast); + return std::make_unique(std::move(stmts)); } -std::unique_ptr PLpgSQLParser::ParseBlock(const nlohmann::json &block) { +std::unique_ptr PLpgSQLParser::ParseBlock(const nlohmann::json &block) { // TODO(boweic): Support statements size other than 1 NOISEPAGE_ASSERT(block.is_array(), "Block isn't array"); if (block.size() == 0) { throw PARSER_EXCEPTION("PL/pgSQL parser : Empty block is not supported"); } - std::vector> stmts; + std::vector> stmts{}; for (uint32_t i = 0; i < block.size(); i++) { const auto stmt = block[i]; @@ -115,18 +110,16 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nlohmann::json &block) if (stmt_names.key() == kPLpgSQL_stmt_return) { auto expr = ParseExprSQL(stmt[kPLpgSQL_stmt_return][kExpr][kPLpgSQL_expr][kQuery].get()); // TODO(boweic): Handle return stmt w/o expression - std::unique_ptr ret_stmt_ast(new RetStmtAST(std::move(expr))); - stmts.push_back(std::move(ret_stmt_ast)); + stmts.push_back(std::make_unique(std::move(expr))); } else if (stmt_names.key() == kPLpgSQL_stmt_if) { stmts.push_back(ParseIf(stmt[kPLpgSQL_stmt_if])); } else if (stmt_names.key() == kPLpgSQL_stmt_assign) { // TODO[Siva]: Need to fix Assignment expression / statement const auto &var_name = udf_ast_context_->GetLocalVariableAtIndex(stmt[kPLpgSQL_stmt_assign][kVarno].get()); - std::unique_ptr lhs(new VariableExprAST(var_name)); + auto lhs = std::make_unique(var_name); auto rhs = ParseExprSQL(stmt[kPLpgSQL_stmt_assign][kExpr][kPLpgSQL_expr][kQuery].get()); - std::unique_ptr ass_expr_ast(new AssignStmtAST(std::move(lhs), std::move(rhs))); - stmts.push_back(std::move(ass_expr_ast)); + stmts.push_back(std::make_unique(std::move(lhs), std::move(rhs))); } else if (stmt_names.key() == kPLpgSQL_stmt_while) { stmts.push_back(ParseWhile(stmt[kPLpgSQL_stmt_while])); } else if (stmt_names.key() == kPLpgSQL_stmt_fors) { @@ -140,72 +133,72 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nlohmann::json &block) } } - return std::make_unique(std::move(stmts)); + return std::make_unique(std::move(stmts)); } -std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &decl) { +std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &decl) { const auto &decl_names = decl.items().begin(); if (decl_names.key() == kPLpgSQL_var) { auto var_name = decl[kPLpgSQL_var][kRefname].get(); udf_ast_context_->AddVariable(var_name); auto type = decl[kPLpgSQL_var][kDatatype][kPLpgSQL_type][kTypname].get(); - std::unique_ptr initial = nullptr; + std::unique_ptr initial = nullptr; if (decl[kPLpgSQL_var].find(kDefaultVal) != decl[kPLpgSQL_var].end()) { initial = ParseExprSQL(decl[kPLpgSQL_var][kDefaultVal][kPLpgSQL_expr][kQuery].get()); } type::TypeId temp_type{}; if (udf_ast_context_->GetVariableType(var_name, &temp_type)) { - return std::unique_ptr(new DeclStmtAST(var_name, temp_type, std::move(initial))); + return std::make_unique(var_name, temp_type, std::move(initial)); } if ((type.find("integer") != std::string::npos) || type.find("INTEGER") != std::string::npos) { udf_ast_context_->SetVariableType(var_name, type::TypeId::INTEGER); - return std::unique_ptr(new DeclStmtAST(var_name, type::TypeId::INTEGER, std::move(initial))); + return std::make_unique(var_name, type::TypeId::INTEGER, std::move(initial)); } else if (type == "double" || type.rfind("numeric") == 0) { udf_ast_context_->SetVariableType(var_name, type::TypeId::DECIMAL); - return std::unique_ptr(new DeclStmtAST(var_name, type::TypeId::DECIMAL, std::move(initial))); + return std::make_unique(var_name, type::TypeId::DECIMAL, std::move(initial)); } else if (type == "varchar") { udf_ast_context_->SetVariableType(var_name, type::TypeId::VARCHAR); - return std::unique_ptr(new DeclStmtAST(var_name, type::TypeId::VARCHAR, std::move(initial))); + return std::make_unique(var_name, type::TypeId::VARCHAR, std::move(initial)); } else if (type.find("date") != std::string::npos) { udf_ast_context_->SetVariableType(var_name, type::TypeId::DATE); - return std::unique_ptr(new DeclStmtAST(var_name, type::TypeId::DATE, std::move(initial))); + return std::make_unique(var_name, type::TypeId::DATE, std::move(initial)); } else if (type == "record") { udf_ast_context_->SetVariableType(var_name, type::TypeId::INVALID); - return std::unique_ptr(new DeclStmtAST(var_name, type::TypeId::INVALID, std::move(initial))); + return std::make_unique(var_name, type::TypeId::INVALID, std::move(initial)); } else { - NOISEPAGE_ASSERT(false, "Unsupported"); + NOISEPAGE_ASSERT(false, "Unsupported Type"); } } else if (decl_names.key() == kPLpgSQL_row) { auto var_name = decl[kPLpgSQL_row][kRefname].get(); NOISEPAGE_ASSERT(var_name == "*internal*", "Unexpected refname"); // TODO[Siva]: Support row types later udf_ast_context_->SetVariableType(var_name, type::TypeId::INVALID); - return std::unique_ptr(new DeclStmtAST(var_name, type::TypeId::INVALID, nullptr)); + return std::make_unique(var_name, type::TypeId::INVALID, nullptr); } // TODO[Siva]: need to handle other types like row, table etc; throw PARSER_EXCEPTION("Declaration type not supported"); } -std::unique_ptr PLpgSQLParser::ParseIf(const nlohmann::json &branch) { +std::unique_ptr PLpgSQLParser::ParseIf(const nlohmann::json &branch) { auto cond_expr = ParseExprSQL(branch[kCond][kPLpgSQL_expr][kQuery].get()); auto then_stmt = ParseBlock(branch[kThenBody]); - std::unique_ptr else_stmt = nullptr; + std::unique_ptr else_stmt = nullptr; if (branch.find(kElseBody) != branch.end()) { else_stmt = ParseBlock(branch[kElseBody]); } - return std::unique_ptr(new IfStmtAST(std::move(cond_expr), std::move(then_stmt), std::move(else_stmt))); + return std::make_unique(std::move(cond_expr), std::move(then_stmt), std::move(else_stmt)); } -std::unique_ptr PLpgSQLParser::ParseWhile(const nlohmann::json &loop) { +std::unique_ptr PLpgSQLParser::ParseWhile(const nlohmann::json &loop) { auto cond_expr = ParseExprSQL(loop[kCond][kPLpgSQL_expr][kQuery].get()); auto body_stmt = ParseBlock(loop[kBody]); - return std::unique_ptr(new WhileStmtAST(std::move(cond_expr), std::move(body_stmt))); + return std::make_unique(std::move(cond_expr), std::move(body_stmt)); } -std::unique_ptr PLpgSQLParser::ParseFor(const nlohmann::json &loop) { +std::unique_ptr PLpgSQLParser::ParseFor(const nlohmann::json &loop) { auto sql_query = loop[kQuery][kPLpgSQL_expr][kQuery].get(); auto parse_result = PostgresParser::BuildParseTree(sql_query.c_str()); if (parse_result == nullptr) { @@ -217,10 +210,10 @@ std::unique_ptr PLpgSQLParser::ParseFor(const nlohmann::json &loop) { for (auto var : var_array) { var_vec.push_back(var[kName].get()); } - return std::unique_ptr(new ForStmtAST(std::move(var_vec), std::move(parse_result), std::move(body_stmt))); + return std::make_unique(std::move(var_vec), std::move(parse_result), std::move(body_stmt)); } -std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql_stmt) { +std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql_stmt) { auto sql_query = sql_stmt[kSqlstmt][kPLpgSQL_expr][kQuery].get(); auto var_name = sql_stmt[kRow][kPLpgSQL_row][kFields][0][kName].get(); @@ -257,17 +250,17 @@ std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql_stmt) udf_ast_context_->SetRecordType(var_name, std::move(elems)); } - return std::make_unique(std::move(parse_result), std::move(var_name), std::move(query_params)); + return std::make_unique(std::move(parse_result), std::move(var_name), std::move(query_params)); } -std::unique_ptr PLpgSQLParser::ParseDynamicSQL(const nlohmann::json &sql_stmt) { +std::unique_ptr PLpgSQLParser::ParseDynamicSQL(const nlohmann::json &sql_stmt) { auto sql_expr = ParseExprSQL(sql_stmt[kQuery][kPLpgSQL_expr][kQuery].get()); auto var_name = sql_stmt[kRow][kPLpgSQL_row][kFields][0][kName].get(); - return std::unique_ptr(new DynamicSQLStmtAST(std::move(sql_expr), std::move(var_name))); + return std::make_unique(std::move(sql_expr), std::move(var_name)); } -std::unique_ptr PLpgSQLParser::ParseExprSQL(const std::string expr_sql_str) { - auto stmt_list = PostgresParser::BuildParseTree(expr_sql_str.c_str()); +std::unique_ptr PLpgSQLParser::ParseExprSQL(const std::string& expr_sql_str) { + auto stmt_list = PostgresParser::BuildParseTree(expr_sql_str); if (stmt_list == nullptr) { return nullptr; } @@ -281,34 +274,33 @@ std::unique_ptr PLpgSQLParser::ParseExprSQL(const std::string expr_sql_ return PLpgSQLParser::ParseExpr(select_list[0]); } -std::unique_ptr PLpgSQLParser::ParseExpr(common::ManagedPointer expr) { +std::unique_ptr PLpgSQLParser::ParseExpr(common::ManagedPointer expr) { if (expr->GetExpressionType() == parser::ExpressionType::COLUMN_VALUE) { auto cve = expr.CastManagedPointerTo(); if (cve->GetTableName().empty()) { - return std::unique_ptr(new VariableExprAST(cve->GetColumnName())); + return std::make_unique(cve->GetColumnName()); } else { - auto vexpr = std::unique_ptr(new VariableExprAST(cve->GetTableName())); - return std::unique_ptr(new MemberExprAST(std::move(vexpr), cve->GetColumnName())); + auto vexpr = std::make_unique(cve->GetTableName()); + return std::make_unique(std::move(vexpr), cve->GetColumnName()); } } else if ((parser::ExpressionUtil::IsOperatorExpression(expr->GetExpressionType()) && expr->GetChildrenSize() == 2) || (parser::ExpressionUtil::IsComparisonExpression(expr->GetExpressionType()))) { - return std::unique_ptr( - new BinaryExprAST(expr->GetExpressionType(), ParseExpr(expr->GetChild(0)), ParseExpr(expr->GetChild(1)))); + return std::make_unique(expr->GetExpressionType(), ParseExpr(expr->GetChild(0)), ParseExpr(expr->GetChild(1))); } else if (expr->GetExpressionType() == parser::ExpressionType::FUNCTION) { auto func_expr = expr.CastManagedPointerTo(); - std::vector> args; + std::vector> args{}; auto num_args = func_expr->GetChildrenSize(); for (size_t idx = 0; idx < num_args; ++idx) { args.push_back(ParseExpr(func_expr->GetChild(idx))); } - return std::unique_ptr(new CallExprAST(func_expr->GetFuncName(), std::move(args))); + return std::make_unique(func_expr->GetFuncName(), std::move(args)); } else if (expr->GetExpressionType() == parser::ExpressionType::VALUE_CONSTANT) { - return std::unique_ptr(new ValueExprAST(expr->Copy())); + return std::make_unique(expr->Copy()); } else if (expr->GetExpressionType() == parser::ExpressionType::OPERATOR_IS_NOT_NULL) { - return std::unique_ptr(new IsNullExprAST(false, ParseExpr(expr->GetChild(0)))); + return std::make_unique(false, ParseExpr(expr->GetChild(0))); } else if (expr->GetExpressionType() == parser::ExpressionType::OPERATOR_IS_NULL) { - return std::unique_ptr(new IsNullExprAST(true, ParseExpr(expr->GetChild(0)))); + return std::make_unique(true, ParseExpr(expr->GetChild(0))); } throw PARSER_EXCEPTION("PL/pgSQL parser : Expression type not supported"); } From cb909ff78da7aab24054d61ee6edb968397a878d Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 18 Jun 2021 08:01:14 -0400 Subject: [PATCH 039/139] lint passing --- src/include/execution/ast/udf/udf_ast_nodes.h | 4 +- .../execution/functions/function_context.h | 1 + src/include/execution/sema/scope.h | 5 +- src/include/parser/udf/udf_parser.h | 2 +- src/parser/udf/udf_parser.cpp | 105 +++++++++--------- 5 files changed, 62 insertions(+), 55 deletions(-) diff --git a/src/include/execution/ast/udf/udf_ast_nodes.h b/src/include/execution/ast/udf/udf_ast_nodes.h index 7624ae36cb..21244487dd 100644 --- a/src/include/execution/ast/udf/udf_ast_nodes.h +++ b/src/include/execution/ast/udf/udf_ast_nodes.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include #include #include @@ -112,7 +114,7 @@ class SeqStmtAST : public StmtAST { public: std::vector> stmts; - explicit SeqStmtAST(std::vector>&& stmts) : stmts(std::move(stmts)) {} + explicit SeqStmtAST(std::vector> &&stmts) : stmts(std::move(stmts)) {} void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; diff --git a/src/include/execution/functions/function_context.h b/src/include/execution/functions/function_context.h index 2928366a7b..6f1877e744 100644 --- a/src/include/execution/functions/function_context.h +++ b/src/include/execution/functions/function_context.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include diff --git a/src/include/execution/sema/scope.h b/src/include/execution/sema/scope.h index d3b4404681..5be9e6676f 100644 --- a/src/include/execution/sema/scope.h +++ b/src/include/execution/sema/scope.h @@ -1,10 +1,9 @@ #pragma once +#include #include #include -#include - #include "execution/ast/identifier.h" #include "execution/util/execution_common.h" @@ -93,7 +92,7 @@ class Scope { // The scope kind. Kind scope_kind_; // The mapping of identifiers to their types. - llvm::DenseMap decls_; + std::unordered_map decls_; }; } // namespace sema diff --git a/src/include/parser/udf/udf_parser.h b/src/include/parser/udf/udf_parser.h index 4a17b5ad8d..cf617036a3 100644 --- a/src/include/parser/udf/udf_parser.h +++ b/src/include/parser/udf/udf_parser.h @@ -45,7 +45,7 @@ class PLpgSQLParser { std::unique_ptr ParseFor(const nlohmann::json &loop); std::unique_ptr ParseSQL(const nlohmann::json &sql_stmt); std::unique_ptr ParseDynamicSQL(const nlohmann::json &sql_stmt); - + // Feed the expression (as a SQL string) to our parser then transform the // noisepage expression into ast node std::unique_ptr ParseExprSQL(const std::string &expr_sql_str); diff --git a/src/parser/udf/udf_parser.cpp b/src/parser/udf/udf_parser.cpp index e87d5fb00e..168fbc41a1 100644 --- a/src/parser/udf/udf_parser.cpp +++ b/src/parser/udf/udf_parser.cpp @@ -14,49 +14,49 @@ namespace udf { /** * @brief The identifiers used as keys in the parse tree. */ -static const std::string kFunctionList = "FunctionList"; -static const std::string kDatums = "datums"; -static const std::string kPLpgSQL_var = "PLpgSQL_var"; -static const std::string kRefname = "refname"; -static const std::string kDatatype = "datatype"; -static const std::string kDefaultVal = "default_val"; -static const std::string kPLpgSQL_type = "PLpgSQL_type"; -static const std::string kTypname = "typname"; -static const std::string kAction = "action"; -static const std::string kPLpgSQL_function = "PLpgSQL_function"; -static const std::string kBody = "body"; -static const std::string kPLpgSQL_stmt_block = "PLpgSQL_stmt_block"; -static const std::string kPLpgSQL_stmt_return = "PLpgSQL_stmt_return"; -static const std::string kPLpgSQL_stmt_if = "PLpgSQL_stmt_if"; -static const std::string kPLpgSQL_stmt_while = "PLpgSQL_stmt_while"; -static const std::string kPLpgSQL_stmt_fors = "PLpgSQL_stmt_fors"; -static const std::string kCond = "cond"; -static const std::string kThenBody = "then_body"; -static const std::string kElseBody = "else_body"; -static const std::string kExpr = "expr"; -static const std::string kQuery = "query"; -static const std::string kPLpgSQL_expr = "PLpgSQL_expr"; -static const std::string kPLpgSQL_stmt_assign = "PLpgSQL_stmt_assign"; -static const std::string kVarno = "varno"; -static const std::string kPLpgSQL_stmt_execsql = "PLpgSQL_stmt_execsql"; -static const std::string kSqlstmt = "sqlstmt"; -static const std::string kRow = "row"; -static const std::string kFields = "fields"; -static const std::string kName = "name"; -static const std::string kPLpgSQL_row = "PLpgSQL_row"; -static const std::string kPLpgSQL_stmt_dynexecute = "PLpgSQL_stmt_dynexecute"; - -std::unique_ptr PLpgSQLParser::ParsePLpgSQL(std::vector &¶m_names, - std::vector &¶m_types, - const std::string &func_body, - common::ManagedPointer ast_context) { +static constexpr const char kFunctionList[] = "FunctionList"; +static constexpr const char kDatums[] = "datums"; +static constexpr const char kPLpgSQL_var[] = "PLpgSQL_var"; +static constexpr const char kRefname[] = "refname"; +static constexpr const char kDatatype[] = "datatype"; +static constexpr const char kDefaultVal[] = "default_val"; +static constexpr const char kPLpgSQL_type[] = "PLpgSQL_type"; +static constexpr const char kTypname[] = "typname"; +static constexpr const char kAction[] = "action"; +static constexpr const char kPLpgSQL_function[] = "PLpgSQL_function"; +static constexpr const char kBody[] = "body"; +static constexpr const char kPLpgSQL_stmt_block[] = "PLpgSQL_stmt_block"; +static constexpr const char kPLpgSQL_stmt_return[] = "PLpgSQL_stmt_return"; +static constexpr const char kPLpgSQL_stmt_if[] = "PLpgSQL_stmt_if"; +static constexpr const char kPLpgSQL_stmt_while[] = "PLpgSQL_stmt_while"; +static constexpr const char kPLpgSQL_stmt_fors[] = "PLpgSQL_stmt_fors"; +static constexpr const char kCond[] = "cond"; +static constexpr const char kThenBody[] = "then_body"; +static constexpr const char kElseBody[] = "else_body"; +static constexpr const char kExpr[] = "expr"; +static constexpr const char kQuery[] = "query"; +static constexpr const char kPLpgSQL_expr[] = "PLpgSQL_expr"; +static constexpr const char kPLpgSQL_stmt_assign[] = "PLpgSQL_stmt_assign"; +static constexpr const char kVarno[] = "varno"; +static constexpr const char kPLpgSQL_stmt_execsql[] = "PLpgSQL_stmt_execsql"; +static constexpr const char kSqlstmt[] = "sqlstmt"; +static constexpr const char kRow[] = "row"; +static constexpr const char kFields[] = "fields"; +static constexpr const char kName[] = "name"; +static constexpr const char kPLpgSQL_row[] = "PLpgSQL_row"; +static constexpr const char kPLpgSQL_stmt_dynexecute[] = "PLpgSQL_stmt_dynexecute"; + +std::unique_ptr PLpgSQLParser::ParsePLpgSQL( + std::vector &¶m_names, std::vector &¶m_types, const std::string &func_body, + common::ManagedPointer ast_context) { auto result = pg_query_parse_plpgsql(func_body.c_str()); if (result.error) { pg_query_free_plpgsql_parse_result(result); throw PARSER_EXCEPTION("PL/pgSQL parsing error"); } // The result is a list, we need to wrap it - const auto ast_json_str = "{ \"" + kFunctionList + "\" : " + std::string{result.plpgsql_funcs} + " }"; + const auto ast_json_str = + "{ \"" + std::string{kFunctionList} + "\" : " + std::string{result.plpgsql_funcs} + " }"; // NOLINT pg_query_free_plpgsql_parse_result(result); @@ -74,8 +74,8 @@ std::unique_ptr PLpgSQLParser::ParsePLpgSQL(st udf_ast_context_->SetVariableType(udf_name, param_types[i++]); } const auto function = function_list[0][kPLpgSQL_function]; - auto function_ast = - std::make_unique(ParseFunction(function), std::move(param_names), std::move(param_types)); + auto function_ast = std::make_unique( + ParseFunction(function), std::move(param_names), std::move(param_types)); return function_ast; } @@ -109,12 +109,12 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl if (stmt_names.key() == kPLpgSQL_stmt_return) { auto expr = ParseExprSQL(stmt[kPLpgSQL_stmt_return][kExpr][kPLpgSQL_expr][kQuery].get()); - // TODO(boweic): Handle return stmt w/o expression + // TODO(Kyle): Handle return stmt w/o expression stmts.push_back(std::make_unique(std::move(expr))); } else if (stmt_names.key() == kPLpgSQL_stmt_if) { stmts.push_back(ParseIf(stmt[kPLpgSQL_stmt_if])); } else if (stmt_names.key() == kPLpgSQL_stmt_assign) { - // TODO[Siva]: Need to fix Assignment expression / statement + // TODO(Kyle): Need to fix Assignment expression / statement const auto &var_name = udf_ast_context_->GetLocalVariableAtIndex(stmt[kPLpgSQL_stmt_assign][kVarno].get()); auto lhs = std::make_unique(var_name); @@ -174,11 +174,11 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo } else if (decl_names.key() == kPLpgSQL_row) { auto var_name = decl[kPLpgSQL_row][kRefname].get(); NOISEPAGE_ASSERT(var_name == "*internal*", "Unexpected refname"); - // TODO[Siva]: Support row types later + // TODO(Kyle): Support row types later udf_ast_context_->SetVariableType(var_name, type::TypeId::INVALID); return std::make_unique(var_name, type::TypeId::INVALID, nullptr); } - // TODO[Siva]: need to handle other types like row, table etc; + // TODO(Kyle): Need to handle other types like row, table etc; throw PARSER_EXCEPTION("Declaration type not supported"); } @@ -189,7 +189,8 @@ std::unique_ptr PLpgSQLParser::ParseIf(const nlohm if (branch.find(kElseBody) != branch.end()) { else_stmt = ParseBlock(branch[kElseBody]); } - return std::make_unique(std::move(cond_expr), std::move(then_stmt), std::move(else_stmt)); + return std::make_unique(std::move(cond_expr), std::move(then_stmt), + std::move(else_stmt)); } std::unique_ptr PLpgSQLParser::ParseWhile(const nlohmann::json &loop) { @@ -210,7 +211,8 @@ std::unique_ptr PLpgSQLParser::ParseFor(const nloh for (auto var : var_array) { var_vec.push_back(var[kName].get()); } - return std::make_unique(std::move(var_vec), std::move(parse_result), std::move(body_stmt)); + return std::make_unique(std::move(var_vec), std::move(parse_result), + std::move(body_stmt)); } std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql_stmt) { @@ -250,7 +252,8 @@ std::unique_ptr PLpgSQLParser::ParseSQL(const nloh udf_ast_context_->SetRecordType(var_name, std::move(elems)); } - return std::make_unique(std::move(parse_result), std::move(var_name), std::move(query_params)); + return std::make_unique(std::move(parse_result), std::move(var_name), + std::move(query_params)); } std::unique_ptr PLpgSQLParser::ParseDynamicSQL(const nlohmann::json &sql_stmt) { @@ -259,7 +262,7 @@ std::unique_ptr PLpgSQLParser::ParseDynamicSQL(con return std::make_unique(std::move(sql_expr), std::move(var_name)); } -std::unique_ptr PLpgSQLParser::ParseExprSQL(const std::string& expr_sql_str) { +std::unique_ptr PLpgSQLParser::ParseExprSQL(const std::string &expr_sql_str) { auto stmt_list = PostgresParser::BuildParseTree(expr_sql_str); if (stmt_list == nullptr) { return nullptr; @@ -274,7 +277,8 @@ std::unique_ptr PLpgSQLParser::ParseExprSQL(const return PLpgSQLParser::ParseExpr(select_list[0]); } -std::unique_ptr PLpgSQLParser::ParseExpr(common::ManagedPointer expr) { +std::unique_ptr PLpgSQLParser::ParseExpr( + common::ManagedPointer expr) { if (expr->GetExpressionType() == parser::ExpressionType::COLUMN_VALUE) { auto cve = expr.CastManagedPointerTo(); if (cve->GetTableName().empty()) { @@ -286,7 +290,8 @@ std::unique_ptr PLpgSQLParser::ParseExpr(common::M } else if ((parser::ExpressionUtil::IsOperatorExpression(expr->GetExpressionType()) && expr->GetChildrenSize() == 2) || (parser::ExpressionUtil::IsComparisonExpression(expr->GetExpressionType()))) { - return std::make_unique(expr->GetExpressionType(), ParseExpr(expr->GetChild(0)), ParseExpr(expr->GetChild(1))); + return std::make_unique(expr->GetExpressionType(), ParseExpr(expr->GetChild(0)), + ParseExpr(expr->GetChild(1))); } else if (expr->GetExpressionType() == parser::ExpressionType::FUNCTION) { auto func_expr = expr.CastManagedPointerTo(); std::vector> args{}; @@ -307,4 +312,4 @@ std::unique_ptr PLpgSQLParser::ParseExpr(common::M } // namespace udf } // namespace parser -} // namespace noisepage \ No newline at end of file +} // namespace noisepage From 6c20c68993a266a60a6406bb0ca4285c4381d290 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 18 Jun 2021 08:24:10 -0400 Subject: [PATCH 040/139] now fighting clang-tidy, almost there but going to need a big refactor on udf_ast_nodes now --- src/execution/ast/ast_clone.cpp | 3 +- .../expression/function_translator.cpp | 7 +- .../compiler/operator/output_translator.cpp | 14 ++-- src/execution/sema/sema_builtin.cpp | 75 +++++++++---------- src/execution/sema/sema_expr.cpp | 2 +- src/execution/sema/sema_type.cpp | 2 +- src/include/execution/ast/ast.h | 2 +- .../execution/ast/udf/udf_ast_node_visitor.h | 2 +- src/include/execution/ast/udf/udf_ast_nodes.h | 4 +- .../execution/exec/execution_context.h | 2 +- .../execution/vm/bytecode_function_info.h | 2 +- .../expression/constant_value_expression.cpp | 2 +- src/traffic_cop/traffic_cop.cpp | 30 ++++---- 13 files changed, 71 insertions(+), 76 deletions(-) diff --git a/src/execution/ast/ast_clone.cpp b/src/execution/ast/ast_clone.cpp index 3b27769807..8413d18491 100644 --- a/src/execution/ast/ast_clone.cpp +++ b/src/execution/ast/ast_clone.cpp @@ -121,9 +121,8 @@ AstNode *AstCloneImpl::VisitIfStmt(IfStmt *node) { AstNode *AstCloneImpl::VisitReturnStmt(ReturnStmt *node) { if (node->Ret() == nullptr) { return factory_->NewReturnStmt(node->Position(), nullptr); - } else { - return factory_->NewReturnStmt(node->Position(), reinterpret_cast(Visit(node->Ret()))); } + return factory_->NewReturnStmt(node->Position(), reinterpret_cast(Visit(node->Ret()))); } AstNode *AstCloneImpl::VisitCallExpr(CallExpr *node) { diff --git a/src/execution/compiler/expression/function_translator.cpp b/src/execution/compiler/expression/function_translator.cpp index ec1fb819f4..1d9f045565 100644 --- a/src/execution/compiler/expression/function_translator.cpp +++ b/src/execution/compiler/expression/function_translator.cpp @@ -38,11 +38,8 @@ ast::Expr *FunctionTranslator::DeriveValue(WorkContext *ctx, const ColumnValuePr if (!func_context->IsBuiltin()) { auto ident_expr = main_fn_; - std::vector args; - for (auto &expr : params) { - args.emplace_back(expr); - } - return GetCodeGen()->Call(ident_expr, std::move(args)); + std::vector args{params.cbegin(), params.cbegin()}; + return GetCodeGen()->Call(ident_expr, args); } return codegen->CallBuiltin(func_context->GetBuiltin(), params); diff --git a/src/execution/compiler/operator/output_translator.cpp b/src/execution/compiler/operator/output_translator.cpp index 52d3a264fa..c02fcf52bf 100644 --- a/src/execution/compiler/operator/output_translator.cpp +++ b/src/execution/compiler/operator/output_translator.cpp @@ -29,7 +29,7 @@ OutputTranslator::OutputTranslator(const planner::AbstractPlanNode &plan, Compil } void OutputTranslator::InitializePipelineState(const Pipeline &pipeline, FunctionBuilder *function) const { - if (GetCompilationContext()->GetOutputCallback()) { + if (GetCompilationContext()->GetOutputCallback() != nullptr) { return; } @@ -41,7 +41,7 @@ void OutputTranslator::InitializePipelineState(const Pipeline &pipeline, Functio } void OutputTranslator::TearDownPipelineState(const Pipeline &pipeline, FunctionBuilder *function) const { - if (GetCompilationContext()->GetOutputCallback()) { + if (GetCompilationContext()->GetOutputCallback() != nullptr) { return; } @@ -55,7 +55,7 @@ void OutputTranslator::PerformPipelineWork(noisepage::execution::compiler::WorkC auto out_buffer = output_buffer_.Get(GetCodeGen()); ast::Expr *cast_call; auto callback = GetCompilationContext()->GetOutputCallback(); - if (callback) { + if (callback != nullptr) { auto output = GetCodeGen()->MakeFreshIdentifier("output_row"); auto *row_alloc = GetCodeGen()->DeclareVarNoInit(output, GetCodeGen()->MakeExpr(output_struct_)); function->Append(row_alloc); @@ -76,13 +76,13 @@ void OutputTranslator::PerformPipelineWork(noisepage::execution::compiler::WorkC ast::Expr *lhs = GetCodeGen()->AccessStructMember(GetCodeGen()->MakeExpr(output_var_), attr_name); ast::Expr *rhs = child_translator->GetOutput(context, attr_idx); function->Append(GetCodeGen()->Assign(lhs, rhs)); - if (callback) { + if (callback != nullptr) { callback_args.push_back(lhs); } } - if (callback) { - function->Append(GetCodeGen()->Call(callback->As()->GetName(), std::move(callback_args))); + if (callback != nullptr) { + function->Append(GetCodeGen()->Call(callback->As()->GetName(), callback_args)); } CounterAdd(function, num_output_, 1); @@ -105,7 +105,7 @@ void OutputTranslator::EndParallelPipelineWork(const Pipeline &pipeline, Functio } void OutputTranslator::FinishPipelineWork(const Pipeline &pipeline, FunctionBuilder *function) const { - if (GetCompilationContext()->GetOutputCallback()) { + if (GetCompilationContext()->GetOutputCallback() != nullptr) { return; } diff --git a/src/execution/sema/sema_builtin.cpp b/src/execution/sema/sema_builtin.cpp index 11d93e9d8f..54286cbacb 100644 --- a/src/execution/sema/sema_builtin.cpp +++ b/src/execution/sema/sema_builtin.cpp @@ -3148,46 +3148,45 @@ void Sema::CheckBuiltinParamCall(ast::CallExpr *call, ast::Builtin builtin) { // Return sql type call->SetType(ast::BuiltinType::Get(GetContext(), sql_type)); return; - } else { - if (builtin > ast::Builtin::FinishNewParams) { - ast::BuiltinType::Kind add_sql_type; - switch (builtin) { - case ast::Builtin::AddParamBool: { - add_sql_type = ast::BuiltinType::Boolean; - break; - } - case ast::Builtin::AddParamTinyInt: - case ast::Builtin::AddParamSmallInt: - case ast::Builtin::AddParamInt: - case ast::Builtin::AddParamBigInt: { - add_sql_type = ast::BuiltinType::Integer; - break; - } - case ast::Builtin::AddParamReal: - case ast::Builtin::AddParamDouble: { - add_sql_type = ast::BuiltinType::Real; - break; - } - case ast::Builtin::AddParamDate: { - add_sql_type = ast::BuiltinType::Date; - break; - } - case ast::Builtin::AddParamTimestamp: { - add_sql_type = ast::BuiltinType::Timestamp; - break; - } - case ast::Builtin::AddParamString: { - add_sql_type = ast::BuiltinType::StringVal; - break; - } - default: { - UNREACHABLE("Undefined parameter call!!"); - } + } + if (builtin > ast::Builtin::FinishNewParams) { + ast::BuiltinType::Kind add_sql_type; + switch (builtin) { + case ast::Builtin::AddParamBool: { + add_sql_type = ast::BuiltinType::Boolean; + break; } - if (call->Arguments()[1]->GetType() != GetBuiltinType(add_sql_type)) { - ReportIncorrectCallArg(call, 1, GetBuiltinType(add_sql_type)); - return; + case ast::Builtin::AddParamTinyInt: + case ast::Builtin::AddParamSmallInt: + case ast::Builtin::AddParamInt: + case ast::Builtin::AddParamBigInt: { + add_sql_type = ast::BuiltinType::Integer; + break; + } + case ast::Builtin::AddParamReal: + case ast::Builtin::AddParamDouble: { + add_sql_type = ast::BuiltinType::Real; + break; + } + case ast::Builtin::AddParamDate: { + add_sql_type = ast::BuiltinType::Date; + break; } + case ast::Builtin::AddParamTimestamp: { + add_sql_type = ast::BuiltinType::Timestamp; + break; + } + case ast::Builtin::AddParamString: { + add_sql_type = ast::BuiltinType::StringVal; + break; + } + default: { + UNREACHABLE("Undefined parameter call!!"); + } + } + if (call->Arguments()[1]->GetType() != GetBuiltinType(add_sql_type)) { + ReportIncorrectCallArg(call, 1, GetBuiltinType(add_sql_type)); + return; } } call->SetType(ast::BuiltinType::Get(GetContext(), ast::BuiltinType::Nil)); diff --git a/src/execution/sema/sema_expr.cpp b/src/execution/sema/sema_expr.cpp index ab0b347088..fc8c68945a 100644 --- a/src/execution/sema/sema_expr.cpp +++ b/src/execution/sema/sema_expr.cpp @@ -173,7 +173,7 @@ void Sema::VisitLambdaExpr(ast::LambdaExpr *node) { for (auto expr : node->GetCaptureIdents()) { auto ident = expr->As(); Resolve(ident); - if (ident->GetType()->SafeAs()) { + if (ident->GetType()->SafeAs() != nullptr) { auto type_repr = factory->NewPointerType( SourcePosition(), factory->NewIdentifierExpr( diff --git a/src/execution/sema/sema_type.cpp b/src/execution/sema/sema_type.cpp index e467ce31a6..49b00f3972 100644 --- a/src/execution/sema/sema_type.cpp +++ b/src/execution/sema/sema_type.cpp @@ -91,7 +91,7 @@ void Sema::VisitMapTypeRepr(ast::MapTypeRepr *node) { } void Sema::VisitLambdaTypeRepr(ast::LambdaTypeRepr *node) { - ast::FunctionType *fn_type = Resolve(node->FunctionType())->SafeAs(); + auto *fn_type = Resolve(node->FunctionType())->SafeAs(); if (fn_type == nullptr) { return; } diff --git a/src/include/execution/ast/ast.h b/src/include/execution/ast/ast.h index a50021042d..8969dd2f1f 100644 --- a/src/include/execution/ast/ast.h +++ b/src/include/execution/ast/ast.h @@ -719,7 +719,7 @@ class BreakStmt : public Stmt { * @param node The node to check. * @return `true` if the node is a break statement, `false` otherwise. */ - static bool classof(const AstNode *node) { return node->GetKind() == Kind::BreakStmt; } + static bool classof(const AstNode *node) { return node->GetKind() == Kind::BreakStmt; } // NOLINT }; /** diff --git a/src/include/execution/ast/udf/udf_ast_node_visitor.h b/src/include/execution/ast/udf/udf_ast_node_visitor.h index 512230995b..3fd2b739d3 100644 --- a/src/include/execution/ast/udf/udf_ast_node_visitor.h +++ b/src/include/execution/ast/udf/udf_ast_node_visitor.h @@ -27,7 +27,7 @@ class FunctionAST; class ASTNodeVisitor { public: - virtual ~ASTNodeVisitor() {} + virtual ~ASTNodeVisitor() = default; virtual void Visit(AbstractAST *ast) = 0; virtual void Visit(StmtAST *ast) = 0; virtual void Visit(ExprAST *ast) = 0; diff --git a/src/include/execution/ast/udf/udf_ast_nodes.h b/src/include/execution/ast/udf/udf_ast_nodes.h index 21244487dd..10ae1812ca 100644 --- a/src/include/execution/ast/udf/udf_ast_nodes.h +++ b/src/include/execution/ast/udf/udf_ast_nodes.h @@ -29,7 +29,7 @@ class AbstractAST { // StmtAST - Base class for all statement nodes. class StmtAST : public AbstractAST { public: - virtual ~StmtAST() = default; + ~StmtAST() override = default; void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; @@ -37,7 +37,7 @@ class StmtAST : public AbstractAST { // ExprAST - Base class for all expression nodes. class ExprAST : public StmtAST { public: - virtual ~ExprAST() = default; + ~ExprAST() override = default; void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; diff --git a/src/include/execution/exec/execution_context.h b/src/include/execution/exec/execution_context.h index 7b327e5d8c..b22a374220 100644 --- a/src/include/execution/exec/execution_context.h +++ b/src/include/execution/exec/execution_context.h @@ -185,7 +185,7 @@ class EXPORT ExecutionContext { void InitializeParallelOUFeatureVector(selfdriving::ExecOUFeatureVector *ouvec, pipeline_id_t pipeline_id); /** Initialize the UDF parameter stack. */ - void StartParams() { udf_param_stack_.push_back({}); } + void StartParams() { udf_param_stack_.emplace_back(); } /** Remove an element from the UDF parameter stack. */ void PopParams() { udf_param_stack_.pop_back(); } diff --git a/src/include/execution/vm/bytecode_function_info.h b/src/include/execution/vm/bytecode_function_info.h index 9eeb1f2f22..93eae8bbec 100644 --- a/src/include/execution/vm/bytecode_function_info.h +++ b/src/include/execution/vm/bytecode_function_info.h @@ -297,7 +297,7 @@ class FunctionInfo { * the body of the lambda. This action is evaluated when we * later visit the declaration for the function itself. */ - void DeferAction(const std::function action) { actions_.push_back(action); } + void DeferAction(std::function &&action) { actions_.push_back(std::move(action)); } /** * @return `true` if the TBC function represented by this object diff --git a/src/parser/expression/constant_value_expression.cpp b/src/parser/expression/constant_value_expression.cpp index a36b4eb6d6..717aed049d 100644 --- a/src/parser/expression/constant_value_expression.cpp +++ b/src/parser/expression/constant_value_expression.cpp @@ -68,7 +68,7 @@ T ConstantValueExpression::Peek() const { } // NOLINTNEXTLINE: bugprone-suspicious-semicolon: seems like a false positive because of constexpr if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v) { // NOLINT return static_cast(GetInteger().val_); } // NOLINTNEXTLINE: bugprone-suspicious-semicolon: seems like a false positive because of constexpr diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index 8737027b9c..febd0a7fcc 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -220,7 +220,7 @@ TrafficCopResult TrafficCop::ExecuteSetStatement(common::ManagedPointerWriteRowDescription(cols, {network::FieldFormat::text}); out->WriteDataRow(reinterpret_cast(&result), cols, {network::FieldFormat::text}); - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } TrafficCopResult TrafficCop::ExecuteCreateStatement( @@ -269,35 +269,35 @@ TrafficCopResult TrafficCop::ExecuteCreateStatement( if (execution::sql::DDLExecutors::CreateTableExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor(), connection_ctx->GetDatabaseOid())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } break; } case network::QueryType::QUERY_CREATE_DB: { if (execution::sql::DDLExecutors::CreateDatabaseExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } break; } case network::QueryType::QUERY_CREATE_INDEX: { if (execution::sql::DDLExecutors::CreateIndexExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } break; } case network::QueryType::QUERY_CREATE_SCHEMA: { if (execution::sql::DDLExecutors::CreateNamespaceExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } break; } case network::QueryType::QUERY_CREATE_FUNCTION: { if (execution::sql::DDLExecutors::CreateFunctionExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { - return {ResultType::COMPLETE, 0}; + return {ResultType::COMPLETE, 0U}; } break; } @@ -329,7 +329,7 @@ TrafficCopResult TrafficCop::ExecuteDropStatement( case network::QueryType::QUERY_DROP_TABLE: { if (execution::sql::DDLExecutors::DropTableExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } break; } @@ -337,21 +337,21 @@ TrafficCopResult TrafficCop::ExecuteDropStatement( if (execution::sql::DDLExecutors::DropDatabaseExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor(), connection_ctx->GetDatabaseOid())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } break; } case network::QueryType::QUERY_DROP_INDEX: { if (execution::sql::DDLExecutors::DropIndexExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } break; } case network::QueryType::QUERY_DROP_SCHEMA: { if (execution::sql::DDLExecutors::DropNamespaceExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } break; } @@ -411,7 +411,7 @@ TrafficCopResult TrafficCop::ExecuteExplainStatement( out->WriteDataRow(reinterpret_cast(&plan_string_val), output_columns, {network::FieldFormat::text}); - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } std::variant, common::ErrorData> TrafficCop::ParseQuery( @@ -468,7 +468,7 @@ TrafficCopResult TrafficCop::BindQuery( return {ResultType::ERROR, error}; } - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } TrafficCopResult TrafficCop::CodegenPhysicalPlan( @@ -495,7 +495,7 @@ TrafficCopResult TrafficCop::CodegenPhysicalPlan( if (portal->GetStatement()->GetExecutableQuery() != nullptr && use_query_cache_) { // We've already codegen'd this, move on... - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } // TODO(WAN): see #1047 @@ -531,7 +531,7 @@ TrafficCopResult TrafficCop::CodegenPhysicalPlan( portal->GetStatement()->SetExecutableQuery(std::move(exec_query)); - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } TrafficCopResult TrafficCop::RunExecutableQuery(const common::ManagedPointer connection_ctx, From 3189debdb3bc9d4f5f4fe6484bb665ba9badce94 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 18 Jun 2021 11:34:22 -0400 Subject: [PATCH 041/139] big refactor of UDF AST in progress --- src/execution/compiler/udf/udf_codegen.cpp | 105 ++- src/include/execution/ast/udf/udf_ast_nodes.h | 633 +++++++++++++++--- 2 files changed, 587 insertions(+), 151 deletions(-) diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index afa81ce5a1..b066e1da58 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -84,12 +84,11 @@ execution::ast::File *UDFCodegen::Finish() { } void UDFCodegen::Visit(ast::udf::CallExprAST *ast) { - auto &args = ast->args; - std::vector args_ast; - std::vector args_ast_region_vec; - std::vector arg_types; + std::vector args_ast{}; + std::vector args_ast_region_vec{}; + std::vector arg_types{}; - for (auto &arg : args) { + for (auto &arg : ast->Args()) { arg->Accept(this); args_ast.push_back(dst_); args_ast_region_vec.push_back(dst_); @@ -98,14 +97,14 @@ void UDFCodegen::Visit(ast::udf::CallExprAST *ast) { NOISEPAGE_ASSERT(builtin->IsSqlValueType(), "Param is not SQL Value Type"); arg_types.push_back(GetCatalogTypeOidFromSQLType(builtin->GetKind())); } - auto proc_oid = accessor_->GetProcOid(ast->callee, arg_types); + auto proc_oid = accessor_->GetProcOid(ast->Callee(), arg_types); NOISEPAGE_ASSERT(proc_oid != catalog::INVALID_PROC_OID, "Invalid call"); auto context = accessor_->GetProcCtxPtr(proc_oid); if (context->IsBuiltin()) { fb_->Append(codegen_->MakeStmt(codegen_->CallBuiltin(context->GetBuiltin(), std::move(args_ast)))); } else { - auto it = str_to_ident_.find(ast->callee); + auto it = str_to_ident_.find(ast->Callee()); execution::ast::Identifier ident_expr; if (it != str_to_ident_.end()) { ident_expr = it->second; @@ -121,8 +120,6 @@ void UDFCodegen::Visit(ast::udf::CallExprAST *ast) { } fb_->Append(codegen_->MakeStmt(codegen_->Call(ident_expr, args_ast_region_vec))); } - - // fb_->Append(codegen_->Call) } void UDFCodegen::Visit(ast::udf::StmtAST *ast) { UNREACHABLE("Not implemented"); } @@ -130,17 +127,17 @@ void UDFCodegen::Visit(ast::udf::StmtAST *ast) { UNREACHABLE("Not implemented"); void UDFCodegen::Visit(ast::udf::ExprAST *ast) { UNREACHABLE("Not implemented"); } void UDFCodegen::Visit(ast::udf::DeclStmtAST *ast) { - if (ast->name == "*internal*") { + if (ast->Name() == "*internal*") { return; } - execution::ast::Identifier ident = codegen_->MakeFreshIdentifier(ast->name); - str_to_ident_.emplace(ast->name, ident); + execution::ast::Identifier ident = codegen_->MakeFreshIdentifier(ast->Name()); + str_to_ident_.emplace(ast->Name(), ident); auto prev_type = current_type_; execution::ast::Expr *tpl_type = nullptr; - if (ast->type == type::TypeId::INVALID) { + if (ast->Type() == type::TypeId::INVALID) { // record type execution::util::RegionVector fields(codegen_->GetAstContext()->GetRegion()); - for (auto p : udf_ast_context_->GetRecordType(ast->name)) { + for (auto p : udf_ast_context_->GetRecordType(ast->Name())) { fields.push_back(codegen_->MakeField(codegen_->MakeIdentifier(p.first), codegen_->TplType(execution::sql::GetTypeId(p.second)))); } @@ -148,12 +145,11 @@ void UDFCodegen::Visit(ast::udf::DeclStmtAST *ast) { aux_decls_.push_back(record_decl); tpl_type = record_decl->TypeRepr(); } else { - tpl_type = codegen_->TplType(execution::sql::GetTypeId(ast->type)); + tpl_type = codegen_->TplType(execution::sql::GetTypeId(ast->Type())); } - current_type_ = ast->type; - if (ast->initial != nullptr) { - // Visit(ast->initial.get()); - ast->initial->Accept(this); + current_type_ = ast->Type(); + if (ast->Initial() != nullptr) { + ast->Initial()->Accept(this); fb_->Append(codegen_->DeclareVar(ident, tpl_type, dst_)); } else { fb_->Append(codegen_->DeclareVarNoInit(ident, tpl_type)); @@ -162,21 +158,21 @@ void UDFCodegen::Visit(ast::udf::DeclStmtAST *ast) { } void UDFCodegen::Visit(ast::udf::FunctionAST *ast) { - for (size_t i = 0; i < ast->param_types_.size(); i++) { + for (size_t i = 0; i < ast->ParameterTypes().size(); i++) { // auto param_type = codegen_->TplType(ast->param_types_[i]); - str_to_ident_.emplace(ast->param_names_[i], codegen_->MakeFreshIdentifier("udf")); + str_to_ident_.emplace(ast->ParameterNames().at(i), codegen_->MakeFreshIdentifier("udf")); } - ast->body.get()->Accept(this); + ast->Body()->Accept(this); } void UDFCodegen::Visit(ast::udf::VariableExprAST *ast) { - auto it = str_to_ident_.find(ast->name); + auto it = str_to_ident_.find(ast->Name()); NOISEPAGE_ASSERT(it != str_to_ident_.end(), "variable not declared"); dst_ = codegen_->MakeExpr(it->second); } void UDFCodegen::Visit(ast::udf::ValueExprAST *ast) { - auto val = common::ManagedPointer(ast->value_).CastManagedPointerTo(); + auto val = common::ManagedPointer(ast->Value()).CastManagedPointerTo(); if (val->IsNull()) { dst_ = codegen_->ConstNull(current_type_); return; @@ -211,29 +207,24 @@ void UDFCodegen::Visit(ast::udf::ValueExprAST *ast) { void UDFCodegen::Visit(ast::udf::AssignStmtAST *ast) { type::TypeId left_type = type::TypeId::INVALID; - udf_ast_context_->GetVariableType(ast->lhs->name, &left_type); + udf_ast_context_->GetVariableType(ast->Destination()->Name(), &left_type); current_type_ = left_type; - reinterpret_cast(ast->rhs.get())->Accept(this); + reinterpret_cast(ast->Source())->Accept(this); auto rhs_expr = dst_; - auto it = str_to_ident_.find(ast->lhs->name); + auto it = str_to_ident_.find(ast->Destination()->Name()); NOISEPAGE_ASSERT(it != str_to_ident_.end(), "Variable not found"); auto left_codegen_ident = it->second; auto *left_expr = codegen_->MakeExpr(left_codegen_ident); - - // auto right_type = rhs_expr->GetType()->GetTypeId(); - - // if (left_type == type::TypeId::VARCHAR) { fb_->Append(codegen_->Assign(left_expr, rhs_expr)); - // } } void UDFCodegen::Visit(ast::udf::BinaryExprAST *ast) { execution::parsing::Token::Type op_token; bool compare = false; - switch (ast->op) { + switch (ast->Op()) { case noisepage::parser::ExpressionType::OPERATOR_DIVIDE: op_token = execution::parsing::Token::Type::SLASH; break; @@ -279,10 +270,10 @@ void UDFCodegen::Visit(ast::udf::BinaryExprAST *ast) { // TODO(tanujnay112): figure out concatenation operation from expressions? UNREACHABLE("Unsupported expression"); } - ast->lhs->Accept(this); + ast->Left()->Accept(this); auto lhs_expr = dst_; - ast->rhs->Accept(this); + ast->Right()->Accept(this); auto rhs_expr = dst_; if (compare) { dst_ = codegen_->Compare(op_token, lhs_expr, rhs_expr); @@ -292,38 +283,38 @@ void UDFCodegen::Visit(ast::udf::BinaryExprAST *ast) { } void UDFCodegen::Visit(ast::udf::IfStmtAST *ast) { - ast->cond_expr->Accept(this); + ast->Condition()->Accept(this); auto cond = dst_; If branch(fb_, cond); - ast->then_stmt->Accept(this); - if (ast->else_stmt != nullptr) { + ast->Then()->Accept(this); + if (ast->Else() != nullptr) { branch.Else(); - ast->else_stmt->Accept(this); + ast->Else()->Accept(this); } branch.EndIf(); } void UDFCodegen::Visit(ast::udf::IsNullExprAST *ast) { - ast->child_->Accept(this); + ast->Child()->Accept(this); auto chld = dst_; dst_ = codegen_->CallBuiltin(execution::ast::Builtin::IsValNull, {chld}); - if (!ast->is_null_check_) { + if (!ast->IsNullCheck()) { dst_ = codegen_->UnaryOp(execution::parsing::Token::Type::BANG, dst_); } } void UDFCodegen::Visit(ast::udf::SeqStmtAST *ast) { - for (auto &stmt : ast->stmts) { + for (auto &stmt : ast->Statements()) { stmt->Accept(this); } } void UDFCodegen::Visit(ast::udf::WhileStmtAST *ast) { - ast->cond_expr->Accept(this); + ast->Condition()->Accept(this); auto cond = dst_; Loop loop(fb_, cond); - ast->body_stmt->Accept(this); + ast->Body()->Accept(this); loop.EndLoop(); } @@ -331,7 +322,7 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { // Once we encounter a For-statement we know we need an execution context needs_exec_ctx_ = true; - const auto query = common::ManagedPointer(ast->query_); + const auto query = common::ManagedPointer(ast->Query()); auto exec_ctx = fb_->GetParameterByPosition(0); binder::BindNodeVisitor visitor{common::ManagedPointer(accessor_), db_oid_}; @@ -352,7 +343,7 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { exec_ctx->As()->Name(), codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); std::size_t i{0}; - for (const auto &var : ast->vars_) { + for (const auto &var : ast->Variables()) { var_idents.push_back(str_to_ident_.find(var)->second); auto var_ident = var_idents.back(); NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, "Can't support non scalars yet!"); @@ -375,7 +366,7 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { } auto prev_fb = fb_; fb_ = &fn; - ast->body_stmt_->Accept(this); + ast->Body()->Accept(this); fb_ = prev_fb; } @@ -473,7 +464,7 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { } void UDFCodegen::Visit(ast::udf::RetStmtAST *ast) { - ast->expr->Accept(reinterpret_cast(this)); + ast->Return()->Accept(reinterpret_cast(this)); auto ret_expr = dst_; fb_->Append(codegen_->Return(ret_expr)); } @@ -483,11 +474,11 @@ void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { // we know we need an execution context needs_exec_ctx_ = true; auto exec_ctx = fb_->GetParameterByPosition(0); - const auto query = common::ManagedPointer(ast->query); + const auto query = common::ManagedPointer(ast->Query()); binder::BindNodeVisitor visitor(common::ManagedPointer(accessor_), db_oid_); - auto &query_params = ast->udf_params; + const auto &query_params = ast->Parameters(); // NOTE(Kyle): Assumptions: // - This is a valid optimizer timeout @@ -514,12 +505,12 @@ void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { std::vector assignees{}; execution::util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; for (auto &col : cols) { - execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->var_name)->second); + execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->Name())->second); type::TypeId udf_type{}; - udf_ast_context_->GetVariableType(ast->var_name, &udf_type); + udf_ast_context_->GetVariableType(ast->Name(), &udf_type); if (udf_type == type::TypeId::INVALID) { // Record type - auto &struct_vars = udf_ast_context_->GetRecordType(ast->var_name); + auto &struct_vars = udf_ast_context_->GetRecordType(ast->Name()); if (captures.empty()) { captures.push_back(capture_var); } @@ -570,7 +561,7 @@ void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { // Set its execution context to whatever exec context was passed in here fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); - std::vector>::iterator> sorted_vec{}; + std::vector>::const_iterator> sorted_vec{}; for (auto it = query_params.begin(); it != query_params.end(); it++) { sorted_vec.push_back(it); } @@ -630,7 +621,7 @@ void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); for (auto &col : cols) { - execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->var_name)->second); + execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->Name())->second); auto lhs = capture_var; if (cols.size() > 1) { // Record struct type @@ -653,9 +644,9 @@ void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { } void UDFCodegen::Visit(ast::udf::MemberExprAST *ast) { - ast->object->Accept(reinterpret_cast(this)); + ast->Object()->Accept(reinterpret_cast(this)); auto object = dst_; - dst_ = codegen_->AccessStructMember(object, codegen_->MakeIdentifier(ast->field)); + dst_ = codegen_->AccessStructMember(object, codegen_->MakeIdentifier(ast->FieldName())); } } // namespace udf diff --git a/src/include/execution/ast/udf/udf_ast_nodes.h b/src/include/execution/ast/udf/udf_ast_nodes.h index 10ae1812ca..03304c49dc 100644 --- a/src/include/execution/ast/udf/udf_ast_nodes.h +++ b/src/include/execution/ast/udf/udf_ast_nodes.h @@ -18,217 +18,662 @@ namespace execution { namespace ast { namespace udf { -// AbstractAST - Base class for all AST nodes. +/** + * The AbstractAST class serves as a base class for all AST nodes. + */ class AbstractAST { public: + /** + * Destroy the AST node. + */ virtual ~AbstractAST() = default; + /** + * AST visitor pattern. + * @param visitor The visitor + */ virtual void Accept(ASTNodeVisitor *visitor) { visitor->Visit(this); } }; -// StmtAST - Base class for all statement nodes. +/** + * The StmtAST class serves as the base class for all statement nodes. + */ class StmtAST : public AbstractAST { public: + /** + * Destroy the AST node. + */ ~StmtAST() override = default; + /** + * AST visitor pattern. + * @param visitor The visitor + */ void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; -// ExprAST - Base class for all expression nodes. +/** + * The ExprAST class serves as the base class for all expression nodes. + */ class ExprAST : public StmtAST { public: + /** + * Destroy the AST node. + */ ~ExprAST() override = default; + /** + * AST visitor pattern. + * @param visitor The visitor + */ void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } }; -// DoubleExprAST - Expression class for numeric literals like "1.1". +/** + * The ValueExprAST class represents literal values. + */ class ValueExprAST : public ExprAST { public: - std::unique_ptr value_; + /** + * Construct a new ValueExprAST instance. + * @param value The AbstractExpression that represents the value + */ + explicit ValueExprAST(std::unique_ptr &&value) : value_(std::move(value)) {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } - explicit ValueExprAST(std::unique_ptr value) : value_(std::move(value)) {} + /** @return A mutable pointer to the value expression */ + parser::AbstractExpression *Value() { return value_.get(); } - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + /** @return An immutable pointer to the value expression */ + const parser::AbstractExpression *Value() const { return value_.get(); } + + private: + /** The expression that represents the value */ + std::unique_ptr value_; }; +/** + * The IsNullExprAST class represents an expression that performs a NULL check. + */ class IsNullExprAST : public ExprAST { public: - bool is_null_check_; - std::unique_ptr child_; + /** + * Construct a new IsNullExprAST instance. + * @param is_null_check The NULL check flag + * @param child The child expression + */ + IsNullExprAST(bool is_null_check, std::unique_ptr &&child) + : is_null_check_{is_null_check}, child_{std::move(child)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } - IsNullExprAST(bool is_null_check, std::unique_ptr child) - : is_null_check_(is_null_check), child_(std::move(child)) {} + /** @return `true` if the NULL check is performed, `false` otherwise */ + bool IsNullCheck() const { return is_null_check_; } - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + /** @return The child expression */ + ExprAST *Child() { return child_.get(); } + + /** @return The child expression */ + const ExprAST *Child() const { return child_.get(); } + + private: + /** The NULL check flag */ + bool is_null_check_; + + /** The child expression */ + std::unique_ptr child_; }; -// VariableExprAST - Expression class for referencing a variable, like "a". +/** + * The VariableExprAST class represents an expression that references a variable. + */ class VariableExprAST : public ExprAST { public: - std::string name; + /** + * Construct a new VariableExprAST instance. + * @param name The name of the variable + */ + explicit VariableExprAST(std::string name) : name_{std::move(name)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } - explicit VariableExprAST(const std::string &name) : name(name) {} + /** @return The name of the variable */ + const std::string &Name() const { return name_; } - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + private: + /** The name of the variable */ + const std::string name_; }; -// VariableExprAST - Expression class for referencing a variable, like "a". +/** + * The MemberExprAST class represents a structure member expression. + */ class MemberExprAST : public ExprAST { public: - std::unique_ptr object; - std::string field; - + /** + * Construct a new MemberExprAST instance. + * @param object The structure + * @param field The name of the field in the structure + */ MemberExprAST(std::unique_ptr &&object, std::string field) - : object(std::move(object)), field(field) {} + : object_{std::move(object)}, field_(std::move(field)) {} + /** + * AST visitor pattern. + * @param visitor The visitor + */ void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + + /** @return The object */ + VariableExprAST *Object() { return object_.get(); } + + /** @return The object */ + const VariableExprAST *Object() const { return object_.get(); } + + /** @return The name of the field */ + const std::string &FieldName() const { return field_; } + + private: + /** The expression for the object */ + std::unique_ptr object_; + + /** The identifier for the field in the object */ + std::string field_; }; -// BinaryExprAST - Expression class for a binary operator. +/** + * The BinaryExprAST class represents a generic binary expression. + */ class BinaryExprAST : public ExprAST { public: - parser::ExpressionType op; - std::unique_ptr lhs, rhs; + /** + * Construct a new BinaryExprAST instance. + * @param op The expression type for the operation + * @param lhs The expression on the left-hande side of the operation + * @param rhs The expression on the right-hand side of the operation + */ + BinaryExprAST(parser::ExpressionType op, std::unique_ptr &&lhs, std::unique_ptr &&rhs) + : op_{op}, lhs_{std::move(lhs)}, rhs_{std::move(rhs)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } - BinaryExprAST(parser::ExpressionType op, std::unique_ptr lhs, std::unique_ptr rhs) - : op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {} + /** @return The expression type for the operation */ + parser::ExpressionType Op() const { return op_; } - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + /** @return A mutable pointer to the left expression */ + ExprAST *Left() { return lhs_.get(); } + + /** @return An immutable pointer to the left expression */ + const ExprAST *Left() const { return lhs_.get(); } + + /** @return A mutable pointer to the right expression */ + ExprAST *Right() { return rhs_.get(); } + + /** @return An immutable pointer to the right expression */ + const ExprAST *Right() const { return rhs_.get(); } + + private: + /** The expression type for the operation */ + parser::ExpressionType op_; + + /** The expression on the left-hand side of the operation */ + std::unique_ptr lhs_; + + /** The expression on the right-hand side of the operation */ + std::unique_ptr rhs_; }; -// CallExprAST - Expression class for function calls. +/** + * The CallExprAST class represents a function call expression. + */ class CallExprAST : public ExprAST { public: - std::string callee; - std::vector> args; + /** + * Construct a new CallExprAST instance. + * @param callee The name of the called function + * @param args The arguments to the function call + */ + CallExprAST(std::string callee, std::vector> &&args) + : callee_{std::move(callee)}, args_{std::move(args)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } - CallExprAST(const std::string &callee, std::vector> args) - : callee(callee), args(std::move(args)) {} + /** @return The name of the called function */ + const std::string &Callee() const { return callee_; } - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + /** @return A mutable reference to the function call arguments */ + std::vector> &Args() { return args_; } + + /** @return An immutable reference to the function call arguments */ + const std::vector> &Args() const { return args_; } + + private: + /** The name of the called function */ + std::string callee_; + + /** The arguments to the function call */ + std::vector> args_; }; -// SeqStmtAST - Statement class for sequence of statements +/** + * The SeqStmtAST class represents a sequence of statements. + */ class SeqStmtAST : public StmtAST { public: - std::vector> stmts; + /** + * Construct a new SeqStmtAST instance. + * @param statements The collection of statements in the sequence + */ + explicit SeqStmtAST(std::vector> &&statements) : statements_(std::move(statements)) {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } - explicit SeqStmtAST(std::vector> &&stmts) : stmts(std::move(stmts)) {} + /** @return A mutable reference to the statements in the sequence */ + std::vector> &Statements() { return statements_; } - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + /** @return An immutable reference to the statements in the sequence */ + const std::vector> &Statements() const { return statements_; } + + private: + /** The collection of statements in the sequence */ + std::vector> statements_; }; // DeclStmtAST - Statement class for sequence of statements +/** + * The DeclStmtAST class represents a declaration statement. + */ class DeclStmtAST : public StmtAST { public: - std::string name; - type::TypeId type; - std::unique_ptr initial; + /** + * Construct a new DeclStmtAST instance. + * @param name The name of the variable that is declared + * @param type The type of the declared variable + * @param initial The initial value in the declaration + */ + DeclStmtAST(std::string name, type::TypeId type, std::unique_ptr &&initial) + : name_{std::move(name)}, type_(type), initial_{std::move(initial)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; - DeclStmtAST(std::string name, type::TypeId type, std::unique_ptr initial) - : name(std::move(name)), type(std::move(type)), initial(std::move(initial)) {} + /** @return The name of the declared variable */ + const std::string &Name() const { return name_; } - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + /** @return The type of the declared variable */ + type::TypeId Type() const { return type_; } + + /** @return A mutable pointer to the initial value expression */ + ExprAST *Initial() { return initial_.get(); } + + /** @return An immutable pointer to the initial value expression */ + const ExprAST *Initial() const { return initial_.get(); } + + private: + /** The name of the variable declared in the statement */ + std::string name_; + + /** The type of the declared variable */ + type::TypeId type_; + + /** The initial value of the declaration */ + std::unique_ptr initial_; }; -// IfStmtAST - Statement class for if/then/else. +/** + * The IfStmtAST class represents an IF/THEN/ELSE construct. + */ class IfStmtAST : public StmtAST { public: - std::unique_ptr cond_expr; - std::unique_ptr then_stmt, else_stmt; - + /** + * Construct a new IfStmtAST instance. + * @param cond_expr The conditional expression + * @param then_stmt The `then` statement + * @param else_stmt The `else` statement + */ + IfStmtAST(std::unique_ptr &&cond_expr, std::unique_ptr &&then_stmt, + std::unique_ptr &&else_stmt) + : cond_expr_{std::move(cond_expr)}, then_stmt_{std::move(then_stmt)}, else_stmt_{std::move(else_stmt)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; - IfStmtAST(std::unique_ptr cond_expr, std::unique_ptr then_stmt, std::unique_ptr else_stmt) - : cond_expr(std::move(cond_expr)), then_stmt(std::move(then_stmt)), else_stmt(std::move(else_stmt)) {} + /** @return The conditional expression */ + ExprAST *Condition() { return cond_expr_.get(); } + + /** @return The conditional expression */ + const ExprAST *Condition() const { return cond_expr_.get(); } + + /** @return The `then` statement */ + StmtAST *Then() { return then_stmt_.get(); } + + /** @return The `then` statement */ + const StmtAST *Then() const { return then_stmt_.get(); } + + /** @return The `else` statement */ + StmtAST *Else() { return else_stmt_.get(); } + + /** @return The `else` statement */ + const StmtAST *Else() const { return else_stmt_.get(); } + + private: + /** The conditional expression */ + std::unique_ptr cond_expr_; + + /** The `then` statement */ + std::unique_ptr then_stmt_; + + /** The `else` statement */ + std::unique_ptr else_stmt_; }; +/** + * The ForStmtAST class represents a `for`-loop construct. + */ class ForStmtAST : public StmtAST { public: - std::vector vars_; - std::unique_ptr query_; - std::unique_ptr body_stmt_; - + /** + * Construct a new ForStmtAST instance. + * @param variables The collection of variables in the loop + * @param query The associated query + * @param body The body of the loop + */ + ForStmtAST(std::vector &&variables, std::unique_ptr &&query, + std::unique_ptr body) + : variables_{std::move(variables)}, query_{std::move(query)}, body_{std::move(body)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; - ForStmtAST(std::vector &&vars_vec, std::unique_ptr query, - std::unique_ptr body_stmt) - : vars_(std::move(vars_vec)), query_(std::move(query)), body_stmt_(std::move(body_stmt)) {} + /** @return The collection of loop variables */ + const std::vector &Variables() const { return variables_; } + + /** @return The associated query */ + parser::ParseResult *Query() { return query_.get(); } + + /** @return The associated query */ + const parser::ParseResult *Query() const { return query_.get(); } + + /** @return The loop body statement */ + StmtAST *Body() { return body_.get(); } + + /** @return The loop body statement */ + const StmtAST *Body() const { return body_.get(); } + + private: + /** The collection of loop variables */ + std::vector variables_; + + /** The associated query */ + std::unique_ptr query_; + + /** The loop body statement */ + std::unique_ptr body_; }; -// WhileAST - Statement class for while loop +/** + * The WhileStmtAST represents a `while`-loop construct. + */ class WhileStmtAST : public StmtAST { public: - std::unique_ptr cond_expr; - std::unique_ptr body_stmt; - + /** + * Construct a new WhileStmtAST instance. + * @param condition The loop condition + * @param body The loop body statement + */ + WhileStmtAST(std::unique_ptr &&condition, std::unique_ptr &&body) + : condition_{std::move(condition)}, body_{std::move(body)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } - WhileStmtAST(std::unique_ptr cond_expr, std::unique_ptr body_stmt) - : cond_expr(std::move(cond_expr)), body_stmt(std::move(body_stmt)) {} + /** @return The loop condition */ + ExprAST *Condition() { return condition_.get(); } + + /** @return The loop condition */ + const ExprAST *Condition() const { return condition_.get(); } + + /** @return The loop body statement */ + StmtAST *Body() { return body_.get(); } + + /** @return The loop body statement */ + const StmtAST *Body() const { return body_.get(); } + + private: + /** The loop condition */ + std::unique_ptr condition_; + + /** The loop body statement */ + std::unique_ptr body_; }; -// RetStmtAST - Statement class for sequence of statements +/** + * The RetStmtAST class represents a `return` statement. + */ class RetStmtAST : public StmtAST { public: - std::unique_ptr expr; + /** + * Construct a new RetStmtAST instance. + * @param ret_expr The `return` expression + */ + explicit RetStmtAST(std::unique_ptr &&ret_expr) : ret_expr_{std::move(ret_expr)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } - explicit RetStmtAST(std::unique_ptr expr) : expr(std::move(expr)) {} + /** @return The `return` expression */ + ExprAST *Return() { return ret_expr_.get(); } - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + /** @return The `return` expression */ + const ExprAST *Return() const { return ret_expr_.get(); } + + private: + /** The `return` expression */ + std::unique_ptr ret_expr_; }; -// AssignStmtAST - Expression class for a binary operator. +/** + * The AssignStmtAST class represents an assignment statement. + */ class AssignStmtAST : public ExprAST { public: - std::unique_ptr lhs; - std::unique_ptr rhs; + /** + * Construct a new AssignStmtAST instance. + * @param dst The variable that represents the destination of the assignment + * @param src The expression that represents the source of the assignment + */ + AssignStmtAST(std::unique_ptr &&dst, std::unique_ptr &&src) + : dst_{std::move(dst)}, src_{std::move(src)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } - AssignStmtAST(std::unique_ptr lhs, std::unique_ptr rhs) - : lhs(std::move(lhs)), rhs(std::move(rhs)) {} + /** @return The destination variable of the assignment */ + VariableExprAST *Destination() { return dst_.get(); } - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + /** @return The destination variable of the assignment */ + const VariableExprAST *Destination() const { return dst_.get(); } + + /** @return The source expression of the assignment */ + ExprAST *Source() { return src_.get(); } + + /** @return The source expression of the assignment */ + const ExprAST *Source() const { return src_.get(); } + + private: + /** The destination of the assignment */ + std::unique_ptr dst_; + + /** The source of the assignment */ + std::unique_ptr src_; }; -// SQLStmtAST - Expression class for a SQL Statement. +/** + * The SQLStmtAST class represents a SQL statement. + */ class SQLStmtAST : public StmtAST { public: - std::unique_ptr query; - std::string var_name; - std::unordered_map> udf_params; + /** + * Construct a new SQLStmtAST instance. + * @param query The result of parsing the SQL query + * @param name The name of the variable to which results of the query are bound + * @param parameters The parameters to the query + */ + SQLStmtAST(std::unique_ptr &&query, std::string name, + std::unordered_map> &¶meters) + : query_{std::move(query)}, name_{std::move(name)}, parameters_(std::move(parameters)) {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } - SQLStmtAST(std::unique_ptr query, std::string var_name, - std::unordered_map> &&udf_params) - : query(std::move(query)), var_name(std::move(var_name)), udf_params(std::move(udf_params)) {} + /** @return The result of parsing the SQL query */ + parser::ParseResult *Query() { return query_.get(); } - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + /** @return The result of parsing the SQL query */ + const parser::ParseResult *Query() const { return query_.get(); } + + /** @return The variable name to which results are bound */ + const std::string &Name() const { return name_; } + + /** @return The parameters to the query */ + const std::unordered_map> &Parameters() const { return parameters_; } + + private: + /** The result of parsing the SQL query */ + std::unique_ptr query_; + + /** The variable name to which results of the query are bound */ + std::string name_; + + /** The parameters to the query */ + std::unordered_map> parameters_; }; -// DynamicSQLStmtAST - Expression class for a SQL Statement. +/** + * The DynamicSQLStmtAST class represents a dynamic SQL statement. + */ class DynamicSQLStmtAST : public StmtAST { public: - std::unique_ptr query; - std::string var_name; + /** + * Construct a new DynamicSQLStmtAST instance. + * @param query The expression that represents the query + * @param name The name of the variable to which results are bound + */ + DynamicSQLStmtAST(std::unique_ptr &&query, std::string var_name) + : query_{std::move(query)}, name_{std::move(name)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } - DynamicSQLStmtAST(std::unique_ptr query, std::string var_name) - : query(std::move(query)), var_name(std::move(var_name)) {} + /** @return The expression that represents the query */ + const ExprAST *Query() const { return query_.get(); } - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + /** @return The name of the variable to which results are bound */ + const std::string &Name() const { return name_; } + + private: + /** The expression that represents the query */ + std::unique_ptr query_; + + /** The name of the variable to which results are bound */ + std::string name_; }; -// FunctionAST - This class represents a function definition itself. +/** + * The FunctionAST class represents a function definition. + */ class FunctionAST : public AbstractAST { public: - std::unique_ptr body; - std::vector param_names_; - std::vector param_types_; + /** + * Construct a new FunctionAST instance. + * @param body The body of the function + * @param parameter_names The names of the parameters to the function + * @param parameter_type The types of the parameters to the function + */ + FunctionAST(std::unique_ptr &&body, std::vector &¶meter_names, + std::vector &¶meter_types) + : body_{std::move(body)}, + parameter_names_{std::move(parameter_names)}, + parameter_types_{std::move(parameter_types)} { + NOISEPAGE_ASSERT(parameter_names_.size() == parameter_types_.size(), "Parameter Name and Type Mismatch"); + } + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } - FunctionAST(std::unique_ptr body, std::vector &¶m_names, - std::vector &¶m_types) - : body(std::move(body)), param_names_(std::move(param_names)), param_types_(std::move(param_types)) {} + /** @return The function body */ + StmtAST *Body() { return body_.get(); } - void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + /** @return The function body */ + const StmtAST *Body() const { return body_.get(); } + + /** The function parameter names */ + const std::vector &ParameterNames() const { return parameter_names_; } + + /** @return The function parameter types */ + const std::vector &ParameterTypes() const { return parameter_types_; } + + private: + /** The body of the function */ + std::unique_ptr body_; + + /** The names of the parameters to the function */ + std::vector parameter_names_; + + /** The types of the parameters to the function */ + std::vector parameter_types_; }; // ---------------------------------------------------------------------------- From 9bad8e6e7370d4336b8f88682d78aa42b75fea56 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 18 Jun 2021 12:11:38 -0400 Subject: [PATCH 042/139] still a work in progress on refactoring UDF code generation and UDF ast --- .../compiler/compilation_context.cpp | 2 +- src/execution/compiler/udf/udf_codegen.cpp | 25 ++- src/execution/sql/ddl_executors.cpp | 2 +- src/include/execution/ast/udf/udf_ast_nodes.h | 2 +- .../execution/compiler/udf/udf_codegen.h | 88 ++++++---- src/include/parser/udf/udf_parser.h | 6 +- src/parser/udf/udf_parser.cpp | 156 +++++++++--------- 7 files changed, 149 insertions(+), 132 deletions(-) diff --git a/src/execution/compiler/compilation_context.cpp b/src/execution/compiler/compilation_context.cpp index f214d27773..fd5572f67b 100644 --- a/src/execution/compiler/compilation_context.cpp +++ b/src/execution/compiler/compilation_context.cpp @@ -96,7 +96,7 @@ CompilationContext::CompilationContext(ExecutableQuery *query, query_id_t query_ query_state_(query_state_type_, [this](CodeGen *codegen) { return codegen->MakeExpr(query_state_var_); }), output_callback_(output_callback), counters_enabled_(settings.GetIsCountersEnabled()), - pipeline_metrics_enabled_(output_callback ? false : settings.GetIsPipelineMetricsEnabled()) {} + pipeline_metrics_enabled_((output_callback != nullptr) ? false : settings.GetIsPipelineMetricsEnabled()) {} // TODO(Kyle): Why disable pipeline metrics whenever we have an output callback? diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index b066e1da58..0d0dff0720 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -27,10 +27,7 @@ #include "planner/plannodes/abstract_plan_node.h" -namespace noisepage { -namespace execution { -namespace compiler { -namespace udf { +namespace noisepage::execution::compiler::udf { UDFCodegen::UDFCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, ast::udf::UDFASTContext *udf_ast_context, CodeGen *codegen, catalog::db_oid_t db_oid) @@ -102,7 +99,7 @@ void UDFCodegen::Visit(ast::udf::CallExprAST *ast) { auto context = accessor_->GetProcCtxPtr(proc_oid); if (context->IsBuiltin()) { - fb_->Append(codegen_->MakeStmt(codegen_->CallBuiltin(context->GetBuiltin(), std::move(args_ast)))); + fb_->Append(codegen_->MakeStmt(codegen_->CallBuiltin(context->GetBuiltin(), args_ast))); } else { auto it = str_to_ident_.find(ast->Callee()); execution::ast::Identifier ident_expr; @@ -137,7 +134,7 @@ void UDFCodegen::Visit(ast::udf::DeclStmtAST *ast) { if (ast->Type() == type::TypeId::INVALID) { // record type execution::util::RegionVector fields(codegen_->GetAstContext()->GetRegion()); - for (auto p : udf_ast_context_->GetRecordType(ast->Name())) { + for (const auto &p : udf_ast_context_->GetRecordType(ast->Name())) { fields.push_back(codegen_->MakeField(codegen_->MakeIdentifier(p.first), codegen_->TplType(execution::sql::GetTypeId(p.second)))); } @@ -371,11 +368,12 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { } execution::util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; - for (auto it : str_to_ident_) { - if (it.first == "executionCtx") { + for (const auto &[name, identifier] : str_to_ident_) { + // TODO(Kyle): Why do we skip this particular identifier? + if (name == "executionCtx") { continue; } - captures.push_back(codegen_->MakeExpr(it.second)); + captures.push_back(codegen_->MakeExpr(identifier)); } lambda_expr = fn.FinishLambda(std::move(captures)); @@ -385,7 +383,7 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { // function into lambda_expr and will also feed in a lambda_expr to the compiler // TODO(Kyle): Using a NULL plan metatdata here... execution::exec::ExecutionSettings exec_settings{}; - const std::string dummy_query = ""; + const std::string dummy_query{}; auto exec_query = execution::compiler::CompilationContext::Compile( *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, std::nullopt, common::ManagedPointer{}, common::ManagedPointer{&dummy_query}, @@ -543,7 +541,7 @@ void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { // We want to pass something down that will materialize the lambda function // into lambda_expr and will also feed in a lambda_expr to the compiler execution::exec::ExecutionSettings exec_settings{}; - const std::string dummy_query = ""; + const std::string dummy_query{}; auto exec_query = execution::compiler::CompilationContext::Compile( *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, std::nullopt, common::ManagedPointer{}, common::ManagedPointer(&dummy_query), @@ -649,7 +647,4 @@ void UDFCodegen::Visit(ast::udf::MemberExprAST *ast) { dst_ = codegen_->AccessStructMember(object, codegen_->MakeIdentifier(ast->FieldName())); } -} // namespace udf -} // namespace compiler -} // namespace execution -} // namespace noisepage +} // namespace noisepage::execution::compiler::udf diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index 27ec09ee15..545ebf0feb 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -101,7 +101,7 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerGetReturnType())))}; compiler::udf::UDFCodegen udf_codegen{accessor.Get(), &fb, &udf_ast_context, &codegen, node->GetDatabaseOid()}; - udf_codegen.GenerateUDF(ast->body.get()); + udf_codegen.GenerateUDF(ast->Body()); auto *file = udf_codegen.Finish(); { diff --git a/src/include/execution/ast/udf/udf_ast_nodes.h b/src/include/execution/ast/udf/udf_ast_nodes.h index 03304c49dc..885cdd2496 100644 --- a/src/include/execution/ast/udf/udf_ast_nodes.h +++ b/src/include/execution/ast/udf/udf_ast_nodes.h @@ -605,7 +605,7 @@ class DynamicSQLStmtAST : public StmtAST { * @param query The expression that represents the query * @param name The name of the variable to which results are bound */ - DynamicSQLStmtAST(std::unique_ptr &&query, std::string var_name) + DynamicSQLStmtAST(std::unique_ptr &&query, std::string name) : query_{std::move(query)}, name_{std::move(name)} {} /** diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index e50e4f667a..6b913237cf 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -57,7 +57,10 @@ class UDFCodegen : ast::udf::ASTNodeVisitor { UDFCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, ast::udf::UDFASTContext *udf_ast_context, CodeGen *codegen, catalog::db_oid_t db_oid); - ~UDFCodegen() = default; + /** + * Destroy the UDF code generation context. + */ + ~UDFCodegen() override = default; /** * Generate a UDF from the given abstract syntax tree. @@ -67,98 +70,116 @@ class UDFCodegen : ast::udf::ASTNodeVisitor { /** * Visit an AbstractAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::AbstractAST *) override; + void Visit(ast::udf::AbstractAST *ast) override; /** * Visit a FunctionAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::FunctionAST *) override; + void Visit(ast::udf::FunctionAST *ast) override; /** * Visit a StmtAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::StmtAST *) override; + void Visit(ast::udf::StmtAST *ast) override; /** * Visit an ExprAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::ExprAST *) override; + void Visit(ast::udf::ExprAST *ast) override; /** * Visit a ValueExprAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::ValueExprAST *) override; + void Visit(ast::udf::ValueExprAST *ast) override; /** * Visit a VariableExprAST node. */ - void Visit(ast::udf::VariableExprAST *) override; + void Visit(ast::udf::VariableExprAST *ast) override; /** * Visit a BinaryExprAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::BinaryExprAST *) override; + void Visit(ast::udf::BinaryExprAST *ast) override; /** * Visit a CallExprAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::CallExprAST *) override; + void Visit(ast::udf::CallExprAST *ast) override; /** * Visit an IsNullExprAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::IsNullExprAST *) override; + void Visit(ast::udf::IsNullExprAST *ast) override; /** * Visit a SeqStmtAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::SeqStmtAST *) override; + void Visit(ast::udf::SeqStmtAST *ast) override; /** * Visit a DeclStmtNode node. + * @param ast The AST node to visit */ - void Visit(ast::udf::DeclStmtAST *) override; + void Visit(ast::udf::DeclStmtAST *ast) override; /** * Visit a IfStmtAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::IfStmtAST *) override; + void Visit(ast::udf::IfStmtAST *ast) override; /** * Visit a WhileStmtAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::WhileStmtAST *) override; + void Visit(ast::udf::WhileStmtAST *ast) override; /** * Visit a RetStmtAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::RetStmtAST *) override; + void Visit(ast::udf::RetStmtAST *ast) override; /** * Visit an AssignStmtAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::AssignStmtAST *) override; + void Visit(ast::udf::AssignStmtAST *ast) override; /** * Visit a SQLStmtAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::SQLStmtAST *) override; + void Visit(ast::udf::SQLStmtAST *ast) override; /** * Visit a DynamicSQLStmtAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::DynamicSQLStmtAST *) override; + void Visit(ast::udf::DynamicSQLStmtAST *ast) override; /** * Visit a ForStmtAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::ForStmtAST *) override; + void Visit(ast::udf::ForStmtAST *ast) override; /** * Visit a MemberExprAST node. + * @param ast The AST node to visit */ - void Visit(ast::udf::MemberExprAST *) override; + void Visit(ast::udf::MemberExprAST *ast) override; /** * Complete UDF code generation. @@ -180,27 +201,34 @@ class UDFCodegen : ast::udf::ASTNodeVisitor { */ catalog::type_oid_t GetCatalogTypeOidFromSQLType(execution::ast::BuiltinType::Kind type); - // The catalog access used during code generation + /** The catalog access used during code generation */ catalog::CatalogAccessor *accessor_; - // The function builder used during code generation + + /** The function builder used during code generation */ FunctionBuilder *fb_; - // The AST context for the UDF + + /** The AST context for the UDF */ ast::udf::UDFASTContext *udf_ast_context_; - // The code generation instance + + /** The code generation instance */ CodeGen *codegen_; - // The OID of the relevant database + + /** The OID of the relevant database */ catalog::db_oid_t db_oid_; - // Auxiliary declarations + + /** Auxiliary declarations */ execution::util::RegionVector aux_decls_; - // Flag indicating whether this UDF requires an execution context + /** Flag indicating whether this UDF requires an execution context */ bool needs_exec_ctx_; - // The current type during code generation + /** The current type during code generation */ type::TypeId current_type_{type::TypeId::INVALID}; - // The destination expression + + /** The destination expression */ execution::ast::Expr *dst_; - // Map from human-readable string identifier to internal identifier + + /** Map from human-readable string identifier to internal identifier */ std::unordered_map str_to_ident_; }; diff --git a/src/include/parser/udf/udf_parser.h b/src/include/parser/udf/udf_parser.h index cf617036a3..52c02b86a4 100644 --- a/src/include/parser/udf/udf_parser.h +++ b/src/include/parser/udf/udf_parser.h @@ -46,10 +46,8 @@ class PLpgSQLParser { std::unique_ptr ParseSQL(const nlohmann::json &sql_stmt); std::unique_ptr ParseDynamicSQL(const nlohmann::json &sql_stmt); - // Feed the expression (as a SQL string) to our parser then transform the - // noisepage expression into ast node - std::unique_ptr ParseExprSQL(const std::string &expr_sql_str); - std::unique_ptr ParseExpr(common::ManagedPointer); + std::unique_ptr ParseExprSQL(const std::string &sql); + std::unique_ptr ParseExpr(common::ManagedPointer expr); common::ManagedPointer udf_ast_context_; const common::ManagedPointer accessor_; diff --git a/src/parser/udf/udf_parser.cpp b/src/parser/udf/udf_parser.cpp index 168fbc41a1..c81e469b1a 100644 --- a/src/parser/udf/udf_parser.cpp +++ b/src/parser/udf/udf_parser.cpp @@ -7,44 +7,42 @@ #include "libpg_query/pg_query.h" #include "nlohmann/json.hpp" -namespace noisepage { -namespace parser { -namespace udf { +namespace noisepage::parser::udf { /** * @brief The identifiers used as keys in the parse tree. */ -static constexpr const char kFunctionList[] = "FunctionList"; -static constexpr const char kDatums[] = "datums"; -static constexpr const char kPLpgSQL_var[] = "PLpgSQL_var"; -static constexpr const char kRefname[] = "refname"; -static constexpr const char kDatatype[] = "datatype"; -static constexpr const char kDefaultVal[] = "default_val"; -static constexpr const char kPLpgSQL_type[] = "PLpgSQL_type"; -static constexpr const char kTypname[] = "typname"; -static constexpr const char kAction[] = "action"; -static constexpr const char kPLpgSQL_function[] = "PLpgSQL_function"; -static constexpr const char kBody[] = "body"; -static constexpr const char kPLpgSQL_stmt_block[] = "PLpgSQL_stmt_block"; -static constexpr const char kPLpgSQL_stmt_return[] = "PLpgSQL_stmt_return"; -static constexpr const char kPLpgSQL_stmt_if[] = "PLpgSQL_stmt_if"; -static constexpr const char kPLpgSQL_stmt_while[] = "PLpgSQL_stmt_while"; -static constexpr const char kPLpgSQL_stmt_fors[] = "PLpgSQL_stmt_fors"; -static constexpr const char kCond[] = "cond"; -static constexpr const char kThenBody[] = "then_body"; -static constexpr const char kElseBody[] = "else_body"; -static constexpr const char kExpr[] = "expr"; -static constexpr const char kQuery[] = "query"; -static constexpr const char kPLpgSQL_expr[] = "PLpgSQL_expr"; -static constexpr const char kPLpgSQL_stmt_assign[] = "PLpgSQL_stmt_assign"; -static constexpr const char kVarno[] = "varno"; -static constexpr const char kPLpgSQL_stmt_execsql[] = "PLpgSQL_stmt_execsql"; -static constexpr const char kSqlstmt[] = "sqlstmt"; -static constexpr const char kRow[] = "row"; -static constexpr const char kFields[] = "fields"; -static constexpr const char kName[] = "name"; -static constexpr const char kPLpgSQL_row[] = "PLpgSQL_row"; -static constexpr const char kPLpgSQL_stmt_dynexecute[] = "PLpgSQL_stmt_dynexecute"; +static constexpr const char K_FUNCTION_LIST[] = "FunctionList"; +static constexpr const char K_DATUMS[] = "datums"; +static constexpr const char K_PLPGSQL_VAR[] = "PLpgSQL_var"; +static constexpr const char K_REFNAME[] = "refname"; +static constexpr const char K_DATATYPE[] = "datatype"; +static constexpr const char K_DEFAULT_VAL[] = "default_val"; +static constexpr const char K_PLPGSQL_TYPE[] = "PLpgSQL_type"; +static constexpr const char K_TYPENAME[] = "typname"; +static constexpr const char K_ACTION[] = "action"; +static constexpr const char K_PLPGSQL_FUNCTION[] = "PLpgSQL_function"; +static constexpr const char K_BODY[] = "body"; +static constexpr const char K_PLPGSQL_STMT_BLOCK[] = "PLpgSQL_stmt_block"; +static constexpr const char K_PLPGSQL_STMT_RETURN[] = "PLpgSQL_stmt_return"; +static constexpr const char K_PLPGSQL_STMT_IF[] = "PLpgSQL_stmt_if"; +static constexpr const char K_PLPGSQL_STMT_WHILE[] = "PLpgSQL_stmt_while"; +static constexpr const char K_PLPGSQL_STMT_FORS[] = "PLpgSQL_stmt_fors"; +static constexpr const char K_COND[] = "cond"; +static constexpr const char K_THEN_BODY[] = "then_body"; +static constexpr const char K_ELSE_BODY[] = "else_body"; +static constexpr const char K_EXPR[] = "expr"; +static constexpr const char K_QUERY[] = "query"; +static constexpr const char K_PLPGSQL_EXPR[] = "PLpgSQL_expr"; +static constexpr const char K_PLPGSQL_STMT_ASSIGN[] = "PLpgSQL_stmt_assign"; +static constexpr const char K_VARNO[] = "varno"; +static constexpr const char K_PLGPSQL_STMT_EXECSQL[] = "PLpgSQL_stmt_execsql"; +static constexpr const char K_SQLSTMT[] = "sqlstmt"; +static constexpr const char K_ROW[] = "row"; +static constexpr const char K_FIELDS[] = "fields"; +static constexpr const char K_NAME[] = "name"; +static constexpr const char K_PLPGSQL_ROW[] = "PLpgSQL_row"; +static constexpr const char K_PLPGSQL_STMT_DYNEXECUTE[] = "PLpgSQL_stmt_dynexecute"; std::unique_ptr PLpgSQLParser::ParsePLpgSQL( std::vector &¶m_names, std::vector &¶m_types, const std::string &func_body, @@ -56,14 +54,14 @@ std::unique_ptr PLpgSQLParser::ParsePLpgSQL( } // The result is a list, we need to wrap it const auto ast_json_str = - "{ \"" + std::string{kFunctionList} + "\" : " + std::string{result.plpgsql_funcs} + " }"; // NOLINT + "{ \"" + std::string{K_FUNCTION_LIST} + "\" : " + std::string{result.plpgsql_funcs} + " }"; // NOLINT pg_query_free_plpgsql_parse_result(result); std::istringstream ss{ast_json_str}; nlohmann::json ast_json{}; ss >> ast_json; - const auto function_list = ast_json[kFunctionList]; + const auto function_list = ast_json[K_FUNCTION_LIST]; NOISEPAGE_ASSERT(function_list.is_array(), "Function list is not an array"); if (function_list.size() != 1) { throw PARSER_EXCEPTION("Function list has size other than 1"); @@ -73,15 +71,15 @@ std::unique_ptr PLpgSQLParser::ParsePLpgSQL( for (const auto &udf_name : param_names) { udf_ast_context_->SetVariableType(udf_name, param_types[i++]); } - const auto function = function_list[0][kPLpgSQL_function]; + const auto function = function_list[0][K_PLPGSQL_FUNCTION]; auto function_ast = std::make_unique( ParseFunction(function), std::move(param_names), std::move(param_types)); return function_ast; } std::unique_ptr PLpgSQLParser::ParseFunction(const nlohmann::json &block) { - const auto decl_list = block[kDatums]; - const auto function_body = block[kAction][kPLpgSQL_stmt_block][kBody]; + const auto decl_list = block[K_DATUMS]; + const auto function_body = block[K_ACTION][K_PLPGSQL_STMT_BLOCK][K_BODY]; std::vector> stmts{}; @@ -107,27 +105,27 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl const auto stmt = block[i]; const auto stmt_names = stmt.items().begin(); - if (stmt_names.key() == kPLpgSQL_stmt_return) { - auto expr = ParseExprSQL(stmt[kPLpgSQL_stmt_return][kExpr][kPLpgSQL_expr][kQuery].get()); + if (stmt_names.key() == K_PLPGSQL_STMT_RETURN) { + auto expr = ParseExprSQL(stmt[K_PLPGSQL_STMT_RETURN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); // TODO(Kyle): Handle return stmt w/o expression stmts.push_back(std::make_unique(std::move(expr))); - } else if (stmt_names.key() == kPLpgSQL_stmt_if) { - stmts.push_back(ParseIf(stmt[kPLpgSQL_stmt_if])); - } else if (stmt_names.key() == kPLpgSQL_stmt_assign) { + } else if (stmt_names.key() == K_PLPGSQL_STMT_IF) { + stmts.push_back(ParseIf(stmt[K_PLPGSQL_STMT_IF])); + } else if (stmt_names.key() == K_PLPGSQL_STMT_ASSIGN) { // TODO(Kyle): Need to fix Assignment expression / statement const auto &var_name = - udf_ast_context_->GetLocalVariableAtIndex(stmt[kPLpgSQL_stmt_assign][kVarno].get()); + udf_ast_context_->GetLocalVariableAtIndex(stmt[K_PLPGSQL_STMT_ASSIGN][K_VARNO].get()); auto lhs = std::make_unique(var_name); - auto rhs = ParseExprSQL(stmt[kPLpgSQL_stmt_assign][kExpr][kPLpgSQL_expr][kQuery].get()); + auto rhs = ParseExprSQL(stmt[K_PLPGSQL_STMT_ASSIGN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); stmts.push_back(std::make_unique(std::move(lhs), std::move(rhs))); - } else if (stmt_names.key() == kPLpgSQL_stmt_while) { - stmts.push_back(ParseWhile(stmt[kPLpgSQL_stmt_while])); - } else if (stmt_names.key() == kPLpgSQL_stmt_fors) { - stmts.push_back(ParseFor(stmt[kPLpgSQL_stmt_fors])); - } else if (stmt_names.key() == kPLpgSQL_stmt_execsql) { - stmts.push_back(ParseSQL(stmt[kPLpgSQL_stmt_execsql])); - } else if (stmt_names.key() == kPLpgSQL_stmt_dynexecute) { - stmts.push_back(ParseDynamicSQL(stmt[kPLpgSQL_stmt_dynexecute])); + } else if (stmt_names.key() == K_PLPGSQL_STMT_WHILE) { + stmts.push_back(ParseWhile(stmt[K_PLPGSQL_STMT_WHILE])); + } else if (stmt_names.key() == K_PLPGSQL_STMT_FORS) { + stmts.push_back(ParseFor(stmt[K_PLPGSQL_STMT_FORS])); + } else if (stmt_names.key() == K_PLGPSQL_STMT_EXECSQL) { + stmts.push_back(ParseSQL(stmt[K_PLGPSQL_STMT_EXECSQL])); + } else if (stmt_names.key() == K_PLPGSQL_STMT_DYNEXECUTE) { + stmts.push_back(ParseDynamicSQL(stmt[K_PLPGSQL_STMT_DYNEXECUTE])); } else { throw PARSER_EXCEPTION("Statement type not supported"); } @@ -139,13 +137,13 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &decl) { const auto &decl_names = decl.items().begin(); - if (decl_names.key() == kPLpgSQL_var) { - auto var_name = decl[kPLpgSQL_var][kRefname].get(); + if (decl_names.key() == K_PLPGSQL_VAR) { + auto var_name = decl[K_PLPGSQL_VAR][K_REFNAME].get(); udf_ast_context_->AddVariable(var_name); - auto type = decl[kPLpgSQL_var][kDatatype][kPLpgSQL_type][kTypname].get(); + auto type = decl[K_PLPGSQL_VAR][K_DATATYPE][K_PLPGSQL_TYPE][K_TYPENAME].get(); std::unique_ptr initial = nullptr; - if (decl[kPLpgSQL_var].find(kDefaultVal) != decl[kPLpgSQL_var].end()) { - initial = ParseExprSQL(decl[kPLpgSQL_var][kDefaultVal][kPLpgSQL_expr][kQuery].get()); + if (decl[K_PLPGSQL_VAR].find(K_DEFAULT_VAL) != decl[K_PLPGSQL_VAR].end()) { + initial = ParseExprSQL(decl[K_PLPGSQL_VAR][K_DEFAULT_VAL][K_PLPGSQL_EXPR][K_QUERY].get()); } type::TypeId temp_type{}; @@ -171,8 +169,8 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo } else { NOISEPAGE_ASSERT(false, "Unsupported Type"); } - } else if (decl_names.key() == kPLpgSQL_row) { - auto var_name = decl[kPLpgSQL_row][kRefname].get(); + } else if (decl_names.key() == K_PLPGSQL_ROW) { + auto var_name = decl[K_PLPGSQL_ROW][K_REFNAME].get(); NOISEPAGE_ASSERT(var_name == "*internal*", "Unexpected refname"); // TODO(Kyle): Support row types later udf_ast_context_->SetVariableType(var_name, type::TypeId::INVALID); @@ -183,41 +181,41 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo } std::unique_ptr PLpgSQLParser::ParseIf(const nlohmann::json &branch) { - auto cond_expr = ParseExprSQL(branch[kCond][kPLpgSQL_expr][kQuery].get()); - auto then_stmt = ParseBlock(branch[kThenBody]); + auto cond_expr = ParseExprSQL(branch[K_COND][K_PLPGSQL_EXPR][K_QUERY].get()); + auto then_stmt = ParseBlock(branch[K_THEN_BODY]); std::unique_ptr else_stmt = nullptr; - if (branch.find(kElseBody) != branch.end()) { - else_stmt = ParseBlock(branch[kElseBody]); + if (branch.find(K_ELSE_BODY) != branch.end()) { + else_stmt = ParseBlock(branch[K_ELSE_BODY]); } return std::make_unique(std::move(cond_expr), std::move(then_stmt), std::move(else_stmt)); } std::unique_ptr PLpgSQLParser::ParseWhile(const nlohmann::json &loop) { - auto cond_expr = ParseExprSQL(loop[kCond][kPLpgSQL_expr][kQuery].get()); - auto body_stmt = ParseBlock(loop[kBody]); + auto cond_expr = ParseExprSQL(loop[K_COND][K_PLPGSQL_EXPR][K_QUERY].get()); + auto body_stmt = ParseBlock(loop[K_BODY]); return std::make_unique(std::move(cond_expr), std::move(body_stmt)); } std::unique_ptr PLpgSQLParser::ParseFor(const nlohmann::json &loop) { - auto sql_query = loop[kQuery][kPLpgSQL_expr][kQuery].get(); + auto sql_query = loop[K_QUERY][K_PLPGSQL_EXPR][K_QUERY].get(); auto parse_result = PostgresParser::BuildParseTree(sql_query.c_str()); if (parse_result == nullptr) { return nullptr; } - auto body_stmt = ParseBlock(loop[kBody]); - auto var_array = loop[kRow][kPLpgSQL_row][kFields]; + auto body_stmt = ParseBlock(loop[K_BODY]); + auto var_array = loop[K_ROW][K_PLPGSQL_ROW][K_FIELDS]; std::vector var_vec; for (auto var : var_array) { - var_vec.push_back(var[kName].get()); + var_vec.push_back(var[K_NAME].get()); } return std::make_unique(std::move(var_vec), std::move(parse_result), std::move(body_stmt)); } std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql_stmt) { - auto sql_query = sql_stmt[kSqlstmt][kPLpgSQL_expr][kQuery].get(); - auto var_name = sql_stmt[kRow][kPLpgSQL_row][kFields][0][kName].get(); + auto sql_query = sql_stmt[K_SQLSTMT][K_PLPGSQL_EXPR][K_QUERY].get(); + auto var_name = sql_stmt[K_ROW][K_PLPGSQL_ROW][K_FIELDS][0][K_NAME].get(); auto parse_result = PostgresParser::BuildParseTree(sql_query.c_str()); if (parse_result == nullptr) { @@ -257,13 +255,13 @@ std::unique_ptr PLpgSQLParser::ParseSQL(const nloh } std::unique_ptr PLpgSQLParser::ParseDynamicSQL(const nlohmann::json &sql_stmt) { - auto sql_expr = ParseExprSQL(sql_stmt[kQuery][kPLpgSQL_expr][kQuery].get()); - auto var_name = sql_stmt[kRow][kPLpgSQL_row][kFields][0][kName].get(); + auto sql_expr = ParseExprSQL(sql_stmt[K_QUERY][K_PLPGSQL_EXPR][K_QUERY].get()); + auto var_name = sql_stmt[K_ROW][K_PLPGSQL_ROW][K_FIELDS][0][K_NAME].get(); return std::make_unique(std::move(sql_expr), std::move(var_name)); } -std::unique_ptr PLpgSQLParser::ParseExprSQL(const std::string &expr_sql_str) { - auto stmt_list = PostgresParser::BuildParseTree(expr_sql_str); +std::unique_ptr PLpgSQLParser::ParseExprSQL(const std::string &sql) { + auto stmt_list = PostgresParser::BuildParseTree(sql); if (stmt_list == nullptr) { return nullptr; } @@ -310,6 +308,4 @@ std::unique_ptr PLpgSQLParser::ParseExpr( throw PARSER_EXCEPTION("PL/pgSQL parser : Expression type not supported"); } -} // namespace udf -} // namespace parser -} // namespace noisepage +} // namespace noisepage::parser::udf From 8171dfbfdea9140013083c961103569147d62e69 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 18 Jun 2021 16:01:40 -0400 Subject: [PATCH 043/139] linting and clang-tidy passing, I am sure we will still fail doxygen but we are now 2 steps closer to being ready to merge --- src/execution/compiler/pipeline.cpp | 26 +++++--- src/execution/sql/ddl_executors.cpp | 20 +++++- src/execution/vm/bytecode_generator.cpp | 7 ++- src/execution/vm/llvm_engine.cpp | 8 +-- src/include/binder/binder_sherpa.h | 8 +-- src/parser/postgresparser.cpp | 11 ++-- src/parser/udf/udf_parser.cpp | 83 ++++++++++++++----------- 7 files changed, 95 insertions(+), 68 deletions(-) diff --git a/src/execution/compiler/pipeline.cpp b/src/execution/compiler/pipeline.cpp index 5966c52a1d..c170026344 100644 --- a/src/execution/compiler/pipeline.cpp +++ b/src/execution/compiler/pipeline.cpp @@ -155,10 +155,15 @@ util::RegionVector Pipeline::PipelineParams() const { void Pipeline::LinkSourcePipeline(Pipeline *dependency) { NOISEPAGE_ASSERT(dependency != nullptr, "Source cannot be null"); + // Add pipeline `dependency` as a nested pipeline dependencies_.push_back(dependency); + // Remove ourselves from the nested pipeline of dependency, if present + // TODO(Kyle): Is this possible? If so, is this a broken invariant? if (std::find(dependency->nested_pipelines_.begin(), dependency->nested_pipelines_.end(), this) != dependency->nested_pipelines_.end()) { - std::remove(dependency->nested_pipelines_.begin(), dependency->nested_pipelines_.end(), this); + dependency->nested_pipelines_.erase( + std::remove(dependency->nested_pipelines_.begin(), dependency->nested_pipelines_.end(), this), + dependency->nested_pipelines_.end()); } } @@ -313,7 +318,7 @@ ast::FunctionDecl *Pipeline::GenerateInitPipelineFunction(ast::LambdaExpr *outpu ast::FieldDecl *p_state_ptr = nullptr; auto &state = GetPipelineStateDescriptor(); uint32_t p_state_ind = 0; - if (nested_ || output_callback) { + if (nested_ || output_callback != nullptr) { p_state_ptr = codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName()))); params.push_back(p_state_ptr); @@ -326,9 +331,10 @@ ast::FunctionDecl *Pipeline::GenerateInitPipelineFunction(ast::LambdaExpr *outpu ast::Expr *exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); ast::Identifier tls = codegen_->MakeFreshIdentifier("threadStateContainer"); builder.Append(codegen_->DeclareVarWithInit(tls, codegen_->ExecCtxGetTLS(exec_ctx))); + // @tlsReset(tls, @sizeOf(ThreadState), init, tearDown, queryState) ast::Expr *state_ptr = query_state->GetStatePointer(codegen_); - if (!nested_ && !output_callback) { + if (!nested_ && output_callback == nullptr) { builder.Append(codegen_->TLSReset(codegen_->MakeExpr(tls), state.GetTypeName(), GetSetupPipelineStateFunctionName(), GetTearDownPipelineStateFunctionName(), state_ptr)); @@ -450,13 +456,13 @@ ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction(query_id_t query_id, as bool started_tracker = false; auto name = codegen_->MakeIdentifier(CreatePipelineFunctionName("Run")); auto params = compilation_context_->QueryParams(); - if (nested_ || output_callback) { + if (nested_ || output_callback != nullptr) { params.push_back(codegen_->MakeField(state_var_, codegen_->PointerType(state_.GetTypeName()))); } for (auto field : extra_pipeline_params_) { params.push_back(field); } - if (output_callback) { + if (output_callback != nullptr) { params.push_back(codegen_->MakeField(output_callback->GetName(), codegen_->LambdaType(output_callback->GetFunctionLitExpr()->TypeRepr()))); } @@ -497,10 +503,10 @@ ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction(query_id_t query_id, as arg = builder.GetParameterByPosition(i++); } } - if (output_callback && !nested_) { + if (output_callback != nullptr && !nested_) { args.push_back(codegen_->MakeExpr(output_callback->GetName())); } - builder.Append(codegen_->Call(GetWorkFunctionName(), std::move(args))); + builder.Append(codegen_->Call(GetWorkFunctionName(), args)); } // TODO(abalakum): This shouldn't actually be dependent on order and the loop can be simplified @@ -524,7 +530,7 @@ ast::FunctionDecl *Pipeline::GenerateTearDownPipelineFunction(ast::LambdaExpr *o ast::FieldDecl *p_state_ptr = nullptr; auto &state = GetPipelineStateDescriptor(); uint32_t p_state_index = 0; - if (nested_ || output_callback) { + if (nested_ || output_callback != nullptr) { p_state_ptr = codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName()))); params.push_back(p_state_ptr); @@ -535,7 +541,7 @@ ast::FunctionDecl *Pipeline::GenerateTearDownPipelineFunction(ast::LambdaExpr *o { // Begin a new code scope for fresh variables. CodeGen::CodeScope code_scope(codegen_); - if (!nested_ && !output_callback) { + if (!nested_ && output_callback == nullptr) { // Tear down thread local state if parallel pipeline. ast::Expr *exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); builder.Append(codegen_->TLSClear(codegen_->ExecCtxGetTLS(exec_ctx))); @@ -570,7 +576,7 @@ void Pipeline::GeneratePipeline(ExecutableQueryFragmentBuilder *builder, query_i builder->DeclareFunction(teardown); // Register the main init, run, tear-down functions as steps, in that order. - if (output_callback) { + if (output_callback != nullptr) { auto fn = GeneratePipelineWrapperFunction(output_callback); builder->DeclareFunction(fn); builder->RegisterStep(fn); diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index 545ebf0feb..58b1c7c02e 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -48,15 +48,16 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointer accessor) { // Request permission from the Catalog to see if this a valid namespace name NOISEPAGE_ASSERT(node->GetUDFLanguage() == parser::PLType::PL_PGSQL, "Unsupported language"); - NOISEPAGE_ASSERT(node->GetFunctionBody().size() >= 1, "Unsupported function body contents"); + NOISEPAGE_ASSERT(!node->GetFunctionBody().empty(), "Unsupported function body contents"); // I don't like how we have to separate the two here std::vector param_type_ids{}; std::vector param_types{}; - for (auto t : node->GetFunctionParameterTypes()) { + for (const auto t : node->GetFunctionParameterTypes()) { param_type_ids.push_back(parser::FuncParameter::DataTypeToTypeId(t)); param_types.push_back(accessor->GetTypeOidFromTypeId(parser::FuncParameter::DataTypeToTypeId(t))); } + auto body = node->GetFunctionBody().front(); auto proc_id = accessor->CreateProcedure( node->GetFunctionName(), catalog::postgres::PgLanguage::PLPGSQL_LANGUAGE_OID, node->GetNamespaceOid(), @@ -69,6 +70,7 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerGetDatabaseOid()}; + std::unique_ptr ast{}; try { ast = udf_parser.ParsePLpgSQL(node->GetFunctionParameterNames(), std::move(param_type_ids), body, @@ -77,6 +79,7 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerGetFunctionName()); sema::ErrorReporter error_reporter{region}; @@ -114,8 +117,19 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointer types{}; + types.reserve(node->GetFunctionParameterTypes().size()); + std::transform(node->GetFunctionParameterTypes().cbegin(), node->GetFunctionParameterTypes().cend(), + std::back_inserter(types), [](const parser::BaseFunctionParameter::DataType &type) -> type::TypeId { + return parser::FuncParameter::DataTypeToTypeId(type); + }); + auto udf_context = std::make_unique( - node->GetFunctionName(), parser::ReturnType::DataTypeToTypeId(node->GetReturnType()), std::move(param_type_ids), + node->GetFunctionName(), parser::ReturnType::DataTypeToTypeId(node->GetReturnType()), std::move(types), std::unique_ptr(region), std::move(ast_context), file); if (!accessor->SetFunctionContextPointer(proc_id, udf_context.get())) { return false; diff --git a/src/execution/vm/bytecode_generator.cpp b/src/execution/vm/bytecode_generator.cpp index 0dcec49f9f..52e095e215 100644 --- a/src/execution/vm/bytecode_generator.cpp +++ b/src/execution/vm/bytecode_generator.cpp @@ -4043,7 +4043,7 @@ FunctionInfo *BytecodeGenerator::AllocateFunc(const std::string &func_name, ast: // Cache func_map_[func->GetName()] = func->GetId(); - for (auto action : deferred_function_create_actions_[func->GetName()]) { + for (const auto &action : deferred_function_create_actions_[func->GetName()]) { action(func->GetId()); } @@ -4062,7 +4062,7 @@ FunctionInfo *BytecodeGenerator::AllocateFunc(const std::string &func_name, ast: func->NewParameterLocal(return_type->PointerTo(), "hiddenRv"); } - // lambda captures + // Lambda captures func->NewParameterLocal(capture_type->PointerTo(), "hiddenCaptures"); // Register parameters @@ -4072,9 +4072,10 @@ FunctionInfo *BytecodeGenerator::AllocateFunc(const std::string &func_name, ast: // Cache func_map_[func->GetName()] = func->GetId(); - for (auto action : deferred_function_create_actions_[func->GetName()]) { + for (const auto &action : deferred_function_create_actions_[func->GetName()]) { action(func->GetId()); } + return func; } diff --git a/src/execution/vm/llvm_engine.cpp b/src/execution/vm/llvm_engine.cpp index bc962f412d..3796a156f2 100644 --- a/src/execution/vm/llvm_engine.cpp +++ b/src/execution/vm/llvm_engine.cpp @@ -324,8 +324,8 @@ class LLVMEngine::FunctionLocalsMap { LLVMEngine::FunctionLocalsMap::FunctionLocalsMap(const FunctionInfo &func_info, llvm::Function *func, TypeMap *type_map, llvm::IRBuilder<> *ir_builder) : ir_builder_(ir_builder) { - uint32_t local_idx = 0; - + // The local variable index used throughout function body + std::size_t local_idx = 0; const auto &func_locals = func_info.GetLocals(); // Make an allocation for the return value, if it's direct. @@ -350,9 +350,7 @@ LLVMEngine::FunctionLocalsMap::FunctionLocalsMap(const FunctionInfo &func_info, params_[capture_local.GetOffset()] = new_capture_param; } - auto calling_context = func_info; - - // Allocate all local variables up front. + // Allocate all local variables up front for (; local_idx < func_info.GetLocals().size(); local_idx++) { const LocalInfo &local_info = func_locals[local_idx]; llvm::Type *llvm_type = type_map->GetLLVMType(local_info.GetType()); diff --git a/src/include/binder/binder_sherpa.h b/src/include/binder/binder_sherpa.h index fb032c3497..d400f96e8a 100644 --- a/src/include/binder/binder_sherpa.h +++ b/src/include/binder/binder_sherpa.h @@ -50,10 +50,10 @@ class BinderSherpa { * Add a parameter to the binder sherpa state. * @param param The parameter expression. */ - void AddParameter(const parser::ConstantValueExpression param) { - parameters_->push_back(param); - desired_parameter_types_->push_back(param.GetReturnValueType()); - } + // void AddParameter(const parser::ConstantValueExpression& param) { + // parameters_->push_back(param); + // desired_parameter_types_->push_back(param.GetReturnValueType()); + // } /** * @param expr The expression whose type constraints we want to look up. diff --git a/src/parser/postgresparser.cpp b/src/parser/postgresparser.cpp index 32908d9b55..044882b5e9 100644 --- a/src/parser/postgresparser.cpp +++ b/src/parser/postgresparser.cpp @@ -1306,7 +1306,8 @@ std::unique_ptr PostgresParser::CreateFunctionTransform(ParseResul std::string func_name = (reinterpret_cast(root->funcname_->tail->data.ptr_value)->val_.str_); std::vector func_body{}; - func_body.push_back(std::string(query_string.c_str())); + func_body.push_back(query_string); + AsType as_type = AsType::INVALID; PLType pl_type = PLType::INVALID; @@ -1337,11 +1338,9 @@ std::unique_ptr PostgresParser::CreateFunctionTransform(ParseResul } } - auto result = - std::make_unique(replace, std::move(func_name), std::move(func_body), - std::move(return_type), std::move(func_parameters), pl_type, as_type); - - return result; + return std::make_unique(replace, std::move(func_name), std::move(func_body), + std::move(return_type), std::move(func_parameters), pl_type, + as_type); } // Postgres.IndexStmt -> noisepage.CreateStatement diff --git a/src/parser/udf/udf_parser.cpp b/src/parser/udf/udf_parser.cpp index c81e469b1a..160f488ceb 100644 --- a/src/parser/udf/udf_parser.cpp +++ b/src/parser/udf/udf_parser.cpp @@ -48,7 +48,7 @@ std::unique_ptr PLpgSQLParser::ParsePLpgSQL( std::vector &¶m_names, std::vector &¶m_types, const std::string &func_body, common::ManagedPointer ast_context) { auto result = pg_query_parse_plpgsql(func_body.c_str()); - if (result.error) { + if (result.error != nullptr) { pg_query_free_plpgsql_parse_result(result); throw PARSER_EXCEPTION("PL/pgSQL parsing error"); } @@ -95,16 +95,13 @@ std::unique_ptr PLpgSQLParser::ParseFunction(const std::unique_ptr PLpgSQLParser::ParseBlock(const nlohmann::json &block) { // TODO(boweic): Support statements size other than 1 NOISEPAGE_ASSERT(block.is_array(), "Block isn't array"); - if (block.size() == 0) { + if (block.empty()) { throw PARSER_EXCEPTION("PL/pgSQL parser : Empty block is not supported"); } std::vector> stmts{}; - - for (uint32_t i = 0; i < block.size(); i++) { - const auto stmt = block[i]; + for (const auto &stmt : block) { const auto stmt_names = stmt.items().begin(); - if (stmt_names.key() == K_PLPGSQL_STMT_RETURN) { auto expr = ParseExprSQL(stmt[K_PLPGSQL_STMT_RETURN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); // TODO(Kyle): Handle return stmt w/o expression @@ -150,25 +147,27 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo if (udf_ast_context_->GetVariableType(var_name, &temp_type)) { return std::make_unique(var_name, temp_type, std::move(initial)); } - if ((type.find("integer") != std::string::npos) || type.find("INTEGER") != std::string::npos) { udf_ast_context_->SetVariableType(var_name, type::TypeId::INTEGER); return std::make_unique(var_name, type::TypeId::INTEGER, std::move(initial)); - } else if (type == "double" || type.rfind("numeric") == 0) { + } + if (type == "double" || type.rfind("numeric") == 0) { udf_ast_context_->SetVariableType(var_name, type::TypeId::DECIMAL); return std::make_unique(var_name, type::TypeId::DECIMAL, std::move(initial)); - } else if (type == "varchar") { + } + if (type == "varchar") { udf_ast_context_->SetVariableType(var_name, type::TypeId::VARCHAR); return std::make_unique(var_name, type::TypeId::VARCHAR, std::move(initial)); - } else if (type.find("date") != std::string::npos) { + } + if (type.find("date") != std::string::npos) { udf_ast_context_->SetVariableType(var_name, type::TypeId::DATE); return std::make_unique(var_name, type::TypeId::DATE, std::move(initial)); - } else if (type == "record") { + } + if (type == "record") { udf_ast_context_->SetVariableType(var_name, type::TypeId::INVALID); return std::make_unique(var_name, type::TypeId::INVALID, std::move(initial)); - } else { - NOISEPAGE_ASSERT(false, "Unsupported Type"); } + NOISEPAGE_ASSERT(false, "Unsupported Type"); } else if (decl_names.key() == K_PLPGSQL_ROW) { auto var_name = decl[K_PLPGSQL_ROW][K_REFNAME].get(); NOISEPAGE_ASSERT(var_name == "*internal*", "Unexpected refname"); @@ -198,8 +197,8 @@ std::unique_ptr PLpgSQLParser::ParseWhile(const nl } std::unique_ptr PLpgSQLParser::ParseFor(const nlohmann::json &loop) { - auto sql_query = loop[K_QUERY][K_PLPGSQL_EXPR][K_QUERY].get(); - auto parse_result = PostgresParser::BuildParseTree(sql_query.c_str()); + const auto sql_query = loop[K_QUERY][K_PLPGSQL_EXPR][K_QUERY].get(); + auto parse_result = PostgresParser::BuildParseTree(sql_query); if (parse_result == nullptr) { return nullptr; } @@ -214,10 +213,12 @@ std::unique_ptr PLpgSQLParser::ParseFor(const nloh } std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql_stmt) { - auto sql_query = sql_stmt[K_SQLSTMT][K_PLPGSQL_EXPR][K_QUERY].get(); + // The query text + const auto sql_query = sql_stmt[K_SQLSTMT][K_PLPGSQL_EXPR][K_QUERY].get(); + // The variable name (non-const for later std::move) auto var_name = sql_stmt[K_ROW][K_PLPGSQL_ROW][K_FIELDS][0][K_NAME].get(); - auto parse_result = PostgresParser::BuildParseTree(sql_query.c_str()); + auto parse_result = PostgresParser::BuildParseTree(sql_query); if (parse_result == nullptr) { return nullptr; } @@ -281,31 +282,39 @@ std::unique_ptr PLpgSQLParser::ParseExpr( auto cve = expr.CastManagedPointerTo(); if (cve->GetTableName().empty()) { return std::make_unique(cve->GetColumnName()); - } else { - auto vexpr = std::make_unique(cve->GetTableName()); - return std::make_unique(std::move(vexpr), cve->GetColumnName()); } - } else if ((parser::ExpressionUtil::IsOperatorExpression(expr->GetExpressionType()) && - expr->GetChildrenSize() == 2) || - (parser::ExpressionUtil::IsComparisonExpression(expr->GetExpressionType()))) { + auto vexpr = std::make_unique(cve->GetTableName()); + return std::make_unique(std::move(vexpr), cve->GetColumnName()); + } + + if ((parser::ExpressionUtil::IsOperatorExpression(expr->GetExpressionType()) && expr->GetChildrenSize() == 2) || + (parser::ExpressionUtil::IsComparisonExpression(expr->GetExpressionType()))) { return std::make_unique(expr->GetExpressionType(), ParseExpr(expr->GetChild(0)), ParseExpr(expr->GetChild(1))); - } else if (expr->GetExpressionType() == parser::ExpressionType::FUNCTION) { - auto func_expr = expr.CastManagedPointerTo(); - std::vector> args{}; - auto num_args = func_expr->GetChildrenSize(); - for (size_t idx = 0; idx < num_args; ++idx) { - args.push_back(ParseExpr(func_expr->GetChild(idx))); + } + + // TODO(Kyle): I am not a fan of non-exhaustive switch statements; + // is there a way that we can refactor this logic to make it better? + + switch (expr->GetExpressionType()) { + case parser::ExpressionType::FUNCTION: { + auto func_expr = expr.CastManagedPointerTo(); + std::vector> args{}; + auto num_args = func_expr->GetChildrenSize(); + for (size_t idx = 0; idx < num_args; ++idx) { + args.push_back(ParseExpr(func_expr->GetChild(idx))); + } + return std::make_unique(func_expr->GetFuncName(), std::move(args)); } - return std::make_unique(func_expr->GetFuncName(), std::move(args)); - } else if (expr->GetExpressionType() == parser::ExpressionType::VALUE_CONSTANT) { - return std::make_unique(expr->Copy()); - } else if (expr->GetExpressionType() == parser::ExpressionType::OPERATOR_IS_NOT_NULL) { - return std::make_unique(false, ParseExpr(expr->GetChild(0))); - } else if (expr->GetExpressionType() == parser::ExpressionType::OPERATOR_IS_NULL) { - return std::make_unique(true, ParseExpr(expr->GetChild(0))); + case parser::ExpressionType::VALUE_CONSTANT: + return std::make_unique(expr->Copy()); + case parser::ExpressionType::OPERATOR_IS_NOT_NULL: + return std::make_unique(false, ParseExpr(expr->GetChild(0))); + case parser::ExpressionType::OPERATOR_IS_NULL: + return std::make_unique(true, ParseExpr(expr->GetChild(0))); + default: + throw PARSER_EXCEPTION("PL/pgSQL parser : Expression type not supported"); } - throw PARSER_EXCEPTION("PL/pgSQL parser : Expression type not supported"); } } // namespace noisepage::parser::udf From a1b177b4bc52e181a56eaac77f2c20214c40f1bb Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Mon, 21 Jun 2021 09:55:49 -0400 Subject: [PATCH 044/139] make pointer explicit --- src/execution/compiler/udf/udf_codegen.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 0d0dff0720..2bad367f74 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -620,7 +620,7 @@ void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { for (auto &col : cols) { execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->Name())->second); - auto lhs = capture_var; + auto* lhs = capture_var; if (cols.size() > 1) { // Record struct type lhs = codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(col.GetName())); From f36b19a6dc7bcf752d9ee3e71306d8cf97a031cd Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 24 Jun 2021 11:11:05 -0400 Subject: [PATCH 045/139] refactoring in UDF parsing and AST construction, add documentation for all undocumented APIs --- src/execution/compiler/udf/udf_codegen.cpp | 2 +- src/execution/sql/ddl_executors.cpp | 4 +- src/include/binder/binder_sherpa.h | 14 +-- src/include/execution/ast/ast_clone.h | 12 ++- src/include/execution/ast/type.h | 6 +- .../execution/ast/udf/udf_ast_context.h | 7 ++ .../execution/ast/udf/udf_ast_node_visitor.h | 102 ++++++++++++++++++ src/include/execution/ast/udf/udf_ast_nodes.h | 2 +- .../execution/compiler/udf/udf_codegen.h | 5 + .../execution/functions/function_context.h | 31 ++++-- src/include/execution/sql/ddl_executors.h | 2 +- .../expression/constant_value_expression.h | 2 + src/include/parser/udf/udf_parser.h | 96 +++++++++++++++-- src/parser/udf/udf_parser.cpp | 8 +- test/catalog/catalog_test.cpp | 2 +- 15 files changed, 248 insertions(+), 47 deletions(-) diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 2bad367f74..8e0089dace 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -620,7 +620,7 @@ void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { for (auto &col : cols) { execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->Name())->second); - auto* lhs = capture_var; + auto *lhs = capture_var; if (cols.size() > 1) { // Record struct type lhs = codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(col.GetName())); diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index 58b1c7c02e..06aba0c712 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -73,8 +73,8 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointer ast{}; try { - ast = udf_parser.ParsePLpgSQL(node->GetFunctionParameterNames(), std::move(param_type_ids), body, - (common::ManagedPointer(&udf_ast_context))); + ast = udf_parser.Parse(node->GetFunctionParameterNames(), std::move(param_type_ids), body, + (common::ManagedPointer(&udf_ast_context))); } catch (Exception &e) { return false; } diff --git a/src/include/binder/binder_sherpa.h b/src/include/binder/binder_sherpa.h index d400f96e8a..ee77fa44c8 100644 --- a/src/include/binder/binder_sherpa.h +++ b/src/include/binder/binder_sherpa.h @@ -41,21 +41,13 @@ class BinderSherpa { common::ManagedPointer GetParseResult() const { return parse_result_; } /** - * @return parameters for the query being bound - * @warning can be nullptr if there are no parameters + * @return The parameters for the query being bound + * @warning May be `nullptr` if there are no parameters */ common::ManagedPointer> GetParameters() const { return parameters_; } /** - * Add a parameter to the binder sherpa state. - * @param param The parameter expression. - */ - // void AddParameter(const parser::ConstantValueExpression& param) { - // parameters_->push_back(param); - // desired_parameter_types_->push_back(param.GetReturnValueType()); - // } - - /** + * Get the desired type for the expression. * @param expr The expression whose type constraints we want to look up. * @return The previously recorded type constraints, or the expression's current return value type if none exist. */ diff --git a/src/include/execution/ast/ast_clone.h b/src/include/execution/ast/ast_clone.h index 2e345cf9e2..540a881777 100644 --- a/src/include/execution/ast/ast_clone.h +++ b/src/include/execution/ast/ast_clone.h @@ -9,15 +9,17 @@ namespace noisepage::execution::ast { class AstNode; +/** + * The AstClone class encapsulates the logic necessary to clone an AST. + */ class AstClone { public: /** * Clones an ASTNode and its descendants. - * @param node The root of the AST to clone. - * @param factory The AstNodeFactory instance from which AST nodes are allocated. - * @param prefix The - * @param old_context - * @param new_context + * @param node The root of the AST to clone + * @param factory The AstNodeFactory instance from which AST nodes are allocated + * @param old_context The old AST context + * @param new_context The new AST context * @return */ static AstNode *Clone(AstNode *node, AstNodeFactory *factory, Context *old_context, Context *new_context); diff --git a/src/include/execution/ast/type.h b/src/include/execution/ast/type.h index e9b2986390..108cc4dee5 100644 --- a/src/include/execution/ast/type.h +++ b/src/include/execution/ast/type.h @@ -658,14 +658,12 @@ class FunctionType : public Type { */ bool IsEqual(const FunctionType *other); - /** - * @return `true` if this function is a lambda, `false` otherwise. - */ + /** @return `true` if this function is a lambda, `false` otherwise. */ bool IsLambda() const { return is_lambda_; } /** * Set the lambda disposition for this function. - * @param `true` if this function is a lambda, `false` otherwise. + * @param is_lambda `true` if this function is a lambda, `false` otherwise. */ void SetIsLambda(bool is_lambda) { is_lambda_ = is_lambda; } diff --git a/src/include/execution/ast/udf/udf_ast_context.h b/src/include/execution/ast/udf/udf_ast_context.h index 1e38ee4d9b..2b9cc3c547 100644 --- a/src/include/execution/ast/udf/udf_ast_context.h +++ b/src/include/execution/ast/udf/udf_ast_context.h @@ -12,8 +12,15 @@ namespace execution { namespace ast { namespace udf { +/** + * The UDFASTContext class maintains state that is utilized + * throughout construction of the UDF abstract syntax tree. + */ class UDFASTContext { public: + /** + * Construct a new UDFASTContext. + */ UDFASTContext() = default; /** diff --git a/src/include/execution/ast/udf/udf_ast_node_visitor.h b/src/include/execution/ast/udf/udf_ast_node_visitor.h index 3fd2b739d3..4c24b4a2a2 100644 --- a/src/include/execution/ast/udf/udf_ast_node_visitor.h +++ b/src/include/execution/ast/udf/udf_ast_node_visitor.h @@ -25,27 +25,129 @@ class DynamicSQLStmtAST; class ForStmtAST; class FunctionAST; +/** + * The ASTNodeVisitor class defines the interface for + * visitors of the UDF abstract syntax tree. + */ class ASTNodeVisitor { public: + /** + * Destroy the visitor. + */ virtual ~ASTNodeVisitor() = default; + + /** + * Visit an AbstractAST node. + * @param ast The node to visit + */ virtual void Visit(AbstractAST *ast) = 0; + + /** + * Visit an StmtAST node. + * @param ast The node to visit + */ virtual void Visit(StmtAST *ast) = 0; + + /** + * Visit an ExprAST node. + * @param ast The node to visit + */ virtual void Visit(ExprAST *ast) = 0; + + /** + * Visit an FunctionAST node. + * @param ast The node to visit + */ virtual void Visit(FunctionAST *ast) = 0; + + /** + * Visit an ValueExprAST node. + * @param ast The node to visit + */ virtual void Visit(ValueExprAST *ast) = 0; + + /** + * Visit an VariableExprAST node. + * @param ast The node to visit + */ virtual void Visit(VariableExprAST *ast) = 0; + + /** + * Visit an BinaryExprAST node. + * @param ast The node to visit + */ virtual void Visit(BinaryExprAST *ast) = 0; + + /** + * Visit an IsNullExprAST node. + * @param ast The node to visit + */ virtual void Visit(IsNullExprAST *ast) = 0; + + /** + * Visit an CallExprAST node. + * @param ast The node to visit + */ virtual void Visit(CallExprAST *ast) = 0; + + /** + * Visit an MemberExprAST node. + * @param ast The node to visit + */ virtual void Visit(MemberExprAST *ast) = 0; + + /** + * Visit an SeqStmtAST node. + * @param ast The node to visit + */ virtual void Visit(SeqStmtAST *ast) = 0; + + /** + * Visit an DeclStmtAST node. + * @param ast The node to visit + */ virtual void Visit(DeclStmtAST *ast) = 0; + + /** + * Visit an IfStmtAST node. + * @param ast The node to visit + */ virtual void Visit(IfStmtAST *ast) = 0; + + /** + * Visit an WhileStmtAST node. + * @param ast The node to visit + */ virtual void Visit(WhileStmtAST *ast) = 0; + + /** + * Visit an RetStmtAST node. + * @param ast The node to visit + */ virtual void Visit(RetStmtAST *ast) = 0; + + /** + * Visit an AssignStmtAST node. + * @param ast The node to visit + */ virtual void Visit(AssignStmtAST *ast) = 0; + + /** + * Visit an ForStmtAST node. + * @param ast The node to visit + */ virtual void Visit(ForStmtAST *ast) = 0; + + /** + * Visit an SQLStmtAST node. + * @param ast The node to visit + */ virtual void Visit(SQLStmtAST *ast) = 0; + + /** + * Visit an DynamicSQLStmtAST node. + * @param ast The node to visit + */ virtual void Visit(DynamicSQLStmtAST *ast) = 0; }; diff --git a/src/include/execution/ast/udf/udf_ast_nodes.h b/src/include/execution/ast/udf/udf_ast_nodes.h index 885cdd2496..78f71e2755 100644 --- a/src/include/execution/ast/udf/udf_ast_nodes.h +++ b/src/include/execution/ast/udf/udf_ast_nodes.h @@ -637,7 +637,7 @@ class FunctionAST : public AbstractAST { * Construct a new FunctionAST instance. * @param body The body of the function * @param parameter_names The names of the parameters to the function - * @param parameter_type The types of the parameters to the function + * @param parameter_types The types of the parameters to the function */ FunctionAST(std::unique_ptr &&body, std::vector &¶meter_names, std::vector &¶meter_types) diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index 6b913237cf..d2a8da5d2e 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -44,6 +44,11 @@ class ForStmtAST; namespace compiler { namespace udf { +/** + * The UDFCodegen class implements a visitor for UDF AST nodes + * and encapsulates all of the logic required to generate code + * from the UDF abstract syntax tree. + */ class UDFCodegen : ast::udf::ASTNodeVisitor { public: /** diff --git a/src/include/execution/functions/function_context.h b/src/include/execution/functions/function_context.h index 6f1877e744..171f5c829e 100644 --- a/src/include/execution/functions/function_context.h +++ b/src/include/execution/functions/function_context.h @@ -24,12 +24,12 @@ class FunctionContext { * Creates a FunctionContext object * @param func_name Name of function * @param func_ret_type Return type of function - * @param args_type Vector of argument types + * @param arg_types Vector of argument types */ - FunctionContext(std::string func_name, type::TypeId func_ret_type, std::vector &&args_type) + FunctionContext(std::string func_name, type::TypeId func_ret_type, std::vector &&arg_types) : func_name_(std::move(func_name)), func_ret_type_(func_ret_type), - args_type_(std::move(args_type)), + arg_types_(std::move(arg_types)), is_builtin_{false}, is_exec_ctx_required_{false} {} @@ -37,15 +37,15 @@ class FunctionContext { * Creates a FunctionContext object for a builtin function * @param func_name Name of function * @param func_ret_type Return type of function - * @param args_type Vector of argument types + * @param arg_types Vector of argument types * @param builtin Which builtin this context refers to * @param is_exec_ctx_required true if this function requires an execution context var as its first argument */ - FunctionContext(std::string func_name, type::TypeId func_ret_type, std::vector &&args_type, + FunctionContext(std::string func_name, type::TypeId func_ret_type, std::vector &&arg_types, ast::Builtin builtin, bool is_exec_ctx_required = false) : func_name_(std::move(func_name)), func_ret_type_(func_ret_type), - args_type_(std::move(args_type)), + arg_types_(std::move(arg_types)), is_builtin_{true}, builtin_{builtin}, is_exec_ctx_required_{is_exec_ctx_required} {} @@ -61,12 +61,12 @@ class FunctionContext { * @param is_exec_ctx_required Flag indicating whether an * execution context is required for this function */ - FunctionContext(std::string func_name, type::TypeId func_ret_type, std::vector &&args_type, + FunctionContext(std::string func_name, type::TypeId func_ret_type, std::vector &&arg_types, std::unique_ptr ast_region, std::unique_ptr ast_context, ast::File *file, bool is_exec_ctx_required = true) : func_name_(std::move(func_name)), func_ret_type_(func_ret_type), - args_type_(std::move(args_type)), + arg_types_(std::move(arg_types)), is_builtin_{false}, is_exec_ctx_required_{is_exec_ctx_required}, ast_region_{std::move(ast_region)}, @@ -81,7 +81,7 @@ class FunctionContext { /** * @return The vector of type arguments of the function represented by this context object. */ - const std::vector &GetFunctionArgsType() const { return args_type_; } + const std::vector &GetFunctionArgTypes() const { return arg_types_; } /** * Gets the return type of the function represented by this object. @@ -136,15 +136,26 @@ class FunctionContext { } private: + /** The function name */ std::string func_name_; + /** The function return type */ type::TypeId func_ret_type_; - std::vector args_type_; + /** The types of the arguments to the function */ + std::vector arg_types_; + + /** `true` if this function is a builtin, `false` otherwise */ bool is_builtin_; + /** The builtin function, if applicable */ ast::Builtin builtin_; + /** `true` if an execution context is required for this function, `false` otherwise */ bool is_exec_ctx_required_; + /** The associated AST region */ std::unique_ptr ast_region_; + /** The associated AST context */ std::unique_ptr ast_context_; + + /** The associated file */ ast::File *file_; }; diff --git a/src/include/execution/sql/ddl_executors.h b/src/include/execution/sql/ddl_executors.h index 1300682ed6..30c2b8e2f5 100644 --- a/src/include/execution/sql/ddl_executors.h +++ b/src/include/execution/sql/ddl_executors.h @@ -49,7 +49,7 @@ class DDLExecutors { /** * @param node node to execute - * @param exec_ctx accessor to use for execution + * @param accessor accessor to use for execution * @return `true` if the operation succeeds, `false` otherwise */ static bool CreateFunctionExecutor(common::ManagedPointer node, diff --git a/src/include/parser/expression/constant_value_expression.h b/src/include/parser/expression/constant_value_expression.h index 49d479cd1d..b2c3436c95 100644 --- a/src/include/parser/expression/constant_value_expression.h +++ b/src/include/parser/expression/constant_value_expression.h @@ -109,12 +109,14 @@ class ConstantValueExpression : public AbstractExpression { return Copy(); } + /** Derive the name of the expression if it is not present */ void DeriveExpressionName() override { if (!this->GetAliasName().empty()) { this->SetExpressionName(this->GetAliasName()); } } + /** @return The expression value as a generic SQL value */ common::ManagedPointer GetVal() const { NOISEPAGE_ASSERT(std::holds_alternative(value_), "GetVal() bad variant access"); return common::ManagedPointer(&std::get(value_)); diff --git a/src/include/parser/udf/udf_parser.h b/src/include/parser/udf/udf_parser.h index 52c02b86a4..a6056d1e2c 100644 --- a/src/include/parser/udf/udf_parser.h +++ b/src/include/parser/udf/udf_parser.h @@ -14,45 +14,127 @@ namespace noisepage { -// Forward declaration namespace execution::ast::udf { class FunctionAST; } namespace parser { namespace udf { + /** * Namespace alias to make below more manageable. */ namespace udfexec = execution::ast::udf; +/** + * The PLpgSQLParser class parses source PL/pgSQL to an abstract syntax tree. + */ class PLpgSQLParser { public: + /** + * Construct a new PLpgSQLParser instance. + * @param udf_ast_context The AST context + * @param accessor The accessor to use during parsing + * @param db_oid The database OID + */ PLpgSQLParser(common::ManagedPointer udf_ast_context, const common::ManagedPointer accessor, catalog::db_oid_t db_oid) : udf_ast_context_(udf_ast_context), accessor_(accessor), db_oid_(db_oid) {} - std::unique_ptr ParsePLpgSQL(std::vector &¶m_names, - std::vector &¶m_types, - const std::string &func_body, - common::ManagedPointer ast_context); + + /** + * Parse source PL/pgSQL to an abstract syntax tree. + * @param param_names The names of the function parameters + * @param param_types The types of the function parameters + * @param func_body The input source for the function + * @param ast_context The AST context to use during parsing + * @return The abstract syntax tree for the source function + */ + std::unique_ptr Parse(std::vector &¶m_names, + std::vector &¶m_types, const std::string &func_body, + common::ManagedPointer ast_context); private: + /** + * Parse a block statement. + * @param block The input JSON object + * @return The AST for the block + */ std::unique_ptr ParseBlock(const nlohmann::json &block); - std::unique_ptr ParseFunction(const nlohmann::json &block); + + /** + * Parse a function statement. + * @param block The input JSON object + * @return The AST for the function + */ + std::unique_ptr ParseFunction(const nlohmann::json &function); + + /** + * Parse a declaration statement. + * @param decl The input JSON object + * @return The AST for the declaration + */ std::unique_ptr ParseDecl(const nlohmann::json &decl); + + /** + * Parse an if-statement. + * @param block The input JSON object + * @return The AST for the if-statement + */ std::unique_ptr ParseIf(const nlohmann::json &branch); + + /** + * Parse a while-statement. + * @param block The input JSON object + * @return The AST for the while-statement + */ std::unique_ptr ParseWhile(const nlohmann::json &loop); + + /** + * Parse a for-statement. + * @param block The input JSON object + * @return The AST for the for-statement + */ std::unique_ptr ParseFor(const nlohmann::json &loop); + + /** + * Parse a SQL statement. + * @param sql_stmt The input JSON object + * @return The AST for the SQL statement + */ std::unique_ptr ParseSQL(const nlohmann::json &sql_stmt); + + /** + * Parse a dynamic SQL statement. + * @param block The input JSON object + * @return The AST for the dynamic SQL statement + */ std::unique_ptr ParseDynamicSQL(const nlohmann::json &sql_stmt); + /** + * Parse a SQL expression. + * @param sql The SQL expression string + * @return The AST for the SQL expression + */ std::unique_ptr ParseExprSQL(const std::string &sql); + + /** + * Parse an expression. + * @param expr The expression + * @return The AST for the expression + */ std::unique_ptr ParseExpr(common::ManagedPointer expr); + private: + /** The UDF AST context */ common::ManagedPointer udf_ast_context_; + + /** The catalog accessor */ const common::ManagedPointer accessor_; + + /** The OID for the database with which the function is associated */ catalog::db_oid_t db_oid_; - // common::ManagedPointer sql_parser_; + + /** The function symbol table */ std::unordered_map symbol_table_; }; diff --git a/src/parser/udf/udf_parser.cpp b/src/parser/udf/udf_parser.cpp index 160f488ceb..108a8c3039 100644 --- a/src/parser/udf/udf_parser.cpp +++ b/src/parser/udf/udf_parser.cpp @@ -44,7 +44,7 @@ static constexpr const char K_NAME[] = "name"; static constexpr const char K_PLPGSQL_ROW[] = "PLpgSQL_row"; static constexpr const char K_PLPGSQL_STMT_DYNEXECUTE[] = "PLpgSQL_stmt_dynexecute"; -std::unique_ptr PLpgSQLParser::ParsePLpgSQL( +std::unique_ptr PLpgSQLParser::Parse( std::vector &¶m_names, std::vector &¶m_types, const std::string &func_body, common::ManagedPointer ast_context) { auto result = pg_query_parse_plpgsql(func_body.c_str()); @@ -77,9 +77,9 @@ std::unique_ptr PLpgSQLParser::ParsePLpgSQL( return function_ast; } -std::unique_ptr PLpgSQLParser::ParseFunction(const nlohmann::json &block) { - const auto decl_list = block[K_DATUMS]; - const auto function_body = block[K_ACTION][K_PLPGSQL_STMT_BLOCK][K_BODY]; +std::unique_ptr PLpgSQLParser::ParseFunction(const nlohmann::json &function) { + const auto decl_list = function[K_DATUMS]; + const auto function_body = function[K_ACTION][K_PLPGSQL_STMT_BLOCK][K_BODY]; std::vector> stmts{}; diff --git a/test/catalog/catalog_test.cpp b/test/catalog/catalog_test.cpp index bf44521f72..1a1fc1b059 100644 --- a/test/catalog/catalog_test.cpp +++ b/test/catalog/catalog_test.cpp @@ -166,7 +166,7 @@ TEST_F(CatalogTests, ProcTest) { EXPECT_TRUE(sin_context->IsBuiltin()); EXPECT_EQ(sin_context->GetBuiltin(), execution::ast::Builtin::Sin); EXPECT_EQ(sin_context->GetFunctionReturnType(), type::TypeId::REAL); - auto sin_args = sin_context->GetFunctionArgsType(); + auto sin_args = sin_context->GetFunctionArgTypes(); EXPECT_EQ(sin_args.size(), 1); EXPECT_EQ(sin_args.back(), type::TypeId::REAL); EXPECT_EQ(sin_context->GetFunctionName(), "sin"); From ea6f440f6b8d8594db82a6c9f9915da93284c9d7 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 24 Jun 2021 14:56:46 -0400 Subject: [PATCH 046/139] fix censored errors --- src/execution/sql/ddl_executors.cpp | 1 - src/execution/vm/llvm_engine.cpp | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index 06aba0c712..0b56f34cdd 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -112,7 +112,6 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerReset(); if (type_check.Run(file)) { EXECUTION_LOG_ERROR("Errors: \n {}", type_check.GetErrorReporter()->SerializeErrors()); - execution::ast::AstPrettyPrint::Dump(std::cout, file); return false; } } diff --git a/src/execution/vm/llvm_engine.cpp b/src/execution/vm/llvm_engine.cpp index 3796a156f2..ca324af295 100644 --- a/src/execution/vm/llvm_engine.cpp +++ b/src/execution/vm/llvm_engine.cpp @@ -657,7 +657,8 @@ void LLVMEngine::CompiledModuleBuilder::BuildSimpleCFG(const FunctionInfo &func_ void LLVMEngine::CompiledModuleBuilder::DefineFunction(const FunctionInfo &func_info, llvm::IRBuilder<> *ir_builder) { llvm::LLVMContext &ctx = ir_builder->getContext(); llvm::Function *func = llvm_module_->getFunction(func_info.GetName()); - if (func->getName().str().find("inline") != std::string::npos) { + // The line below is flagged by `check-censored` target because of 'inline' + if (func->getName().str().find("inline") != std::string::npos) { // NOLINT func->setLinkage(llvm::Function::LinkOnceAnyLinkage); func->addFnAttr(llvm::Attribute::AlwaysInline); } From dc53b416a1e9ccef17ec190e0017e59d74c57c4e Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 25 Jun 2021 10:11:57 -0400 Subject: [PATCH 047/139] fix broken tpl test --- util/execution/tpl.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/util/execution/tpl.cpp b/util/execution/tpl.cpp index fb160d9528..a484b6ff88 100644 --- a/util/execution/tpl.cpp +++ b/util/execution/tpl.cpp @@ -98,13 +98,17 @@ static void CompileAndRun(const std::string &source, const std::string &name = " db_oid, common::ManagedPointer(txn), callback, output_schema, common::ManagedPointer(accessor), exec_settings, db_main->GetMetricsManager(), DISABLED, DISABLED}; // Add dummy parameters for tests - std::vector params; - params.emplace_back(type::TypeId::INTEGER, sql::Integer(37)); - params.emplace_back(type::TypeId::REAL, sql::Real(37.73)); - params.emplace_back(type::TypeId::DATE, sql::DateVal(sql::Date::FromYMD(1937, 3, 7))); + std::vector params_builder{}; + params_builder.emplace_back(type::TypeId::INTEGER, sql::Integer(37)); + params_builder.emplace_back(type::TypeId::REAL, sql::Real(37.73)); + params_builder.emplace_back(type::TypeId::DATE, sql::DateVal(sql::Date::FromYMD(1937, 3, 7))); auto string_val = sql::ValueUtil::CreateStringVal(std::string_view("37 Strings")); - params.emplace_back(type::TypeId::VARCHAR, string_val.first, std::move(string_val.second)); - exec_ctx.SetParams(common::ManagedPointer>(¶ms)); + params_builder.emplace_back(type::TypeId::VARCHAR, string_val.first, std::move(string_val.second)); + + std::vector> params{}; + std::transform(params_builder.cbegin(), params_builder.cend(), std::back_inserter(params), + [](const parser::ConstantValueExpression &expr) { return common::ManagedPointer{expr.SqlValue()}; }); + exec_ctx.SetParams(common::ManagedPointer{¶ms}); // Generate test tables sql::TableGenerator table_generator{&exec_ctx, db_main->GetStorageLayer()->GetBlockStore(), ns_oid}; From 35f8302a3c04dcd77e451b5e4c740ba39de0d4fa Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 25 Jun 2021 10:41:22 -0400 Subject: [PATCH 048/139] revisiting tpl tests, crashing on basic usage of tpl lambdas --- sample_tpl/call-lambda.tpl | 43 -------------------------------------- sample_tpl/lambda0.tpl | 9 ++++++++ sample_tpl/lambda1.tpl | 10 +++++++++ sample_tpl/tpl_tests.txt | 3 ++- 4 files changed, 21 insertions(+), 44 deletions(-) delete mode 100644 sample_tpl/call-lambda.tpl create mode 100644 sample_tpl/lambda0.tpl create mode 100644 sample_tpl/lambda1.tpl diff --git a/sample_tpl/call-lambda.tpl b/sample_tpl/call-lambda.tpl deleted file mode 100644 index 76f088d12a..0000000000 --- a/sample_tpl/call-lambda.tpl +++ /dev/null @@ -1,43 +0,0 @@ -// Expected output: 70 - -fun f(z : Date ) -> Date { return z } - -fun main(exec : *ExecutionContext) -> int32 { - var y = 11 - var lam = lambda [y] (z: Integer ) -> nil { - y = y + z - } - lam(10) - - - var d = @dateToSql(1999, 2, 11) - //f(lam, d) - var k : Date - //var h = &k - //*h = d - //k = f(d) - lam(d) - if(@datePart(y, @intToSql(21)) == @intToSql(1999)){ - // good - return 1 - } - return 0 -} - -fun pipeline1(QueryState *q) { - TableIterator tvi; - for(@tableIteratorAdvance(&tvi)){ - @hashTableInsert(q.join_ht, @getTupleValue(&tvi, 3))) - } -} - -fun pipeline2(QueryState *q) { - TableIterator tvi; - for(@tableIteratorAdvance(&tvi)){ - var o_custkey = @getTupleValue(&tvi, 1) - if(@hashTableKeyExists(q.join_ht, o_custkey)){ - var out = @outputBufferAlloc(q.output_buff) - out.col1 = o_custkey + 1 - } - } -} \ No newline at end of file diff --git a/sample_tpl/lambda0.tpl b/sample_tpl/lambda0.tpl new file mode 100644 index 0000000000..c05499dc9e --- /dev/null +++ b/sample_tpl/lambda0.tpl @@ -0,0 +1,9 @@ +// Expected output: 2 + +fun main(exec : *ExecutionContext) -> int32 { + // Lambda without capture + var addOne = lambda [] (x: int32) -> int32 { + return x + 1 + } + return addOne(1) +} \ No newline at end of file diff --git a/sample_tpl/lambda1.tpl b/sample_tpl/lambda1.tpl new file mode 100644 index 0000000000..fc3f3505de --- /dev/null +++ b/sample_tpl/lambda1.tpl @@ -0,0 +1,10 @@ +// Expected output: 3 + +fun main(exec : *ExecutionContext) -> int32 { + var x = 1 + var addValue = lambda [x] (y: int32) -> int32 { + x = x + y + } + addValue(2) + return x +} \ No newline at end of file diff --git a/sample_tpl/tpl_tests.txt b/sample_tpl/tpl_tests.txt index 53231cc416..c6a8f22d57 100644 --- a/sample_tpl/tpl_tests.txt +++ b/sample_tpl/tpl_tests.txt @@ -10,7 +10,8 @@ array.tpl,false,44 array-iterate.tpl,false,110 array-iterate-2.tpl,false,110 call.tpl,false,70 -#call-lambda.tpl,false,70 TODO(Kyle): Requires lambdas +#lambda0.tpl,false,2 +#lambda1.tpl,false,3 comments.tpl,false,46 compare.tpl,false,200 date-functions.tpl,false,0 From 3590422e0d07991d750854ad312fb7065b5b0661 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 25 Jun 2021 10:52:58 -0400 Subject: [PATCH 049/139] fix broken benchmark runner --- benchmark/runner/execution_runners.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/benchmark/runner/execution_runners.cpp b/benchmark/runner/execution_runners.cpp index 11d1e1eb8d..d9069a881c 100644 --- a/benchmark/runner/execution_runners.cpp +++ b/benchmark/runner/execution_runners.cpp @@ -528,7 +528,7 @@ class ExecutionRunners : public benchmark::Fixture { execution::exec::ExecutionSettings *exec_settings_arg = nullptr) { transaction::TransactionContext *txn = nullptr; std::unique_ptr accessor = nullptr; - std::vector> param_ref = *params; + const auto ¶ms_ref = *params; execution::exec::NoOpResultConsumer consumer; execution::exec::OutputCallback callback = consumer; @@ -552,8 +552,12 @@ class ExecutionRunners : public benchmark::Fixture { metrics_manager, DISABLED, DISABLED); // Attach params to ExecutionContext - if (static_cast(i) < param_ref.size()) { - exec_ctx->SetParams(common::ManagedPointer>(¶m_ref[i])); + if (static_cast(i) < params_ref.size()) { + std::vector> p{}; + std::transform( + params_ref[i].cbegin(), params_ref[i].cend(), std::back_inserter(p), + [](const parser::ConstantValueExpression &expr) { return common::ManagedPointer{expr.SqlValue()}; }); + exec_ctx->SetParams(common::ManagedPointer{&p}); } exec_query->Run(common::ManagedPointer(exec_ctx), mode); From bc6df12ed3629573116c7a69294dcd7e3014b884 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 25 Jun 2021 14:51:05 -0400 Subject: [PATCH 050/139] fix FunctionBuilder that had unread private member --- .../execution/compiler/function_builder.h | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/include/execution/compiler/function_builder.h b/src/include/execution/compiler/function_builder.h index 5ae3f20513..1fe63aadc6 100644 --- a/src/include/execution/compiler/function_builder.h +++ b/src/include/execution/compiler/function_builder.h @@ -94,27 +94,28 @@ class FunctionBuilder { */ ast::LambdaExpr *GetConstructedLambda() const { return std::get(decl_); } - /** - * @return The code generator instance. - */ + /** @return The code generator instance. */ CodeGen *GetCodeGen() const { return codegen_; } + /** @return `true` if the function represents a lambda, `false` otherwise. */ + bool IsLambda() const { return is_lambda_; } + private: - // The code generation instance. + /** The code generation instance */ CodeGen *codegen_; - // The function's name. + /** The function's name */ ast::Identifier name_; - // The function's arguments. + /** The function's arguments */ util::RegionVector params_; - // The return type of the function. + /** The return type of the function */ ast::Expr *ret_type_; - // The start and stop position of statements in the function. + /** The start and stop position of statements in the function */ SourcePosition start_; - // The list of generated statements making up the function. + /** The list of generated statements making up the function */ ast::BlockStmt *statements_; - // `true` if this function is a lambda, `false` otherwise. + /** `true` if this function is a lambda, `false` otherwise */ bool is_lambda_; - // The cached function declaration. Constructed once in Finish(). + /** The cached function declaration. Constructed once in Finish() */ std::variant decl_; }; From 8de24e8c906698b0fd6930ceec2db1c79a8112cc Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sat, 26 Jun 2021 07:22:30 -0400 Subject: [PATCH 051/139] add top-level command for CREATE FUNCTION feedback --- src/network/postgres/postgres_packet_writer.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/network/postgres/postgres_packet_writer.cpp b/src/network/postgres/postgres_packet_writer.cpp index abd5081965..cbd693b480 100644 --- a/src/network/postgres/postgres_packet_writer.cpp +++ b/src/network/postgres/postgres_packet_writer.cpp @@ -162,6 +162,9 @@ void PostgresPacketWriter::WriteCommandComplete(const QueryType query_type, cons case QueryType::QUERY_CREATE_SCHEMA: WriteCommandComplete("CREATE SCHEMA"); break; + case QueryType::QUERY_CREATE_FUNCTION: + WriteCommandComplete("CREATE FUNCTION"); + break; case QueryType::QUERY_DROP_DB: WriteCommandComplete("DROP DATABASE"); break; From f6751402322931f56ea6848ca77607e34313e681 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Mon, 28 Jun 2021 17:35:53 -0400 Subject: [PATCH 052/139] wip refactoring execution context to resolve lifetime management for query parameters --- benchmark/runner/execution_runners.cpp | 119 +++- src/execution/exec/execution_context.cpp | 21 +- .../execution/compiler/executable_query.h | 2 +- .../execution/exec/execution_context.h | 587 +++++++++++++----- src/include/execution/exec/output.h | 3 + src/include/execution/vm/bytecode_handlers.h | 10 +- .../vm/{vm_defs.h => execution_mode.h} | 0 src/include/execution/vm/module.h | 2 +- src/include/traffic_cop/traffic_cop.h | 2 +- src/parser/udf/udf_parser.cpp | 10 +- src/util/query_exec_util.cpp | 2 +- test/execution/atomics_test.cpp | 2 +- test/test_util/tpch/workload.cpp | 35 +- 13 files changed, 571 insertions(+), 224 deletions(-) rename src/include/execution/vm/{vm_defs.h => execution_mode.h} (100%) diff --git a/benchmark/runner/execution_runners.cpp b/benchmark/runner/execution_runners.cpp index d9069a881c..7a2cc8a2d3 100644 --- a/benchmark/runner/execution_runners.cpp +++ b/benchmark/runner/execution_runners.cpp @@ -444,9 +444,18 @@ class ExecutionRunners : public benchmark::Fixture { exec_settings = *exec_settings_arg; } - auto exec_ctx = std::make_unique( - db_oid, common::ManagedPointer(txn), execution::exec::NoOpResultConsumer(), out_plan->GetOutputSchema().Get(), - common::ManagedPointer(accessor), exec_settings, metrics_manager_, DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithExecutionMode(ExecutionRunners::mode) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputCallback(execution::exec::NoOpResultConsumer{}) + .WithOutputSchema(out_plan->GetOutputSchema()) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); execution::compiler::ExecutableQuery::query_identifier.store(ExecutionRunners::query_id++); auto exec_query = execution::compiler::CompilationContext::Compile(*out_plan, exec_settings, accessor.get(), @@ -489,9 +498,17 @@ class ExecutionRunners : public benchmark::Fixture { auto txn = txn_manager->BeginTransaction(); auto accessor = catalog->GetAccessor(common::ManagedPointer(txn), db_oid, DISABLED); auto exec_settings = ExecutionRunners::GetExecutionSettings(); - auto exec_ctx = std::make_unique( - db_oid, common::ManagedPointer(txn), nullptr, nullptr, common::ManagedPointer(accessor), exec_settings, - metrics_manager_, DISABLED, DISABLED); + + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithExecutionMode(ExecutionRunners::mode) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); execution::sql::TableGenerator table_generator(exec_ctx.get(), block_store, accessor->GetDefaultNamespace()); if (is_build) { @@ -547,18 +564,28 @@ class ExecutionRunners : public benchmark::Fixture { exec_settings = *exec_settings_arg; } - auto exec_ctx = std::make_unique( - db_oid, common::ManagedPointer(txn), callback, out_schema, common::ManagedPointer(accessor), exec_settings, - metrics_manager, DISABLED, DISABLED); - - // Attach params to ExecutionContext - if (static_cast(i) < params_ref.size()) { - std::vector> p{}; - std::transform( - params_ref[i].cbegin(), params_ref[i].cend(), std::back_inserter(p), - [](const parser::ConstantValueExpression &expr) { return common::ManagedPointer{expr.SqlValue()}; }); - exec_ctx->SetParams(common::ManagedPointer{&p}); + // auto exec_ctx = std::make_unique( + // db_oid, common::ManagedPointer(txn), callback, out_schema, common::ManagedPointer(accessor), exec_settings, + // metrics_manager, DISABLED, DISABLED); + + // TODO(Kyle): This makes an unnecessary copy of the query parameters + std::vector parameters{}; + if (static_cast(i) < params_ref.size()) { + std::copy(params_ref[i].cbegin(), params_ref[i].cend(), std::back_inserter(parameters)); } + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithQueryParametersFrom(parameters) + .WithExecutionMode(ExecutionRunners::mode) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(common::ManagedPointer{out_schema}) + .WithOutputCallback(std::move(callback)) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); exec_query->Run(common::ManagedPointer(exec_ctx), mode); @@ -580,10 +607,16 @@ class ExecutionRunners : public benchmark::Fixture { auto txn = txn_manager_->BeginTransaction(); auto accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_oid, DISABLED); auto exec_settings = GetExecutionSettings(); - auto exec_ctx = std::make_unique( - db_oid, common::ManagedPointer(txn), nullptr, nullptr, common::ManagedPointer(accessor), exec_settings, - metrics_manager_, DISABLED, DISABLED); - exec_ctx->SetExecutionMode(static_cast(mode)); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithExecutionMode(ExecutionRunners::mode) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); selfdriving::PipelineOperatingUnits units; selfdriving::ExecutionOperatingUnitFeatureVector pipe0_vec; @@ -942,9 +975,19 @@ BENCHMARK_DEFINE_F(ExecutionRunners, SEQ0_OutputRunners)(benchmark::State &state execution::compiler::ExecutableQuery::query_identifier.store(ExecutionRunners::query_id++); execution::exec::NoOpResultConsumer consumer; execution::exec::OutputCallback callback = consumer; - auto exec_ctx = std::make_unique( - db_oid, common::ManagedPointer(txn), callback, schema.get(), common::ManagedPointer(accessor), exec_settings, - metrics_manager_, DISABLED, DISABLED); + + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithExecutionMode(ExecutionRunners::mode) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(common::ManagedPointer{schema}) + .WithOutputCallback(std::move(callback)) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); auto exec_query = execution::compiler::ExecutableQuery(output.str(), common::ManagedPointer(exec_ctx), false, 16, exec_settings); @@ -1008,9 +1051,18 @@ void ExecutionRunners::ExecuteIndexOperation(benchmark::State *state, bool is_in auto exec_settings = GetExecutionSettings(); execution::exec::NoOpResultConsumer consumer; execution::exec::OutputCallback callback = consumer; - auto exec_ctx = std::make_unique( - db_oid, common::ManagedPointer(txn), callback, nullptr, common::ManagedPointer(accessor), exec_settings, - metrics_manager, DISABLED, DISABLED); + + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithExecutionMode(ExecutionRunners::mode) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputCallback(std::move(callback)) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // A brief discussion of the features: // NUM_ROWS: size of the index @@ -2032,9 +2084,16 @@ void InitializeRunnersState() { // Load the database auto accessor = catalog->GetAccessor(common::ManagedPointer(txn), db_oid, DISABLED); auto exec_settings = ExecutionRunners::GetExecutionSettings(); - auto exec_ctx = std::make_unique( - db_oid, common::ManagedPointer(txn), nullptr, nullptr, common::ManagedPointer(accessor), exec_settings, - db_main->GetMetricsManager(), DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithExecutionMode(ExecutionRunners::mode) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); execution::sql::TableGenerator table_gen(exec_ctx.get(), block_store, accessor->GetDefaultNamespace()); table_gen.GenerateExecutionRunnersData(settings, config); diff --git a/src/execution/exec/execution_context.cpp b/src/execution/exec/execution_context.cpp index 3efcf58cf9..769874dee0 100644 --- a/src/execution/exec/execution_context.cpp +++ b/src/execution/exec/execution_context.cpp @@ -14,6 +14,18 @@ namespace noisepage::execution::exec { +std::unique_ptr ExecutionContextBuilder::Build() { + NOISEPAGE_ASSERT(db_oid_ != INVALID_DATABASE_OID, "Must specify database OID."); + NOISEPAGE_ASSERT(exec_mode_.has_value(), "Must specify execution mode."); + NOISEPAGE_ASSERT(exec_settings_.has_value(), "Must specify execution setting."); + NOISEPAGE_ASSERT(static_cast(catalog_accessor_), "Must specify catalog accessor."); + // MetricsManager, ReplicationManager, and RecoveryManaged may be DISABLED + return std::make_unique(db_oid_, std::move(parameters_), exec_mode_.value(), + std::move(exec_settings_.value()), txn_, output_schema_, + std::move(output_callback_.value()), catalog_accessor_, metrics_manager_, + replication_manager_, recovery_manager_); +} + OutputBuffer *ExecutionContext::OutputBufferNew() { if (schema_ == nullptr) { return nullptr; @@ -26,7 +38,7 @@ OutputBuffer *ExecutionContext::OutputBufferNew() { return buffer; } -uint32_t ExecutionContext::ComputeTupleSize(const planner::OutputSchema *schema) { +uint32_t ExecutionContext::ComputeTupleSize(common::ManagedPointer schema) { uint32_t tuple_size = 0; for (const auto &col : schema->GetColumns()) { auto alignment = sql::ValUtil::GetSqlAlignment(col.GetType()); @@ -106,7 +118,8 @@ void ExecutionContext::EndResourceTracker(const char *name, uint32_t len) { common::thread_context.resource_tracker_.Stop(); common::thread_context.resource_tracker_.SetMemory(mem_tracker_->GetAllocatedSize()); const auto &resource_metrics = common::thread_context.resource_tracker_.GetMetrics(); - common::thread_context.metrics_store_->RecordExecutionData(name, len, execution_mode_, resource_metrics); + common::thread_context.metrics_store_->RecordExecutionData(name, len, static_cast(execution_mode_), + resource_metrics); } } @@ -147,8 +160,8 @@ void ExecutionContext::EndPipelineTracker(query_id_t query_id, pipeline_id_t pip NOISEPAGE_ASSERT(pipeline_id == ouvec->pipeline_id_, "Incorrect feature vector pipeline id?"); selfdriving::ExecutionOperatingUnitFeatureVector features(ouvec->pipeline_features_->begin(), ouvec->pipeline_features_->end()); - common::thread_context.metrics_store_->RecordPipelineData(query_id, pipeline_id, execution_mode_, - std::move(features), resource_metrics); + common::thread_context.metrics_store_->RecordPipelineData( + query_id, pipeline_id, static_cast(execution_mode_), std::move(features), resource_metrics); } } diff --git a/src/include/execution/compiler/executable_query.h b/src/include/execution/compiler/executable_query.h index 605f56c97a..4e4a8e434a 100644 --- a/src/include/execution/compiler/executable_query.h +++ b/src/include/execution/compiler/executable_query.h @@ -9,7 +9,7 @@ #include "common/managed_pointer.h" #include "execution/ast/ast_fwd.h" #include "execution/exec_defs.h" -#include "execution/vm/vm_defs.h" +#include "execution/vm/execution_mode.h" namespace noisepage { namespace selfdriving { diff --git a/src/include/execution/exec/execution_context.h b/src/include/execution/exec/execution_context.h index b22a374220..582251c693 100644 --- a/src/include/execution/exec/execution_context.h +++ b/src/include/execution/exec/execution_context.h @@ -1,9 +1,12 @@ #pragma once #include +#include +#include #include #include +#include "catalog/catalog_defs.h" #include "common/managed_pointer.h" #include "execution/exec/execution_settings.h" #include "execution/exec/output.h" @@ -13,6 +16,7 @@ #include "execution/sql/thread_state_container.h" #include "execution/sql/value.h" #include "execution/util/region.h" +#include "execution/vm/execution_mode.h" #include "metrics/metrics_defs.h" #include "planner/plannodes/output_schema.h" #include "self_driving/modeling/operating_unit.h" @@ -43,10 +47,12 @@ class RecoveryManager; } // namespace noisepage::storage namespace noisepage::execution::exec { + class ExecutionSettings; +class ExecutionContextBuilder; + /** - * Execution Context: Stores information handed in by upper layers. - * TODO(Amadou): This class will change once we know exactly what we get from upper layers. + * The ExecutionContext class stores information handed in by upper layers. */ class EXPORT ExecutionContext { public: @@ -76,75 +82,93 @@ class EXPORT ExecutionContext { */ using HookFn = void (*)(void *, void *, void *); + /* -------------------------------------------------------------------------- + Getters / Setters + -------------------------------------------------------------------------- */ + + /** @return The identifier for the associated query */ + execution::query_id_t GetQueryId() { return query_id_; } + /** - * Constructor - * @param db_oid oid of the database - * @param txn transaction used by this query - * @param callback callback function for outputting - * @param schema the schema of the output - * @param accessor the catalog accessor of this query - * @param exec_settings The execution settings to run with. - * @param metrics_manager The metrics manager for recording metrics - * @param replication_manager The replication manager to handle communication between primary and replicas. - * @param recovery_manager The recovery manager that handles both recovery and application of replication records. - */ - ExecutionContext(catalog::db_oid_t db_oid, common::ManagedPointer txn, - const OutputCallback &callback, const planner::OutputSchema *schema, - const common::ManagedPointer accessor, - const exec::ExecutionSettings &exec_settings, - common::ManagedPointer metrics_manager, - common::ManagedPointer replication_manager, - common::ManagedPointer recovery_manager) - : exec_settings_(exec_settings), - db_oid_(db_oid), - txn_(txn), - mem_tracker_(std::make_unique()), - mem_pool_(std::make_unique(common::ManagedPointer(mem_tracker_))), - schema_(schema), - callback_(callback), - thread_state_container_(std::make_unique(mem_pool_.get())), - accessor_(accessor), - metrics_manager_(metrics_manager), - replication_manager_(replication_manager), - recovery_manager_(recovery_manager) {} - - /** - * @return the transaction used by this query + * Set the current executing query identifier. + * @param query_id The query identifier */ + void SetQueryId(execution::query_id_t query_id) { query_id_ = query_id; } + + /** @return The database OID. */ + catalog::db_oid_t DBOid() { return db_oid_; } + + /** @return The transaction associated with this execution context */ common::ManagedPointer GetTxn() { return txn_; } + /** @return The execution settings. */ + const exec::ExecutionSettings &GetExecutionSettings() const { return execution_settings_; } + + /** @return The catalog accessor associated with this execution context */ + catalog::CatalogAccessor *GetAccessor() { return accessor_.Get(); } + + /** @return The metrics manager associated with this execution context */ + common::ManagedPointer GetMetricsManager() { return metrics_manager_; } + + /** @return The memory pool for this execution context */ + sql::MemoryPool *GetMemoryPool() { return mem_pool_.get(); } + + /** @return The thread state container */ + sql::ThreadStateContainer *GetThreadStateContainer() { return thread_state_container_.get(); } + + /** @return The string allocator for this execution context */ + sql::VarlenHeap *GetStringAllocator() { return &string_allocator_; } + + /** @return The pipeline operating units for the execution context */ + common::ManagedPointer GetPipelineOperatingUnits() { + return pipeline_operating_units_; + } + /** - * Constructs a new Output Buffer for outputting query results to consumers - * @return newly created output buffer + * Set the pipeline operating units for the execution context. + * @param op pipeline operating units for executing the query */ - OutputBuffer *OutputBufferNew(); + void SetPipelineOperatingUnits(common::ManagedPointer op) { + pipeline_operating_units_ = op; + } + + /** @return The number of rows affected by the current execution, e.g., INSERT/DELETE/UPDATE. */ + uint32_t GetRowsAffected() const { return rows_affected_; } /** - * @return The thread state container. + * Increment or decrement the number of rows affected. + * @param num_rows The delta for the number of rows affected */ - sql::ThreadStateContainer *GetThreadStateContainer() { return thread_state_container_.get(); } + void AddRowsAffected(int64_t num_rows) { rows_affected_ += num_rows; } /** - * @return the memory pool + * Overrides recording from memory tracker. + * NOTE: This should never be used by parallel threads directly + * @param memory_use Correct memory value to record */ - sql::MemoryPool *GetMemoryPool() { return mem_pool_.get(); } + void SetMemoryUseOverride(uint32_t memory_use) { + memory_use_override_ = true; + memory_use_override_value_ = memory_use; + } /** - * @return the string allocator + * Sets the estimated concurrency of a parallel operation. + * This value is used when initializing an ExecOUFeatureVector + * + * @note this value is reset by setting it to 0. + * @param estimate Estimated number of concurrent tasks */ - sql::VarlenHeap *GetStringAllocator() { return &string_allocator_; } + void SetNumConcurrentEstimate(uint32_t estimate) { num_concurrent_estimate_ = estimate; } /** - * @param schema the schema of the output - * @return the size of tuple with this final_schema + * Sets the opaque query state pointer for the current query invocation. + * @param query_state QueryState */ - static uint32_t ComputeTupleSize(const planner::OutputSchema *schema); - - /** @return The catalog accessor. */ - catalog::CatalogAccessor *GetAccessor() { return accessor_.Get(); } + void SetQueryState(void *query_state) { query_state_ = query_state; } - /** @return The execution settings. */ - const exec::ExecutionSettings &GetExecutionSettings() const { return exec_settings_; } + /* -------------------------------------------------------------------------- + Resource Metrics Collection + -------------------------------------------------------------------------- */ /** Start the resource tracker. */ void StartResourceTracker(metrics::MetricsComponent component); @@ -184,76 +208,81 @@ class EXPORT ExecutionContext { */ void InitializeParallelOUFeatureVector(selfdriving::ExecOUFeatureVector *ouvec, pipeline_id_t pipeline_id); - /** Initialize the UDF parameter stack. */ - void StartParams() { udf_param_stack_.emplace_back(); } + /* -------------------------------------------------------------------------- + Runtime Parameters (User-Defined Functions) + -------------------------------------------------------------------------- */ - /** Remove an element from the UDF parameter stack. */ - void PopParams() { udf_param_stack_.pop_back(); } + /** Initialize a new, empty collection of parameters at the top of the parameter stack */ + void StartParams() { runtime_parameters_.emplace(); } + + /** Remove the topmost collection of parameters from the parameter stack */ + void FinishParams() { + NOISEPAGE_ASSERT(!runtime_parameters_.empty(), "Attempt to pop from empty runtime parameter stack."); + runtime_parameters_.pop(); + } /** - * Add a parameter to the set of parameters at the top of the UDF parameter stack. + * Add a runtime parameter to the "top-most" collection of runtime parameters. * @param val The parameter to be added */ void AddParam(common::ManagedPointer val) { - udf_param_stack_.back().push_back(val.CastManagedPointerTo()); + NOISEPAGE_ASSERT(!runtime_parameters_.empty(), "Must call StartParams() prior to adding runtime parameters."); + runtime_parameters_.top().push_back(val.CastManagedPointerTo()); } - /** @return The database OID. */ - catalog::db_oid_t DBOid() { return db_oid_; } - - /** - * Set the mode for this execution. - * This only records the mode and serves the metrics collection purpose, - * which does not have any impact on the actual execution. - * @param mode the integer value of the execution mode to record - */ - void SetExecutionMode(uint8_t mode) { execution_mode_ = mode; } - - /** - * Set the accessor - * @param accessor The catalog accessor. - */ - void SetAccessor(const common::ManagedPointer accessor) { accessor_ = accessor; } - /** - * Set the execution parameters. - * @param params The execution parameters + * Add a runtime parameter to the "top-most" collection of runtime parameters. + * @param val The parameter to be added */ - void SetParams(common::ManagedPointer>> params) { - params_ = params; + void AddParam(common::ManagedPointer val) { + NOISEPAGE_ASSERT(!runtime_parameters_.empty(), "Must call StartParams() prior to adding runtime parameters."); + runtime_parameters_.top().push_back(val); } + // /** + // * Set the execution parameters for the query. + // * @param params The execution parameters + // * + // * NOTE: The use of a ManagedPointer for this API denotes that + // * the ExecutionContext instance does not own the underlying + // * collection of query parameters; the caller is responsible for + // * ensuring the query parameters survive until query execution + // * is complete. + // */ + // void SetParams(common::ManagedPointer>> params) { + // NOISEPAGE_ASSERT(!static_cast(params_), "Attempt to set query execution parameters multiple times."); + // params_ = params; + // } + /** * Get the parameter at the specified index. - * @param param_idx index of parameter to access + * @param index index of parameter to access * @return An immutable point to the parameter at specified index */ - common::ManagedPointer GetParam(std::size_t param_idx) const { - // TODO(Kyle): This logic is confusing, why are we transparently - // switching between the "regular" parameters and the UDF parameters? - return udf_param_stack_.empty() ? (*params_)[param_idx] : udf_param_stack_.back()[param_idx]; + common::ManagedPointer GetParam(uint32_t index) const { + // Always get the query parameter from the "top-most" collection + // of parameters; if the runtime parameters stack is empty, default + // to the "base" set of parameters for the query, otherwise, grab + // the parameter at the specified index from the top of the runtime + // parameters stack. + if (runtime_parameters_.empty()) { + NOISEPAGE_ASSERT(index < parameters_.size(), "ExecutionContext::GetParam() index out of range"); + return parameters_[index]; + } else { + NOISEPAGE_ASSERT(index < runtime_parameters_.top().size(), "ExecutionContext::GetParam() index out of range."); + return runtime_parameters_.top()[index]; + } } - /** - * Set the PipelineOperatingUnits - * @param op PipelineOperatingUnits for executing the given query - */ - void SetPipelineOperatingUnits(common::ManagedPointer op) { - pipeline_operating_units_ = op; - } + /* -------------------------------------------------------------------------- + Other Functionality + -------------------------------------------------------------------------- */ /** - * @return PipelineOperatingUnits + * Constructs a new Output Buffer for outputting query results to consumers. + * @return The newly created output buffer */ - common::ManagedPointer GetPipelineOperatingUnits() { - return pipeline_operating_units_; - } - - /** @return The number of rows affected by the current execution, e.g., INSERT/DELETE/UPDATE. */ - uint32_t GetRowsAffected() const { return rows_affected_; } - - /** Increment or decrement the number of rows affected. */ - void AddRowsAffected(int64_t num_rows) { rows_affected_ += num_rows; } + OutputBuffer *OutputBufferNew(); /** * @return On the primary, returns the ID of the last txn sent. @@ -279,49 +308,33 @@ class EXPORT ExecutionContext { void AggregateMetricsThread(); /** - * Ensures that the trackers for the current thread are stopped + * Ensures that the trackers for the current thread are stopped. */ void EnsureTrackersStopped(); /** - * @return metrics manager used by execution context + * Compute the size of an output tuple based on the provided schema. + * @param schema The output schema + * @return The size of tuple in this schema */ - common::ManagedPointer GetMetricsManager() { return metrics_manager_; } + static uint32_t ComputeTupleSize(common::ManagedPointer schema); - /** - * @return query identifier - */ - execution::query_id_t GetQueryId() { return query_id_; } + /* -------------------------------------------------------------------------- + Hook Function Management + -------------------------------------------------------------------------- */ /** - * Set the current executing query identifier + * Initializes the set of hooks for the execution context to specified capacity. + * @param num_hooks The desired number of hooks */ - void SetQueryId(execution::query_id_t query_id) { query_id_ = query_id; } + void InitHooks(std::size_t num_hooks); /** - * Overrides recording from memory tracker - * This should never be used by parallel threads directly - * @param memory_use Correct memory value to record - */ - void SetMemoryUseOverride(uint32_t memory_use) { - memory_use_override_ = true; - memory_use_override_value_ = memory_use; - } - - /** - * Sets the opaque query state pointer for the current query invocation - * @param query_state QueryState - */ - void SetQueryState(void *query_state) { query_state_ = query_state; } - - /** - * Sets the estimated concurrency of a parallel operation. - * This value is used when initializing an ExecOUFeatureVector - * - * @note this value is reset by setting it to 0. - * @param estimate Estimated number of concurrent tasks + * Registers a hook function + * @param hook_idx Hook index to register function + * @param hook Function to register */ - void SetNumConcurrentEstimate(uint32_t estimate) { num_concurrent_estimate_ = estimate; } + void RegisterHook(std::size_t hook_idx, HookFn hook); /** * Invoke a hook function if a hook function is available @@ -329,60 +342,302 @@ class EXPORT ExecutionContext { * @param tls TLS argument * @param arg Opaque argument to pass */ - void InvokeHook(size_t hook_index, void *tls, void *arg); + void InvokeHook(std::size_t hook_index, void *tls, void *arg); /** - * Registers a hook function - * @param hook_idx Hook index to register function - * @param hook Function to register + * Clear the hooks for the execution context. */ - void RegisterHook(size_t hook_idx, HookFn hook); + void ClearHooks() { hooks_.clear(); } - /** - * Initializes hooks_ to a certain capacity - * @param num_hooks Number of hooks needed - */ - void InitHooks(size_t num_hooks); + private: + friend class ExecutionContextBuilder; /** - * Clears hooks_ + * Construct a new ExecutionContext instance. + * + * NOTE: Private access modifier forces use of ExecutionContextBuilder. + * + * @param db_oid The OID of the database + * @param parameters The query parameters + * @param execution_mode The query execution mode + * @param execution_settings The execution settings to run with + * @param txn The transaction used by this query + * @param output_schema The output schema + * @param output_callback The callback function for query output + * @param accessor The catalog accessor of this query + * @param metrics_manager The metrics manager for recording metrics + * @param replication_manager The replication manager to handle communication between primary and replicas. + * @param recovery_manager The recovery manager that handles both recovery and application of replication records. */ - void ClearHooks() { hooks_.clear(); } + ExecutionContext(const catalog::db_oid_t db_oid, std::vector> &¶meters, + vm::ExecutionMode execution_mode, exec::ExecutionSettings &&execution_settings, + const common::ManagedPointer txn, + const common::ManagedPointer output_schema, OutputCallback &&output_callback, + const common::ManagedPointer accessor, + const common::ManagedPointer metrics_manager, + const common::ManagedPointer replication_manager, + const common::ManagedPointer recovery_manager) + : db_oid_{db_oid}, + parameters_{std::move(parameters)}, + execution_mode_{execution_mode}, + execution_settings_{execution_settings}, + txn_{txn}, + schema_{output_schema}, + callback_{std::move(output_callback)}, + accessor_{accessor}, + metrics_manager_{metrics_manager}, + replication_manager_{replication_manager}, + recovery_manager_{recovery_manager}, + mem_tracker_{std::make_unique()}, + mem_pool_{std::make_unique(common::ManagedPointer(mem_tracker_))}, + thread_state_container_{std::make_unique(mem_pool_.get())} {} private: + /** + * The query identifier + * + * The query identifier is only used in certain situations and is + * set manually after construction of the ExecutionContext via the + * SetQueryId() member function. + */ query_id_t query_id_{execution::query_id_t(0)}; - exec::ExecutionSettings exec_settings_; - catalog::db_oid_t db_oid_; - common::ManagedPointer txn_; - std::unique_ptr mem_tracker_; - std::unique_ptr mem_pool_; - std::unique_ptr buffer_ = nullptr; - const planner::OutputSchema *schema_ = nullptr; + + /** The query execution mode */ + vm::ExecutionMode execution_mode_; + /** The query parameters */ + std::vector> parameters_; + + /** The OID of the database with which the query is associated */ + const catalog::db_oid_t db_oid_; + /** The execution setting for the query */ + const exec::ExecutionSettings execution_settings_; + + /** The associated transaction */ + const common::ManagedPointer txn_; + + /** The query output schema */ + common::ManagedPointer schema_{nullptr}; + /** The query output buffer */ + std::unique_ptr buffer_{nullptr}; + /** The query output callback */ const OutputCallback &callback_; - // Container for thread-local state. - // During parallel processing, execution threads access their thread-local state from this container. - std::unique_ptr thread_state_container_; - // TODO(WAN): EXEC PORT we used to push the memory tracker into the string allocator, do this - sql::VarlenHeap string_allocator_; - common::ManagedPointer pipeline_operating_units_{nullptr}; + /** The query catalog accessor */ common::ManagedPointer accessor_; + /** The query metrics manager */ common::ManagedPointer metrics_manager_; - common::ManagedPointer>> params_; - uint8_t execution_mode_; - uint32_t rows_affected_ = 0; - + /** The replication manager with which the query is associated */ common::ManagedPointer replication_manager_; + /** The recovery manager with which the query is associated */ common::ManagedPointer recovery_manager_; - // The stack of UDF parameters; each element in the stack - // is itself a (possibly-incomplete) set of parameters - std::vector>> udf_param_stack_; + /** The memory tracker */ + std::unique_ptr mem_tracker_; + /** The memory pool */ + std::unique_ptr mem_pool_; + /** The container for thread-local state */ + std::unique_ptr thread_state_container_; + + /** The allocator for strings */ + // TODO(WAN): EXEC PORT we used to push the memory tracker into the string allocator, do this + sql::VarlenHeap string_allocator_; + + /** The pipeline operating units for the query */ + common::ManagedPointer pipeline_operating_units_{nullptr}; + /** The number of rows affected by the query */ + uint32_t rows_affected_{0}; + + /** `true` if memory overrride is used */ bool memory_use_override_ = false; - uint32_t memory_use_override_value_ = 0; - uint32_t num_concurrent_estimate_ = 0; + /** The value to use for memory override */ + uint32_t memory_use_override_value_{0}; + /** The concurrency estimate for query execution */ + uint32_t num_concurrent_estimate_{0}; + + /** The hooks for the query */ std::vector hooks_{}; + + /** The query state object */ void *query_state_; + + /** The runtime parameter stack */ + std::stack>> runtime_parameters_; +}; + +/** + * The ExecutionContextBuilder class implements a builder for ExecutionContext. + */ +class ExecutionContextBuilder { + public: + /** + * Construct a new ExecutionContextBuilder. + */ + ExecutionContextBuilder() = default; + + /** @return The completed ExecutionContext instance */ + std::unique_ptr Build(); + + /** + * Set the execution mode for the execution context. + * @param mode The execution mode + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithExecutionMode(const vm::ExecutionMode mode) { + exec_mode_.emplace(mode); + return *this; + } + + /** + * Set the query parameters for the execution context. + * @param parameters The query parameters + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithQueryParameters(std::vector> &¶meters) { + parameters_ = std::move(parameters); + return *this; + } + + /** + * Set the query parameters for the execution context. + * @param param_expr The collection of expressions from which the query parameters are derived + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithQueryParametersFrom( + const std::vector ¶meter_exprs) { + NOISEPAGE_ASSERT(parameters_.empty(), "Attempt to initialize query parameters more than once."); + parameters_.reserve(parameter_exprs.size()); + std::transform(parameter_exprs.cbegin(), parameter_exprs.cend(), std::back_inserter(parameters_), + [](const parser::ConstantValueExpression &expr) -> common::ManagedPointer { + return common::ManagedPointer{expr.SqlValue()}; + }); + return *this; + } + + /** + * Set the database OID for the execution context. + * @param db_oid The database OID + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithDatabaseOID(const catalog::db_oid_t db_oid) { + db_oid_ = db_oid; + return *this; + } + + /** + * Set the transaction context for the execution context. + * @param txn The transaction context + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithTxnContext(common::ManagedPointer txn) { + txn_ = txn; + return *this; + } + + /** + * Set the output schema for the execution context. + * @param output_schema The output schema + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithOutputSchema(common::ManagedPointer output_schema) { + output_schema_ = output_schema_; + return *this; + } + + /** + * Set the output callback for the execution context. + * @param output_callback The output callback + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithOutputCallback(OutputCallback &&output_callback) { + output_callback_.emplace(std::move(output_callback)); + return *this; + } + + /** + * Set the catalog accessor for the execution context. + * @param accessor The catalog accessor + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithCatalogAccessor(common::ManagedPointer accessor) { + catalog_accessor_ = accessor; + return *this; + } + + /** + * Set the execution settings for the execution context. + * @param exec_settings The execution settings + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithExecutionSettings(exec::ExecutionSettings &&exec_settings) { + exec_settings_.emplace(std::move(exec_settings)); + return *this; + } + + /** + * Set the execution settings for the execution context. + * @param exec_settings The execution settings + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithExecutionSettings(exec::ExecutionSettings exec_settings) { + exec_settings_.emplace(std::move(exec_settings)); + return *this; + } + + /** + * Set the metrics manager for the execution context. + * @param metrics_manager The metrics manager + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithMetricsManager(common::ManagedPointer metrics_manager) { + metrics_manager_ = metrics_manager; + return *this; + } + + /** + * Set the replication manager for the execution context. + * @param replication_manager The replication manager + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithReplicationManager( + common::ManagedPointer replication_manager) { + replication_manager_ = replication_manager; + return *this; + } + + /** + * Set the recovery manager for the execution context. + * @param recovery_manager The recovery manager + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithRecoveryManager(common::ManagedPointer recovery_manager) { + recovery_manager_ = recovery_manager; + return *this; + } + + private: + /** The query execution mode */ + std::optional exec_mode_; + /** The query execution settings */ + std::optional exec_settings_; + /** The query parmeters */ + std::vector> parameters_; + /** The database OID */ + catalog::db_oid_t db_oid_{INVALID_DATABASE_OID}; + /** The associated transaction */ + common::ManagedPointer txn_; + /** The output callback */ + std::optional output_callback_{NULL_OUTPUT_CALLBACK}; + /** The output schema */ + common::ManagedPointer output_schema_{nullptr}; + /** The catalog accessor */ + common::ManagedPointer catalog_accessor_; + /** The metrics manager */ + common::ManagedPointer metrics_manager_; + /** The replication manager */ + common::ManagedPointer replication_manager_; + /** The recovery manager */ + common::ManagedPointer recovery_manager_; }; + } // namespace noisepage::execution::exec diff --git a/src/include/execution/exec/output.h b/src/include/execution/exec/output.h index 21742a7acd..2e9cc25f5a 100644 --- a/src/include/execution/exec/output.h +++ b/src/include/execution/exec/output.h @@ -28,6 +28,9 @@ namespace noisepage::execution::exec { // Params(): tuples, num_tuples, tuple_size; using OutputCallback = std::function; +/** An empty output callback */ +constexpr const OutputCallback NULL_OUTPUT_CALLBACK{nullptr}; + /** * A class that buffers the output and makes a callback for every batch. */ diff --git a/src/include/execution/vm/bytecode_handlers.h b/src/include/execution/vm/bytecode_handlers.h index 8da087b5a2..f4407caf72 100644 --- a/src/include/execution/vm/bytecode_handlers.h +++ b/src/include/execution/vm/bytecode_handlers.h @@ -2185,13 +2185,15 @@ VM_OP_WARM void OpAbortTxn(noisepage::execution::exec::ExecutionContext *exec_ct // Parameter Calls // --------------------------------- -// TODO(Kyle): this used to have a conditional check; was it safe to remove? #define GEN_SCALAR_PARAM_GET(Name, SqlType) \ VM_OP_HOT void OpGetParam##Name(noisepage::execution::sql::SqlType *ret, \ noisepage::execution::exec::ExecutionContext *exec_ctx, uint32_t param_idx) { \ - const auto &val = \ - *reinterpret_cast(exec_ctx->GetParam(param_idx).Get()); \ - *ret = val; \ + const auto &cve = exec_ctx->GetParam(param_idx); \ + if (cve.IsNull()) { \ + ret->is_null_ = true; \ + } else { \ + *ret = cve.Get##SqlType(); \ + } \ } GEN_SCALAR_PARAM_GET(Bool, BoolVal) diff --git a/src/include/execution/vm/vm_defs.h b/src/include/execution/vm/execution_mode.h similarity index 100% rename from src/include/execution/vm/vm_defs.h rename to src/include/execution/vm/execution_mode.h diff --git a/src/include/execution/vm/module.h b/src/include/execution/vm/module.h index 44ce8082d0..dd52ea2816 100644 --- a/src/include/execution/vm/module.h +++ b/src/include/execution/vm/module.h @@ -10,9 +10,9 @@ #include "execution/ast/type.h" #include "execution/vm/bytecode_module.h" +#include "execution/vm/execution_mode.h" #include "execution/vm/llvm_engine.h" #include "execution/vm/module_metadata.h" -#include "execution/vm/vm_defs.h" namespace noisepage::execution::vm { diff --git a/src/include/traffic_cop/traffic_cop.h b/src/include/traffic_cop/traffic_cop.h index f5fd70b67f..360c342fd4 100644 --- a/src/include/traffic_cop/traffic_cop.h +++ b/src/include/traffic_cop/traffic_cop.h @@ -7,7 +7,7 @@ #include "catalog/catalog_defs.h" #include "common/managed_pointer.h" -#include "execution/vm/vm_defs.h" +#include "execution/vm/execution_mode.h" #include "network/network_defs.h" #include "traffic_cop/traffic_cop_defs.h" diff --git a/src/parser/udf/udf_parser.cpp b/src/parser/udf/udf_parser.cpp index 108a8c3039..631f012dc2 100644 --- a/src/parser/udf/udf_parser.cpp +++ b/src/parser/udf/udf_parser.cpp @@ -204,11 +204,11 @@ std::unique_ptr PLpgSQLParser::ParseFor(const nloh } auto body_stmt = ParseBlock(loop[K_BODY]); auto var_array = loop[K_ROW][K_PLPGSQL_ROW][K_FIELDS]; - std::vector var_vec; - for (auto var : var_array) { - var_vec.push_back(var[K_NAME].get()); - } - return std::make_unique(std::move(var_vec), std::move(parse_result), + std::vector variables{}; + variables.reserve(var_array.size()); + std::transform(var_array.cbegin(), var_array.cend(), std::back_inserter(variables), + [](const nlohmann::json &var) { return var[K_NAME].get(); }); + return std::make_unique(std::move(variables), std::move(parse_result), std::move(body_stmt)); } diff --git a/src/util/query_exec_util.cpp b/src/util/query_exec_util.cpp index 72bc91ac00..472f4bc3a9 100644 --- a/src/util/query_exec_util.cpp +++ b/src/util/query_exec_util.cpp @@ -10,7 +10,7 @@ #include "execution/compiler/executable_query.h" #include "execution/exec/execution_context.h" #include "execution/sql/ddl_executors.h" -#include "execution/vm/vm_defs.h" +#include "execution/vm/execution_mode.h" #include "loggers/common_logger.h" #include "metrics/metrics_manager.h" #include "network/network_defs.h" diff --git a/test/execution/atomics_test.cpp b/test/execution/atomics_test.cpp index d3fd2d1d13..cfe5471dd8 100644 --- a/test/execution/atomics_test.cpp +++ b/test/execution/atomics_test.cpp @@ -8,9 +8,9 @@ #include "execution/compiler/compiler_settings.h" #include "execution/sema/error_reporter.h" #include "execution/util/region.h" +#include "execution/vm/execution_mode.h" #include "execution/vm/llvm_engine.h" #include "execution/vm/module.h" -#include "execution/vm/vm_defs.h" #include "spdlog/fmt/fmt.h" #include "test_util/fs_util.h" #include "test_util/multithread_test_util.h" diff --git a/test/test_util/tpch/workload.cpp b/test/test_util/tpch/workload.cpp index 12f83a91df..8b60e1b149 100644 --- a/test/test_util/tpch/workload.cpp +++ b/test/test_util/tpch/workload.cpp @@ -41,13 +41,19 @@ Workload::Workload(common::ManagedPointer db_main, const std::string &db exec_settings_.is_counters_enabled_ = true; // Make the execution context - auto exec_ctx = - execution::exec::ExecutionContext(db_oid_, common::ManagedPointer(txn), nullptr, - nullptr, common::ManagedPointer(accessor), - exec_settings_, db_main->GetMetricsManager(), DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithExecutionMode(execution::vm::ExecutionMode::Interpret) + .WithExecutionSettings(exec_settings_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // create the TPCH database and compile the queries - GenerateTables(&exec_ctx, table_root, type); + GenerateTables(exec_ctx.get(), table_root, type); LoadQueries(accessor, type); txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); @@ -147,13 +153,22 @@ void Workload::Execute(int8_t worker_id, uint64_t execution_us_per_worker, uint6 // Uncomment this line and change output.cpp:90 to EXECUTION_LOG_INFO to print output // execution::exec::OutputPrinter printer(output_schema); execution::exec::NoOpResultConsumer printer; - auto exec_ctx = execution::exec::ExecutionContext( - db_oid_, common::ManagedPointer(txn), printer, output_schema, - common::ManagedPointer(accessor), exec_settings_, db_main_->GetMetricsManager(), - DISABLED, DISABLED); + + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithExecutionMode(mode) + .WithExecutionSettings(exec_settings_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(common::ManagedPointer{output_schema}) + .WithOutputCallback(std::move(printer)) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); std::get<0>(query_and_plan_[index[counter]]) - ->Run(common::ManagedPointer(&exec_ctx), mode); + ->Run(common::ManagedPointer(exec_ctx), mode); // Only execute up to query_num number of queries for this thread in round-robin counter = counter == query_num - 1 ? 0 : counter + 1; From 13ad86578a605374b46cbf6333fa5fb03f793f8a Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Tue, 29 Jun 2021 08:04:31 -0400 Subject: [PATCH 053/139] still working on refactor --- src/execution/compiler/executable_query.cpp | 3 +- src/execution/exec/execution_context.cpp | 19 ++++++-- .../execution/exec/execution_context.h | 46 ++++--------------- src/include/execution/exec/output.h | 3 -- src/include/execution/vm/bytecode_handlers.h | 12 ++--- src/traffic_cop/traffic_cop.cpp | 32 +++++++++---- 6 files changed, 54 insertions(+), 61 deletions(-) diff --git a/src/execution/compiler/executable_query.cpp b/src/execution/compiler/executable_query.cpp index 4522313472..707557feb4 100644 --- a/src/execution/compiler/executable_query.cpp +++ b/src/execution/compiler/executable_query.cpp @@ -179,9 +179,8 @@ void ExecutableQuery::Run(common::ManagedPointer exec_ct // First, allocate the query state and move the execution context into it. auto query_state = std::make_unique(query_state_size_); *reinterpret_cast(query_state.get()) = exec_ctx.Get(); + exec_ctx->SetQueryState(query_state.get()); - - exec_ctx->SetExecutionMode(static_cast(mode)); exec_ctx->SetPipelineOperatingUnits(GetPipelineOperatingUnits()); exec_ctx->SetQueryId(query_id_); diff --git a/src/execution/exec/execution_context.cpp b/src/execution/exec/execution_context.cpp index 769874dee0..28b0949528 100644 --- a/src/execution/exec/execution_context.cpp +++ b/src/execution/exec/execution_context.cpp @@ -15,15 +15,26 @@ namespace noisepage::execution::exec { std::unique_ptr ExecutionContextBuilder::Build() { - NOISEPAGE_ASSERT(db_oid_ != INVALID_DATABASE_OID, "Must specify database OID."); + NOISEPAGE_ASSERT(db_oid_ != catalog::INVALID_DATABASE_OID, "Must specify database OID."); NOISEPAGE_ASSERT(exec_mode_.has_value(), "Must specify execution mode."); - NOISEPAGE_ASSERT(exec_settings_.has_value(), "Must specify execution setting."); + NOISEPAGE_ASSERT(exec_settings_.has_value(), "Must specify execution settings."); NOISEPAGE_ASSERT(static_cast(catalog_accessor_), "Must specify catalog accessor."); // MetricsManager, ReplicationManager, and RecoveryManaged may be DISABLED return std::make_unique(db_oid_, std::move(parameters_), exec_mode_.value(), std::move(exec_settings_.value()), txn_, output_schema_, - std::move(output_callback_.value()), catalog_accessor_, metrics_manager_, - replication_manager_, recovery_manager_); + std::move(output_callback_.value_or(nullptr)), catalog_accessor_, + metrics_manager_, replication_manager_, recovery_manager_); +} + +ExecutionContextBuilder &ExecutionContextBuilder::WithQueryParametersFrom( + const std::vector ¶meter_exprs) { + NOISEPAGE_ASSERT(parameters_.empty(), "Attempt to initialize query parameters more than once."); + parameters_.reserve(parameter_exprs.size()); + std::transform(parameter_exprs.cbegin(), parameter_exprs.cend(), std::back_inserter(parameters_), + [](const parser::ConstantValueExpression &expr) -> common::ManagedPointer { + return common::ManagedPointer{expr.SqlValue()}; + }); + return *this; } OutputBuffer *ExecutionContext::OutputBufferNew() { diff --git a/src/include/execution/exec/execution_context.h b/src/include/execution/exec/execution_context.h index 582251c693..3ba2af71e9 100644 --- a/src/include/execution/exec/execution_context.h +++ b/src/include/execution/exec/execution_context.h @@ -49,7 +49,6 @@ class RecoveryManager; namespace noisepage::execution::exec { class ExecutionSettings; -class ExecutionContextBuilder; /** * The ExecutionContext class stores information handed in by upper layers. @@ -239,21 +238,6 @@ class EXPORT ExecutionContext { runtime_parameters_.top().push_back(val); } - // /** - // * Set the execution parameters for the query. - // * @param params The execution parameters - // * - // * NOTE: The use of a ManagedPointer for this API denotes that - // * the ExecutionContext instance does not own the underlying - // * collection of query parameters; the caller is responsible for - // * ensuring the query parameters survive until query execution - // * is complete. - // */ - // void SetParams(common::ManagedPointer>> params) { - // NOISEPAGE_ASSERT(!static_cast(params_), "Attempt to set query execution parameters multiple times."); - // params_ = params; - // } - /** * Get the parameter at the specified index. * @param index index of parameter to access @@ -349,9 +333,7 @@ class EXPORT ExecutionContext { */ void ClearHooks() { hooks_.clear(); } - private: - friend class ExecutionContextBuilder; - + public: /** * Construct a new ExecutionContext instance. * @@ -392,6 +374,7 @@ class EXPORT ExecutionContext { mem_pool_{std::make_unique(common::ManagedPointer(mem_tracker_))}, thread_state_container_{std::make_unique(mem_pool_.get())} {} + friend class ExecutionContextBuilder; private: /** * The query identifier @@ -402,13 +385,13 @@ class EXPORT ExecutionContext { */ query_id_t query_id_{execution::query_id_t(0)}; - /** The query execution mode */ - vm::ExecutionMode execution_mode_; - /** The query parameters */ - std::vector> parameters_; - /** The OID of the database with which the query is associated */ const catalog::db_oid_t db_oid_; + + /** The query parameters */ + std::vector> parameters_; + /** The query execution mode */ + vm::ExecutionMode execution_mode_; /** The execution setting for the query */ const exec::ExecutionSettings execution_settings_; @@ -503,16 +486,7 @@ class ExecutionContextBuilder { * @param param_expr The collection of expressions from which the query parameters are derived * @return Builder reference for chaining */ - ExecutionContextBuilder &WithQueryParametersFrom( - const std::vector ¶meter_exprs) { - NOISEPAGE_ASSERT(parameters_.empty(), "Attempt to initialize query parameters more than once."); - parameters_.reserve(parameter_exprs.size()); - std::transform(parameter_exprs.cbegin(), parameter_exprs.cend(), std::back_inserter(parameters_), - [](const parser::ConstantValueExpression &expr) -> common::ManagedPointer { - return common::ManagedPointer{expr.SqlValue()}; - }); - return *this; - } + ExecutionContextBuilder &WithQueryParametersFrom(const std::vector ¶meter_exprs); /** * Set the database OID for the execution context. @@ -623,11 +597,11 @@ class ExecutionContextBuilder { /** The query parmeters */ std::vector> parameters_; /** The database OID */ - catalog::db_oid_t db_oid_{INVALID_DATABASE_OID}; + catalog::db_oid_t db_oid_{catalog::INVALID_DATABASE_OID}; /** The associated transaction */ common::ManagedPointer txn_; /** The output callback */ - std::optional output_callback_{NULL_OUTPUT_CALLBACK}; + std::optional output_callback_; /** The output schema */ common::ManagedPointer output_schema_{nullptr}; /** The catalog accessor */ diff --git a/src/include/execution/exec/output.h b/src/include/execution/exec/output.h index 2e9cc25f5a..21742a7acd 100644 --- a/src/include/execution/exec/output.h +++ b/src/include/execution/exec/output.h @@ -28,9 +28,6 @@ namespace noisepage::execution::exec { // Params(): tuples, num_tuples, tuple_size; using OutputCallback = std::function; -/** An empty output callback */ -constexpr const OutputCallback NULL_OUTPUT_CALLBACK{nullptr}; - /** * A class that buffers the output and makes a callback for every batch. */ diff --git a/src/include/execution/vm/bytecode_handlers.h b/src/include/execution/vm/bytecode_handlers.h index f4407caf72..e211eb0d88 100644 --- a/src/include/execution/vm/bytecode_handlers.h +++ b/src/include/execution/vm/bytecode_handlers.h @@ -2185,15 +2185,13 @@ VM_OP_WARM void OpAbortTxn(noisepage::execution::exec::ExecutionContext *exec_ct // Parameter Calls // --------------------------------- +// TODO(Kyle): Is it ever the case that we pass a NULL CVE to call? #define GEN_SCALAR_PARAM_GET(Name, SqlType) \ VM_OP_HOT void OpGetParam##Name(noisepage::execution::sql::SqlType *ret, \ noisepage::execution::exec::ExecutionContext *exec_ctx, uint32_t param_idx) { \ - const auto &cve = exec_ctx->GetParam(param_idx); \ - if (cve.IsNull()) { \ - ret->is_null_ = true; \ - } else { \ - *ret = cve.Get##SqlType(); \ - } \ + const auto &val = \ + *reinterpret_cast(exec_ctx->GetParam(param_idx).Get()); \ + *ret = val; \ } GEN_SCALAR_PARAM_GET(Bool, BoolVal) @@ -2229,7 +2227,7 @@ GEN_SCALAR_PARAM_ADD(String, StringVal, VARCHAR) VM_OP_HOT void OpStartNewParams(noisepage::execution::exec::ExecutionContext *exec_ctx) { exec_ctx->StartParams(); } -VM_OP_HOT void OpFinishParams(noisepage::execution::exec::ExecutionContext *exec_ctx) { exec_ctx->PopParams(); } +VM_OP_HOT void OpFinishParams(noisepage::execution::exec::ExecutionContext *exec_ctx) { exec_ctx->FinishParams(); } // --------------------------------- // Replication functions diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index febd0a7fcc..1f2805a845 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -582,15 +582,29 @@ TrafficCopResult TrafficCop::RunExecutableQuery(const common::ManagedPointerMetricsManager(); } - auto exec_ctx = std::make_unique( - connection_ctx->GetDatabaseOid(), connection_ctx->Transaction(), callback, physical_plan->GetOutputSchema().Get(), - connection_ctx->Accessor(), exec_settings, metrics, replication_manager_, recovery_manager_); - - std::vector> params{}; - params.reserve(portal->Parameters()->size()); - std::transform(portal->Parameters()->cbegin(), portal->Parameters()->cend(), std::back_inserter(params), - [](const parser::ConstantValueExpression &cve) { return common::ManagedPointer{cve.SqlValue()}; }); - exec_ctx->SetParams(common::ManagedPointer(¶ms)); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(connection_ctx->GetDatabaseOid()) + .WithExecutionMode(execution_mode_) + .WithExecutionSettings(exec_settings) + .WithTxnContext(connection_ctx->Transaction()) + .WithOutputSchema(physical_plan->GetOutputSchema()) + .WithOutputCallback(std::move(callback)) + .WithCatalogAccessor(connection_ctx->Accessor()) + .WithMetricsManager(metrics) + .WithReplicationManager(replication_manager_) + .WithRecoveryManager(recovery_manager_) + .WithQueryParametersFrom(*portal->Parameters()) + .Build(); + + // auto exec_ctx = std::make_unique( + // connection_ctx->GetDatabaseOid(), connection_ctx->Transaction(), callback, physical_plan->GetOutputSchema().Get(), + // connection_ctx->Accessor(), exec_settings, metrics, replication_manager_, recovery_manager_); + + // std::vector> params{}; + // params.reserve(portal->Parameters()->size()); + // std::transform(portal->Parameters()->cbegin(), portal->Parameters()->cend(), std::back_inserter(params), + // [](const parser::ConstantValueExpression &cve) { return common::ManagedPointer{cve.SqlValue()}; }); + // exec_ctx->SetParams(common::ManagedPointer(¶ms)); const auto exec_query = portal->GetStatement()->GetExecutableQuery(); From 7c22e78f8358966cbedc6cd1ef32a7f8d8ca944b Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Tue, 29 Jun 2021 08:54:15 -0400 Subject: [PATCH 054/139] building again after execution context refactor --- src/execution/compiler/executable_query.cpp | 2 +- .../execution/exec/execution_context.h | 4 ++- src/include/util/query_exec_util.h | 26 +++++++++++--- src/traffic_cop/traffic_cop.cpp | 34 +++++++------------ src/util/query_exec_util.cpp | 24 ++++++++----- 5 files changed, 54 insertions(+), 36 deletions(-) diff --git a/src/execution/compiler/executable_query.cpp b/src/execution/compiler/executable_query.cpp index 707557feb4..30e4624add 100644 --- a/src/execution/compiler/executable_query.cpp +++ b/src/execution/compiler/executable_query.cpp @@ -179,7 +179,7 @@ void ExecutableQuery::Run(common::ManagedPointer exec_ct // First, allocate the query state and move the execution context into it. auto query_state = std::make_unique(query_state_size_); *reinterpret_cast(query_state.get()) = exec_ctx.Get(); - + exec_ctx->SetQueryState(query_state.get()); exec_ctx->SetPipelineOperatingUnits(GetPipelineOperatingUnits()); exec_ctx->SetQueryId(query_id_); diff --git a/src/include/execution/exec/execution_context.h b/src/include/execution/exec/execution_context.h index 3ba2af71e9..d894cd1b78 100644 --- a/src/include/execution/exec/execution_context.h +++ b/src/include/execution/exec/execution_context.h @@ -333,7 +333,10 @@ class EXPORT ExecutionContext { */ void ClearHooks() { hooks_.clear(); } + // TODO(Kyle): Why is this friend class declaration not working? public: + friend class ExecutionContextBuilder; + /** * Construct a new ExecutionContext instance. * @@ -374,7 +377,6 @@ class EXPORT ExecutionContext { mem_pool_{std::make_unique(common::ManagedPointer(mem_tracker_))}, thread_state_container_{std::make_unique(mem_pool_.get())} {} - friend class ExecutionContextBuilder; private: /** * The query identifier diff --git a/src/include/util/query_exec_util.h b/src/include/util/query_exec_util.h index 637332ccac..f8689ae429 100644 --- a/src/include/util/query_exec_util.h +++ b/src/include/util/query_exec_util.h @@ -11,6 +11,7 @@ #include "execution/compiler/executable_query.h" #include "execution/exec/execution_settings.h" #include "execution/exec_defs.h" +#include "execution/vm/execution_mode.h" #include "planner/plannodes/output_schema.h" #include "type/type_id.h" @@ -217,20 +218,39 @@ class QueryExecUtil { std::string GetError() { return error_msg_; } private: + /** Reset the error message stored by the instance */ void ResetError(); + + /** + * Set the database OID. + * @param db_oid The database OID + */ void SetDatabase(catalog::db_oid_t db_oid); + /** The transaction manager */ common::ManagedPointer txn_manager_; + /** The catalog accessor */ common::ManagedPointer catalog_; + /** The settings manager */ common::ManagedPointer settings_; + /** The statistics storage */ common::ManagedPointer stats_; + + /** The timeout for query optimizatio */ uint64_t optimizer_timeout_; - /** Database being accessed */ + /** Idenditifer for the database being accessed */ catalog::db_oid_t db_oid_{catalog::INVALID_DATABASE_OID}; + + /** `true` if the QueryExecUtil instance owns the transaction, `false` otherwise */ bool own_txn_ = false; + /** The transaction context */ transaction::TransactionContext *txn_ = nullptr; + /** The query execution mode */ + // TODO(Kyle): Need a way to not just hard-code this value + execution::vm::ExecutionMode execution_mode_{execution::vm::ExecutionMode::Interpret}; + /** * Information about cached executable queries * Assumes that the query string is a unique identifier. @@ -238,9 +258,7 @@ class QueryExecUtil { std::unordered_map> schemas_; std::unordered_map> exec_queries_; - /** - * Stores the most recently encountered error. - */ + /** Stores the most recently encountered error */ std::string error_msg_; }; diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index 1f2805a845..de60977f5b 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -583,28 +583,18 @@ TrafficCopResult TrafficCop::RunExecutableQuery(const common::ManagedPointerGetDatabaseOid()) - .WithExecutionMode(execution_mode_) - .WithExecutionSettings(exec_settings) - .WithTxnContext(connection_ctx->Transaction()) - .WithOutputSchema(physical_plan->GetOutputSchema()) - .WithOutputCallback(std::move(callback)) - .WithCatalogAccessor(connection_ctx->Accessor()) - .WithMetricsManager(metrics) - .WithReplicationManager(replication_manager_) - .WithRecoveryManager(recovery_manager_) - .WithQueryParametersFrom(*portal->Parameters()) - .Build(); - - // auto exec_ctx = std::make_unique( - // connection_ctx->GetDatabaseOid(), connection_ctx->Transaction(), callback, physical_plan->GetOutputSchema().Get(), - // connection_ctx->Accessor(), exec_settings, metrics, replication_manager_, recovery_manager_); - - // std::vector> params{}; - // params.reserve(portal->Parameters()->size()); - // std::transform(portal->Parameters()->cbegin(), portal->Parameters()->cend(), std::back_inserter(params), - // [](const parser::ConstantValueExpression &cve) { return common::ManagedPointer{cve.SqlValue()}; }); - // exec_ctx->SetParams(common::ManagedPointer(¶ms)); + .WithDatabaseOID(connection_ctx->GetDatabaseOid()) + .WithExecutionMode(execution_mode_) + .WithExecutionSettings(exec_settings) + .WithTxnContext(connection_ctx->Transaction()) + .WithOutputSchema(physical_plan->GetOutputSchema()) + .WithOutputCallback(std::move(callback)) + .WithCatalogAccessor(connection_ctx->Accessor()) + .WithMetricsManager(metrics) + .WithReplicationManager(replication_manager_) + .WithRecoveryManager(recovery_manager_) + .WithQueryParametersFrom(*portal->Parameters()) + .Build(); const auto exec_query = portal->GetStatement()->GetExecutableQuery(); diff --git a/src/util/query_exec_util.cpp b/src/util/query_exec_util.cpp index 472f4bc3a9..efa9f8929c 100644 --- a/src/util/query_exec_util.cpp +++ b/src/util/query_exec_util.cpp @@ -281,17 +281,25 @@ bool QueryExecUtil::ExecuteQuery(const std::string &statement, TupleFunction tup // TODO(wz2): May want to thread the replication manager or recovery manager through execution::exec::OutputCallback callback = consumer; auto accessor = catalog_->GetAccessor(txn, db_oid_, DISABLED); - auto exec_ctx = std::make_unique( - db_oid_, txn, callback, schema, common::ManagedPointer(accessor), exec_settings, metrics, DISABLED, DISABLED); - // Must translate the ConstantValueExpressions to opaque sql::Val - std::vector> value_params{}; - value_params.reserve(params->size()); - std::transform(params->cbegin(), params->cend(), std::back_inserter(value_params), - [](const parser::ConstantValueExpression &cve) { return common::ManagedPointer{cve.SqlValue()}; }); - exec_ctx->SetParams(common::ManagedPointer(&value_params)); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithExecutionMode(execution_mode_) + .WithExecutionSettings(exec_settings) + .WithQueryParametersFrom(*params) + .WithTxnContext(txn) + .WithOutputSchema(common::ManagedPointer{schema}) + .WithOutputCallback(std::move(callback)) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(metrics) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); NOISEPAGE_ASSERT(!txn->MustAbort(), "Transaction should not be in must-abort state prior to executing"); + // TODO(Kyle): Right now it looks like the QueryExecUtil always runs queries in interpreted + // execution mode, regardless of how the setting is updated throughout the rest of the system, + // is this the intended behavior..? (unlikely) exec_queries_[statement]->Run(common::ManagedPointer(exec_ctx), execution::vm::ExecutionMode::Interpret); if (txn->MustAbort()) { // Return false to indicate that the query encountered a runtime error. From cff6bd0a768b61d3506dab97ac57c5d44bf25ab7 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Tue, 29 Jun 2021 10:28:39 -0400 Subject: [PATCH 055/139] fix access modifier in execution context --- src/execution/exec/execution_context.cpp | 8 ++++---- src/include/execution/exec/execution_context.h | 5 ++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/execution/exec/execution_context.cpp b/src/execution/exec/execution_context.cpp index 28b0949528..08dda6c836 100644 --- a/src/execution/exec/execution_context.cpp +++ b/src/execution/exec/execution_context.cpp @@ -20,10 +20,10 @@ std::unique_ptr ExecutionContextBuilder::Build() { NOISEPAGE_ASSERT(exec_settings_.has_value(), "Must specify execution settings."); NOISEPAGE_ASSERT(static_cast(catalog_accessor_), "Must specify catalog accessor."); // MetricsManager, ReplicationManager, and RecoveryManaged may be DISABLED - return std::make_unique(db_oid_, std::move(parameters_), exec_mode_.value(), - std::move(exec_settings_.value()), txn_, output_schema_, - std::move(output_callback_.value_or(nullptr)), catalog_accessor_, - metrics_manager_, replication_manager_, recovery_manager_); + return std::unique_ptr{ + new ExecutionContext{db_oid_, std::move(parameters_), exec_mode_.value(), std::move(exec_settings_.value()), txn_, + output_schema_, std::move(output_callback_.value_or(nullptr)), catalog_accessor_, + metrics_manager_, replication_manager_, recovery_manager_}}; } ExecutionContextBuilder &ExecutionContextBuilder::WithQueryParametersFrom( diff --git a/src/include/execution/exec/execution_context.h b/src/include/execution/exec/execution_context.h index d894cd1b78..94e1fcfe53 100644 --- a/src/include/execution/exec/execution_context.h +++ b/src/include/execution/exec/execution_context.h @@ -333,8 +333,7 @@ class EXPORT ExecutionContext { */ void ClearHooks() { hooks_.clear(); } - // TODO(Kyle): Why is this friend class declaration not working? - public: + private: friend class ExecutionContextBuilder; /** @@ -516,7 +515,7 @@ class ExecutionContextBuilder { * @return Builder reference for chaining */ ExecutionContextBuilder &WithOutputSchema(common::ManagedPointer output_schema) { - output_schema_ = output_schema_; + output_schema_ = output_schema; return *this; } From 4b069ee5abf79541b14c5d068027afe5fa9f45fb Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Tue, 29 Jun 2021 11:16:18 -0400 Subject: [PATCH 056/139] refactoring again to account for the fact that we don't always know the execution mode when constructing the execution context --- benchmark/runner/execution_runners.cpp | 16 ++--- src/execution/exec/execution_context.cpp | 42 ++++++++++--- .../execution/exec/execution_context.h | 62 ++++++++----------- src/traffic_cop/traffic_cop.cpp | 1 - src/util/query_exec_util.cpp | 3 +- test/execution/execution_context_test.cpp | 11 ++++ test/test_util/tpcc/workload_cached.cpp | 36 +++++++---- test/test_util/tpch/workload.cpp | 4 +- 8 files changed, 105 insertions(+), 70 deletions(-) create mode 100644 test/execution/execution_context_test.cpp diff --git a/benchmark/runner/execution_runners.cpp b/benchmark/runner/execution_runners.cpp index 7a2cc8a2d3..a81837abda 100644 --- a/benchmark/runner/execution_runners.cpp +++ b/benchmark/runner/execution_runners.cpp @@ -446,11 +446,10 @@ class ExecutionRunners : public benchmark::Fixture { auto exec_ctx = execution::exec::ExecutionContextBuilder() .WithDatabaseOID(db_oid) - .WithExecutionMode(ExecutionRunners::mode) .WithExecutionSettings(exec_settings) .WithTxnContext(common::ManagedPointer{txn}) - .WithOutputCallback(execution::exec::NoOpResultConsumer{}) .WithOutputSchema(out_plan->GetOutputSchema()) + .WithOutputCallback(execution::exec::NoOpResultConsumer{}) .WithCatalogAccessor(common::ManagedPointer{accessor}) .WithMetricsManager(metrics_manager_) .WithReplicationManager(DISABLED) @@ -501,9 +500,10 @@ class ExecutionRunners : public benchmark::Fixture { auto exec_ctx = execution::exec::ExecutionContextBuilder() .WithDatabaseOID(db_oid) - .WithExecutionMode(ExecutionRunners::mode) .WithExecutionSettings(exec_settings) .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) + .WithOutputCallback(execution::exec::ExecutionContext::NULL_OUTPUT_CALLBACK) .WithCatalogAccessor(common::ManagedPointer{accessor}) .WithMetricsManager(metrics_manager_) .WithReplicationManager(DISABLED) @@ -576,7 +576,6 @@ class ExecutionRunners : public benchmark::Fixture { auto exec_ctx = execution::exec::ExecutionContextBuilder() .WithDatabaseOID(db_oid) .WithQueryParametersFrom(parameters) - .WithExecutionMode(ExecutionRunners::mode) .WithExecutionSettings(exec_settings) .WithTxnContext(common::ManagedPointer{txn}) .WithOutputSchema(common::ManagedPointer{out_schema}) @@ -609,9 +608,10 @@ class ExecutionRunners : public benchmark::Fixture { auto exec_settings = GetExecutionSettings(); auto exec_ctx = execution::exec::ExecutionContextBuilder() .WithDatabaseOID(db_oid) - .WithExecutionMode(ExecutionRunners::mode) .WithExecutionSettings(exec_settings) .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) + .WithOutputCallback(execution::exec::ExecutionContext::NULL_OUTPUT_CALLBACK) .WithCatalogAccessor(common::ManagedPointer{accessor}) .WithMetricsManager(metrics_manager_) .WithReplicationManager(DISABLED) @@ -978,7 +978,6 @@ BENCHMARK_DEFINE_F(ExecutionRunners, SEQ0_OutputRunners)(benchmark::State &state auto exec_ctx = execution::exec::ExecutionContextBuilder() .WithDatabaseOID(db_oid) - .WithExecutionMode(ExecutionRunners::mode) .WithExecutionSettings(exec_settings) .WithTxnContext(common::ManagedPointer{txn}) .WithOutputSchema(common::ManagedPointer{schema}) @@ -1054,9 +1053,9 @@ void ExecutionRunners::ExecuteIndexOperation(benchmark::State *state, bool is_in auto exec_ctx = execution::exec::ExecutionContextBuilder() .WithDatabaseOID(db_oid) - .WithExecutionMode(ExecutionRunners::mode) .WithExecutionSettings(exec_settings) .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) .WithOutputCallback(std::move(callback)) .WithCatalogAccessor(common::ManagedPointer{accessor}) .WithMetricsManager(metrics_manager_) @@ -2086,9 +2085,10 @@ void InitializeRunnersState() { auto exec_settings = ExecutionRunners::GetExecutionSettings(); auto exec_ctx = execution::exec::ExecutionContextBuilder() .WithDatabaseOID(db_oid) - .WithExecutionMode(ExecutionRunners::mode) .WithExecutionSettings(exec_settings) .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) + .WithOutputCallback(execution::exec::ExecutionContext::NULL_OUTPUT_CALLBACK) .WithCatalogAccessor(common::ManagedPointer{accessor}) .WithMetricsManager(db_main->GetMetricsManager()) .WithReplicationManager(DISABLED) diff --git a/src/execution/exec/execution_context.cpp b/src/execution/exec/execution_context.cpp index 08dda6c836..434199f519 100644 --- a/src/execution/exec/execution_context.cpp +++ b/src/execution/exec/execution_context.cpp @@ -15,15 +15,41 @@ namespace noisepage::execution::exec { std::unique_ptr ExecutionContextBuilder::Build() { - NOISEPAGE_ASSERT(db_oid_ != catalog::INVALID_DATABASE_OID, "Must specify database OID."); - NOISEPAGE_ASSERT(exec_mode_.has_value(), "Must specify execution mode."); - NOISEPAGE_ASSERT(exec_settings_.has_value(), "Must specify execution settings."); - NOISEPAGE_ASSERT(static_cast(catalog_accessor_), "Must specify catalog accessor."); - // MetricsManager, ReplicationManager, and RecoveryManaged may be DISABLED + if (db_oid_ == catalog::INVALID_DATABASE_OID) { + throw EXECUTION_EXCEPTION("Must specify database OID.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!exec_settings_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify exection settings.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!txn_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify a transaction context.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!output_schema_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify output schema.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!output_callback_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify output callback.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!catalog_accessor_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify catalog accessor.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!metrics_manager_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify metrics manager.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!replication_manager_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify replication manager.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!recovery_manager_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify recovery manager.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + + // Query parameters (parameters_) is not validated because + // this defaults to an empty collection + return std::unique_ptr{ - new ExecutionContext{db_oid_, std::move(parameters_), exec_mode_.value(), std::move(exec_settings_.value()), txn_, - output_schema_, std::move(output_callback_.value_or(nullptr)), catalog_accessor_, - metrics_manager_, replication_manager_, recovery_manager_}}; + new ExecutionContext{db_oid_, std::move(parameters_), std::move(exec_settings_.value()), txn_.value(), + output_schema_.value(), std::move(output_callback_.value()), catalog_accessor_.value(), + metrics_manager_.value(), replication_manager_.value(), recovery_manager_.value()}}; } ExecutionContextBuilder &ExecutionContextBuilder::WithQueryParametersFrom( diff --git a/src/include/execution/exec/execution_context.h b/src/include/execution/exec/execution_context.h index 94e1fcfe53..d9d49328bb 100644 --- a/src/include/execution/exec/execution_context.h +++ b/src/include/execution/exec/execution_context.h @@ -100,6 +100,15 @@ class EXPORT ExecutionContext { /** @return The transaction associated with this execution context */ common::ManagedPointer GetTxn() { return txn_; } + /** @return The execution mode for the execution context */ + vm::ExecutionMode GetExecutionMode() const { return execution_mode_; } + + /** + * Set the execution mode for the execution context. + * @param execution_mode The desired execution mode + */ + void SetExecutionMode(const vm::ExecutionMode execution_mode) { execution_mode_ = execution_mode; } + /** @return The execution settings. */ const exec::ExecutionSettings &GetExecutionSettings() const { return execution_settings_; } @@ -249,13 +258,12 @@ class EXPORT ExecutionContext { // to the "base" set of parameters for the query, otherwise, grab // the parameter at the specified index from the top of the runtime // parameters stack. - if (runtime_parameters_.empty()) { - NOISEPAGE_ASSERT(index < parameters_.size(), "ExecutionContext::GetParam() index out of range"); - return parameters_[index]; - } else { + if (!runtime_parameters_.empty()) { NOISEPAGE_ASSERT(index < runtime_parameters_.top().size(), "ExecutionContext::GetParam() index out of range."); return runtime_parameters_.top()[index]; } + NOISEPAGE_ASSERT(index < parameters_.size(), "ExecutionContext::GetParam() index out of range"); + return parameters_[index]; } /* -------------------------------------------------------------------------- @@ -333,6 +341,12 @@ class EXPORT ExecutionContext { */ void ClearHooks() { hooks_.clear(); } + public: + /** An empty output schema */ + constexpr static const std::nullptr_t NULL_OUTPUT_SCHEMA{nullptr}; + /** An empty output callback */ + constexpr static const std::nullptr_t NULL_OUTPUT_CALLBACK{nullptr}; + private: friend class ExecutionContextBuilder; @@ -343,7 +357,6 @@ class EXPORT ExecutionContext { * * @param db_oid The OID of the database * @param parameters The query parameters - * @param execution_mode The query execution mode * @param execution_settings The execution settings to run with * @param txn The transaction used by this query * @param output_schema The output schema @@ -354,7 +367,7 @@ class EXPORT ExecutionContext { * @param recovery_manager The recovery manager that handles both recovery and application of replication records. */ ExecutionContext(const catalog::db_oid_t db_oid, std::vector> &¶meters, - vm::ExecutionMode execution_mode, exec::ExecutionSettings &&execution_settings, + exec::ExecutionSettings &&execution_settings, const common::ManagedPointer txn, const common::ManagedPointer output_schema, OutputCallback &&output_callback, const common::ManagedPointer accessor, @@ -363,7 +376,6 @@ class EXPORT ExecutionContext { const common::ManagedPointer recovery_manager) : db_oid_{db_oid}, parameters_{std::move(parameters)}, - execution_mode_{execution_mode}, execution_settings_{execution_settings}, txn_{txn}, schema_{output_schema}, @@ -462,16 +474,6 @@ class ExecutionContextBuilder { /** @return The completed ExecutionContext instance */ std::unique_ptr Build(); - /** - * Set the execution mode for the execution context. - * @param mode The execution mode - * @return Builder reference for chaining - */ - ExecutionContextBuilder &WithExecutionMode(const vm::ExecutionMode mode) { - exec_mode_.emplace(mode); - return *this; - } - /** * Set the query parameters for the execution context. * @param parameters The query parameters @@ -539,23 +541,13 @@ class ExecutionContextBuilder { return *this; } - /** - * Set the execution settings for the execution context. - * @param exec_settings The execution settings - * @return Builder reference for chaining - */ - ExecutionContextBuilder &WithExecutionSettings(exec::ExecutionSettings &&exec_settings) { - exec_settings_.emplace(std::move(exec_settings)); - return *this; - } - /** * Set the execution settings for the execution context. * @param exec_settings The execution settings * @return Builder reference for chaining */ ExecutionContextBuilder &WithExecutionSettings(exec::ExecutionSettings exec_settings) { - exec_settings_.emplace(std::move(exec_settings)); + exec_settings_.emplace(exec_settings); return *this; } @@ -591,8 +583,6 @@ class ExecutionContextBuilder { } private: - /** The query execution mode */ - std::optional exec_mode_; /** The query execution settings */ std::optional exec_settings_; /** The query parmeters */ @@ -600,19 +590,19 @@ class ExecutionContextBuilder { /** The database OID */ catalog::db_oid_t db_oid_{catalog::INVALID_DATABASE_OID}; /** The associated transaction */ - common::ManagedPointer txn_; + std::optional> txn_; /** The output callback */ std::optional output_callback_; /** The output schema */ - common::ManagedPointer output_schema_{nullptr}; + std::optional> output_schema_{nullptr}; /** The catalog accessor */ - common::ManagedPointer catalog_accessor_; + std::optional> catalog_accessor_; /** The metrics manager */ - common::ManagedPointer metrics_manager_; + std::optional> metrics_manager_; /** The replication manager */ - common::ManagedPointer replication_manager_; + std::optional> replication_manager_; /** The recovery manager */ - common::ManagedPointer recovery_manager_; + std::optional> recovery_manager_; }; } // namespace noisepage::execution::exec diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index de60977f5b..8f1f355abe 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -584,7 +584,6 @@ TrafficCopResult TrafficCop::RunExecutableQuery(const common::ManagedPointerGetDatabaseOid()) - .WithExecutionMode(execution_mode_) .WithExecutionSettings(exec_settings) .WithTxnContext(connection_ctx->Transaction()) .WithOutputSchema(physical_plan->GetOutputSchema()) diff --git a/src/util/query_exec_util.cpp b/src/util/query_exec_util.cpp index efa9f8929c..90b881c852 100644 --- a/src/util/query_exec_util.cpp +++ b/src/util/query_exec_util.cpp @@ -284,9 +284,7 @@ bool QueryExecUtil::ExecuteQuery(const std::string &statement, TupleFunction tup auto exec_ctx = execution::exec::ExecutionContextBuilder() .WithDatabaseOID(db_oid_) - .WithExecutionMode(execution_mode_) .WithExecutionSettings(exec_settings) - .WithQueryParametersFrom(*params) .WithTxnContext(txn) .WithOutputSchema(common::ManagedPointer{schema}) .WithOutputCallback(std::move(callback)) @@ -294,6 +292,7 @@ bool QueryExecUtil::ExecuteQuery(const std::string &statement, TupleFunction tup .WithMetricsManager(metrics) .WithReplicationManager(DISABLED) .WithRecoveryManager(DISABLED) + .WithQueryParametersFrom(*params) .Build(); NOISEPAGE_ASSERT(!txn->MustAbort(), "Transaction should not be in must-abort state prior to executing"); diff --git a/test/execution/execution_context_test.cpp b/test/execution/execution_context_test.cpp new file mode 100644 index 0000000000..d764401229 --- /dev/null +++ b/test/execution/execution_context_test.cpp @@ -0,0 +1,11 @@ +#include "execution/exec/execution_context.h" + +#include "execution/compiled_tpl_test.h" + +namespace noisepage::execution::test { + +class ExecutionContextTest : public TplTest {}; + +TEST_F(ExecutionContextTest, ItWorks) { EXPECT_TRUE(true); } + +} // namespace noisepage::execution::test diff --git a/test/test_util/tpcc/workload_cached.cpp b/test/test_util/tpcc/workload_cached.cpp index 0d93088a95..25150af99c 100644 --- a/test/test_util/tpcc/workload_cached.cpp +++ b/test/test_util/tpcc/workload_cached.cpp @@ -78,9 +78,17 @@ void WorkloadCached::LoadTPCCQueries(const std::vector &txn_names) nullptr) ->TakePlanNodeOwnership(); - auto exec_ctx = std::make_unique( - db_oid_, common::ManagedPointer(txn), nullptr, nullptr, common::ManagedPointer(accessor), exec_settings_, - db_main_->GetMetricsManager(), DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithExecutionSettings(exec_settings_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) + .WithOutputCallback(execution::exec::ExecutionContext::NULL_OUTPUT_CALLBACK) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // generate executable query and emplace it into the vector; break down here auto exec_query = @@ -114,16 +122,18 @@ void WorkloadCached::Execute(int8_t worker_id, uint32_t num_precomputed_txns_per auto accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_oid_, DISABLED); for (const auto &query : queries_.find(txn_names_[index[counter]])->second) { - execution::exec::ExecutionContext exec_ctx{db_oid_, - common::ManagedPointer(txn), - nullptr, - nullptr, // FIXME: Get the correct output later - common::ManagedPointer(accessor), - exec_settings_, - db_main_->GetMetricsManager(), - DISABLED, - DISABLED}; - query->Run(common::ManagedPointer(&exec_ctx), mode); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithExecutionSettings(exec_settings_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) + .WithOutputCallback(execution::exec::ExecutionContext::NULL_OUTPUT_CALLBACK) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); + query->Run(common::ManagedPointer{exec_ctx}, mode); } counter = counter == num_queries - 1 ? 0 : counter + 1; txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); diff --git a/test/test_util/tpch/workload.cpp b/test/test_util/tpch/workload.cpp index 8b60e1b149..5a5bd0b3e7 100644 --- a/test/test_util/tpch/workload.cpp +++ b/test/test_util/tpch/workload.cpp @@ -43,9 +43,10 @@ Workload::Workload(common::ManagedPointer db_main, const std::string &db // Make the execution context auto exec_ctx = execution::exec::ExecutionContextBuilder() .WithDatabaseOID(db_oid_) - .WithExecutionMode(execution::vm::ExecutionMode::Interpret) .WithExecutionSettings(exec_settings_) .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) + .WithOutputCallback(execution::exec::ExecutionContext::NULL_OUTPUT_CALLBACK) .WithCatalogAccessor(common::ManagedPointer{accessor}) .WithMetricsManager(db_main->GetMetricsManager()) .WithReplicationManager(DISABLED) @@ -156,7 +157,6 @@ void Workload::Execute(int8_t worker_id, uint64_t execution_us_per_worker, uint6 auto exec_ctx = execution::exec::ExecutionContextBuilder() .WithDatabaseOID(db_oid_) - .WithExecutionMode(mode) .WithExecutionSettings(exec_settings_) .WithTxnContext(common::ManagedPointer{txn}) .WithOutputSchema(common::ManagedPointer{output_schema}) From 69a7004674b0e6be4c054b830c6a4d091917204f Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Tue, 29 Jun 2021 15:16:43 -0400 Subject: [PATCH 057/139] builds, lints, and passes clang tidy after refactor --- benchmark/runner/execution_runners.cpp | 14 +-- src/execution/compiler/executable_query.cpp | 1 + src/execution/exec/execution_context.cpp | 14 +-- .../execution/exec/execution_context.h | 27 ++--- src/include/execution/exec/output.h | 7 +- .../planner/plannodes/abstract_plan_node.h | 4 +- src/util/query_exec_util.cpp | 7 +- test/execution/compiler_test.cpp | 99 +++++++------------ test/include/execution/sql_test.h | 64 +++++++++++- test/test_util/tpch/workload.cpp | 2 +- util/execution/tpl.cpp | 46 +++++---- 11 files changed, 166 insertions(+), 119 deletions(-) diff --git a/benchmark/runner/execution_runners.cpp b/benchmark/runner/execution_runners.cpp index a81837abda..8f00493e30 100644 --- a/benchmark/runner/execution_runners.cpp +++ b/benchmark/runner/execution_runners.cpp @@ -540,7 +540,7 @@ class ExecutionRunners : public benchmark::Fixture { } void BenchmarkExecQuery(int64_t num_iters, execution::compiler::ExecutableQuery *exec_query, - planner::OutputSchema *out_schema, bool commit, + const planner::OutputSchema *out_schema, bool commit, std::vector> *params = &empty_params, execution::exec::ExecutionSettings *exec_settings_arg = nullptr) { transaction::TransactionContext *txn = nullptr; @@ -564,10 +564,6 @@ class ExecutionRunners : public benchmark::Fixture { exec_settings = *exec_settings_arg; } - // auto exec_ctx = std::make_unique( - // db_oid, common::ManagedPointer(txn), callback, out_schema, common::ManagedPointer(accessor), exec_settings, - // metrics_manager, DISABLED, DISABLED); - // TODO(Kyle): This makes an unnecessary copy of the query parameters std::vector parameters{}; if (static_cast(i) < params_ref.size()) { @@ -579,7 +575,7 @@ class ExecutionRunners : public benchmark::Fixture { .WithExecutionSettings(exec_settings) .WithTxnContext(common::ManagedPointer{txn}) .WithOutputSchema(common::ManagedPointer{out_schema}) - .WithOutputCallback(std::move(callback)) + .WithOutputCallback(callback) .WithCatalogAccessor(common::ManagedPointer{accessor}) .WithMetricsManager(metrics_manager_) .WithReplicationManager(DISABLED) @@ -969,7 +965,7 @@ BENCHMARK_DEFINE_F(ExecutionRunners, SEQ0_OutputRunners)(benchmark::State &state auto txn = txn_manager_->BeginTransaction(); auto accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_oid, DISABLED); - auto schema = std::make_unique(std::move(cols)); + auto schema = std::make_unique(std::move(cols)); auto exec_settings = GetExecutionSettings(); execution::compiler::ExecutableQuery::query_identifier.store(ExecutionRunners::query_id++); @@ -981,7 +977,7 @@ BENCHMARK_DEFINE_F(ExecutionRunners, SEQ0_OutputRunners)(benchmark::State &state .WithExecutionSettings(exec_settings) .WithTxnContext(common::ManagedPointer{txn}) .WithOutputSchema(common::ManagedPointer{schema}) - .WithOutputCallback(std::move(callback)) + .WithOutputCallback(callback) .WithCatalogAccessor(common::ManagedPointer{accessor}) .WithMetricsManager(metrics_manager_) .WithReplicationManager(DISABLED) @@ -1056,7 +1052,7 @@ void ExecutionRunners::ExecuteIndexOperation(benchmark::State *state, bool is_in .WithExecutionSettings(exec_settings) .WithTxnContext(common::ManagedPointer{txn}) .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) - .WithOutputCallback(std::move(callback)) + .WithOutputCallback(callback) .WithCatalogAccessor(common::ManagedPointer{accessor}) .WithMetricsManager(metrics_manager_) .WithReplicationManager(DISABLED) diff --git a/src/execution/compiler/executable_query.cpp b/src/execution/compiler/executable_query.cpp index 30e4624add..2ecbafada1 100644 --- a/src/execution/compiler/executable_query.cpp +++ b/src/execution/compiler/executable_query.cpp @@ -180,6 +180,7 @@ void ExecutableQuery::Run(common::ManagedPointer exec_ct auto query_state = std::make_unique(query_state_size_); *reinterpret_cast(query_state.get()) = exec_ctx.Get(); + exec_ctx->SetExecutionMode(mode); exec_ctx->SetQueryState(query_state.get()); exec_ctx->SetPipelineOperatingUnits(GetPipelineOperatingUnits()); exec_ctx->SetQueryId(query_id_); diff --git a/src/execution/exec/execution_context.cpp b/src/execution/exec/execution_context.cpp index 434199f519..2fd1af0c69 100644 --- a/src/execution/exec/execution_context.cpp +++ b/src/execution/exec/execution_context.cpp @@ -43,11 +43,10 @@ std::unique_ptr ExecutionContextBuilder::Build() { throw EXECUTION_EXCEPTION("Must specify recovery manager.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); } - // Query parameters (parameters_) is not validated because - // this defaults to an empty collection - + // Query parameters (parameters_) is not validated because default is empty collection + // ExecutionSettings exec_settings = exec_settings_.value(); return std::unique_ptr{ - new ExecutionContext{db_oid_, std::move(parameters_), std::move(exec_settings_.value()), txn_.value(), + new ExecutionContext{db_oid_, std::move(parameters_), exec_settings_.value(), txn_.value(), output_schema_.value(), std::move(output_callback_.value()), catalog_accessor_.value(), metrics_manager_.value(), replication_manager_.value(), recovery_manager_.value()}}; } @@ -64,18 +63,19 @@ ExecutionContextBuilder &ExecutionContextBuilder::WithQueryParametersFrom( } OutputBuffer *ExecutionContext::OutputBufferNew() { - if (schema_ == nullptr) { + if (output_schema_ == nullptr) { return nullptr; } // Use C++ placement new auto size = sizeof(OutputBuffer); auto *buffer = reinterpret_cast(mem_pool_->Allocate(size)); - new (buffer) OutputBuffer(mem_pool_.get(), schema_->GetColumns().size(), ComputeTupleSize(schema_), callback_); + new (buffer) OutputBuffer(mem_pool_.get(), output_schema_->GetColumns().size(), ComputeTupleSize(output_schema_), + output_callback_); return buffer; } -uint32_t ExecutionContext::ComputeTupleSize(common::ManagedPointer schema) { +uint32_t ExecutionContext::ComputeTupleSize(common::ManagedPointer schema) { uint32_t tuple_size = 0; for (const auto &col : schema->GetColumns()) { auto alignment = sql::ValUtil::GetSqlAlignment(col.GetType()); diff --git a/src/include/execution/exec/execution_context.h b/src/include/execution/exec/execution_context.h index d9d49328bb..f6725c5426 100644 --- a/src/include/execution/exec/execution_context.h +++ b/src/include/execution/exec/execution_context.h @@ -106,6 +106,11 @@ class EXPORT ExecutionContext { /** * Set the execution mode for the execution context. * @param execution_mode The desired execution mode + * + * NOTE: Most of the time one should avoid calling this + * function directly; the execution mode for the ExecutionContext + * instance is automatically set in ExecutableQuery::Run() to + * the execution mode in which the query is executed. */ void SetExecutionMode(const vm::ExecutionMode execution_mode) { execution_mode_ = execution_mode; } @@ -309,7 +314,7 @@ class EXPORT ExecutionContext { * @param schema The output schema * @return The size of tuple in this schema */ - static uint32_t ComputeTupleSize(common::ManagedPointer schema); + static uint32_t ComputeTupleSize(common::ManagedPointer schema); /* -------------------------------------------------------------------------- Hook Function Management @@ -367,10 +372,10 @@ class EXPORT ExecutionContext { * @param recovery_manager The recovery manager that handles both recovery and application of replication records. */ ExecutionContext(const catalog::db_oid_t db_oid, std::vector> &¶meters, - exec::ExecutionSettings &&execution_settings, + exec::ExecutionSettings execution_settings, const common::ManagedPointer txn, - const common::ManagedPointer output_schema, OutputCallback &&output_callback, - const common::ManagedPointer accessor, + const common::ManagedPointer output_schema, + OutputCallback &&output_callback, const common::ManagedPointer accessor, const common::ManagedPointer metrics_manager, const common::ManagedPointer replication_manager, const common::ManagedPointer recovery_manager) @@ -378,8 +383,8 @@ class EXPORT ExecutionContext { parameters_{std::move(parameters)}, execution_settings_{execution_settings}, txn_{txn}, - schema_{output_schema}, - callback_{std::move(output_callback)}, + output_schema_{output_schema}, + output_callback_{std::move(output_callback)}, accessor_{accessor}, metrics_manager_{metrics_manager}, replication_manager_{replication_manager}, @@ -412,11 +417,11 @@ class EXPORT ExecutionContext { const common::ManagedPointer txn_; /** The query output schema */ - common::ManagedPointer schema_{nullptr}; + common::ManagedPointer output_schema_{nullptr}; /** The query output buffer */ std::unique_ptr buffer_{nullptr}; /** The query output callback */ - const OutputCallback &callback_; + OutputCallback output_callback_; /** The query catalog accessor */ common::ManagedPointer accessor_; @@ -516,7 +521,7 @@ class ExecutionContextBuilder { * @param output_schema The output schema * @return Builder reference for chaining */ - ExecutionContextBuilder &WithOutputSchema(common::ManagedPointer output_schema) { + ExecutionContextBuilder &WithOutputSchema(common::ManagedPointer output_schema) { output_schema_ = output_schema; return *this; } @@ -526,7 +531,7 @@ class ExecutionContextBuilder { * @param output_callback The output callback * @return Builder reference for chaining */ - ExecutionContextBuilder &WithOutputCallback(OutputCallback &&output_callback) { + ExecutionContextBuilder &WithOutputCallback(OutputCallback output_callback) { output_callback_.emplace(std::move(output_callback)); return *this; } @@ -594,7 +599,7 @@ class ExecutionContextBuilder { /** The output callback */ std::optional output_callback_; /** The output schema */ - std::optional> output_schema_{nullptr}; + std::optional> output_schema_{nullptr}; /** The catalog accessor */ std::optional> catalog_accessor_; /** The metrics manager */ diff --git a/src/include/execution/exec/output.h b/src/include/execution/exec/output.h index 21742a7acd..dbb16d4591 100644 --- a/src/include/execution/exec/output.h +++ b/src/include/execution/exec/output.h @@ -139,7 +139,7 @@ class OutputWriter { * @param out packet writer to use * @param field_formats reference to the field formats for this query */ - OutputWriter(const common::ManagedPointer schema, + OutputWriter(const common::ManagedPointer schema, const common::ManagedPointer out, const std::vector &field_formats) : schema_(schema), out_(out), field_formats_(field_formats) {} @@ -166,8 +166,11 @@ class OutputWriter { * (parallel scan) */ std::mutex output_synchronization_; - const common::ManagedPointer schema_; + /** The output schema */ + const common::ManagedPointer schema_; + /** The output writer */ const common::ManagedPointer out_; + /** The field formats */ const std::vector &field_formats_; }; diff --git a/src/include/planner/plannodes/abstract_plan_node.h b/src/include/planner/plannodes/abstract_plan_node.h index 44b487434d..a07c20476f 100644 --- a/src/include/planner/plannodes/abstract_plan_node.h +++ b/src/include/planner/plannodes/abstract_plan_node.h @@ -166,7 +166,9 @@ class AbstractPlanNode { * @return output schema for the node. The output schema contains information on columns of the output of the plan * node operator */ - common::ManagedPointer GetOutputSchema() const { return common::ManagedPointer(output_schema_); } + common::ManagedPointer GetOutputSchema() const { + return common::ManagedPointer(output_schema_.get()); + } //===--------------------------------------------------------------------===// // Add child diff --git a/src/util/query_exec_util.cpp b/src/util/query_exec_util.cpp index 90b881c852..e61de1633e 100644 --- a/src/util/query_exec_util.cpp +++ b/src/util/query_exec_util.cpp @@ -189,7 +189,7 @@ bool QueryExecUtil::ExecuteDDL(const std::string &query, bool what_if) { // has run. We can't compile the query before the CreateIndexExecutor because codegen would have // no idea which index to insert into. execution::exec::ExecutionSettings settings{}; - common::ManagedPointer schema = out_plan->GetOutputSchema(); + const auto schema = out_plan->GetOutputSchema(); auto exec_query = execution::compiler::CompilationContext::Compile( *out_plan, settings, accessor.get(), execution::compiler::CompilationMode::OneShot, std::nullopt, statement->OptimizeResult()->GetPlanMetaData()); @@ -235,8 +235,7 @@ bool QueryExecUtil::CompileQuery(const std::string &statement, const common::ManagedPointer out_plan = result->OptimizeResult()->GetPlanNode(); NOISEPAGE_ASSERT(network::NetworkUtil::DMLQueryType(result->GetQueryType()), "ExecuteDML expects DML"); - common::ManagedPointer schema = out_plan->GetOutputSchema(); - + const auto schema = out_plan->GetOutputSchema(); auto exec_query = execution::compiler::CompilationContext::Compile( *out_plan, exec_settings, accessor.get(), execution::compiler::CompilationMode::OneShot, override_qid, result->OptimizeResult()->GetPlanMetaData()); @@ -253,7 +252,7 @@ bool QueryExecUtil::ExecuteQuery(const std::string &statement, TupleFunction tup NOISEPAGE_ASSERT(txn_ != nullptr, "Requires BeginTransaction() or UseTransaction()"); ResetError(); auto txn = common::ManagedPointer(txn_); - planner::OutputSchema *schema = schemas_[statement].get(); + const planner::OutputSchema *schema = schemas_[statement].get(); std::mutex sync_mutex; auto consumer = [&tuple_fn, &sync_mutex, schema](byte *tuples, uint32_t num_tuples, uint32_t tuple_size) { diff --git a/test/execution/compiler_test.cpp b/test/execution/compiler_test.cpp index d7fa17bb02..ad0e7e6de0 100644 --- a/test/execution/compiler_test.cpp +++ b/test/execution/compiler_test.cpp @@ -70,27 +70,6 @@ class CompilerTest : public SqlBasedTest { static constexpr vm::ExecutionMode MODE = vm::ExecutionMode::Interpret; }; -/** - * Transform the parameters vector supplied to an executable query. - * - * TODO(Kyle): This function is a hack that results from a refactor of - * the API for executable queries. Eventually, when we actually get - * around to refactoring the compiler tests, we should remove this and - * just fix the API itself. - * - * @param parameters The input parameters collection - * @return A non-owning collection of parameters in the format - * expected by the ExecutableQuery API - */ -static std::unique_ptr>> TransformParameters( - const std::vector ¶meters) { - auto params = std::make_unique>>(); - params->reserve(parameters.size()); - std::transform(parameters.cbegin(), parameters.cend(), std::back_inserter(*params), - [](const parser::ConstantValueExpression &cve) { return common::ManagedPointer{cve.SqlValue()}; }); - return params; -} - // NOLINTNEXTLINE TEST_F(CompilerTest, CompileFromSource) { util::Region region{"compiler_test"}; @@ -437,14 +416,13 @@ TEST_F(CompilerTest, SimpleSeqScanWithParamsTest) { exec::OutputPrinter printer(seq_scan->GetOutputSchema().Get()); MultiOutputCallback callback{std::vector{store, printer}}; exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = MakeExecCtx(&callback_fn, seq_scan->GetOutputSchema().Get()); - std::vector param_builder{}; - param_builder.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(100)); - param_builder.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(500)); - param_builder.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(3)); - auto params = TransformParameters(param_builder); - exec_ctx->SetParams(common::ManagedPointer(params)); + std::vector params{}; + params.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(100)); + params.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(500)); + params.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(3)); + + auto exec_ctx = MakeExecCtxWithParameters(params, &callback_fn, seq_scan->GetOutputSchema().Get()); // Run & Check auto executable = execution::compiler::CompilationContext::Compile(*seq_scan, exec_ctx->GetExecutionSettings(), @@ -3096,12 +3074,12 @@ TEST_F(CompilerTest, InsertIntoSelectWithParamTest) { // Make Exec Ctx MultiOutputCallback callback{std::vector{}}; exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = MakeExecCtx(&callback_fn, insert->GetOutputSchema().Get()); - std::vector params_builder{}; - params_builder.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(495)); - params_builder.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(505)); - auto params = TransformParameters(params_builder); - exec_ctx->SetParams(common::ManagedPointer(params)); + + std::vector params{}; + params.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(495)); + params.emplace_back(type::TypeId::INTEGER, execution::sql::Integer(505)); + + auto exec_ctx = MakeExecCtxWithParameters(params, &callback_fn, insert->GetOutputSchema().Get()); auto executable = execution::compiler::CompilationContext::Compile(*insert, exec_ctx->GetExecutionSettings(), exec_ctx->GetAccessor()); @@ -3321,33 +3299,31 @@ TEST_F(CompilerTest, SimpleInsertWithParamsTest) { // Make Exec Ctx MultiOutputCallback callback{std::vector{}}; exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = MakeExecCtx(&callback_fn, insert->GetOutputSchema().Get()); - std::vector params_builder{}; + std::vector params{}; // First parameter list auto str1_val = sql::ValueUtil::CreateStringVal(str1); - params_builder.emplace_back(type::TypeId::VARCHAR, str1_val.first, std::move(str1_val.second)); - params_builder.emplace_back(type::TypeId::DATE, sql::DateVal(date1.val_)); - params_builder.emplace_back(type::TypeId::REAL, sql::Real(real1)); - params_builder.emplace_back(type::TypeId::BOOLEAN, sql::BoolVal(bool1)); - params_builder.emplace_back(type::TypeId::TINYINT, sql::Integer(tinyint1)); - params_builder.emplace_back(type::TypeId::SMALLINT, sql::Integer(smallint1)); - params_builder.emplace_back(type::TypeId::INTEGER, sql::Integer(int1)); - params_builder.emplace_back(type::TypeId::BIGINT, sql::Integer(bigint1)); + params.emplace_back(type::TypeId::VARCHAR, str1_val.first, std::move(str1_val.second)); + params.emplace_back(type::TypeId::DATE, sql::DateVal(date1.val_)); + params.emplace_back(type::TypeId::REAL, sql::Real(real1)); + params.emplace_back(type::TypeId::BOOLEAN, sql::BoolVal(bool1)); + params.emplace_back(type::TypeId::TINYINT, sql::Integer(tinyint1)); + params.emplace_back(type::TypeId::SMALLINT, sql::Integer(smallint1)); + params.emplace_back(type::TypeId::INTEGER, sql::Integer(int1)); + params.emplace_back(type::TypeId::BIGINT, sql::Integer(bigint1)); // Second parameter list auto str2_val = sql::ValueUtil::CreateStringVal(str2); - params_builder.emplace_back(type::TypeId::VARCHAR, str2_val.first, std::move(str2_val.second)); - params_builder.emplace_back(type::TypeId::DATE, sql::DateVal(date2.val_)); - params_builder.emplace_back(type::TypeId::REAL, sql::Real(real2)); - params_builder.emplace_back(type::TypeId::BOOLEAN, sql::BoolVal(bool2)); - params_builder.emplace_back(type::TypeId::TINYINT, sql::Integer(tinyint2)); - params_builder.emplace_back(type::TypeId::SMALLINT, sql::Integer(smallint2)); - params_builder.emplace_back(type::TypeId::INTEGER, sql::Integer(int2)); - params_builder.emplace_back(type::TypeId::BIGINT, sql::Integer(bigint2)); - - auto params = TransformParameters(params_builder); - exec_ctx->SetParams(common::ManagedPointer(params)); + params.emplace_back(type::TypeId::VARCHAR, str2_val.first, std::move(str2_val.second)); + params.emplace_back(type::TypeId::DATE, sql::DateVal(date2.val_)); + params.emplace_back(type::TypeId::REAL, sql::Real(real2)); + params.emplace_back(type::TypeId::BOOLEAN, sql::BoolVal(bool2)); + params.emplace_back(type::TypeId::TINYINT, sql::Integer(tinyint2)); + params.emplace_back(type::TypeId::SMALLINT, sql::Integer(smallint2)); + params.emplace_back(type::TypeId::INTEGER, sql::Integer(int2)); + params.emplace_back(type::TypeId::BIGINT, sql::Integer(bigint2)); + + auto exec_ctx = MakeExecCtxWithParameters(params, &callback_fn, insert->GetOutputSchema().Get()); auto executable = execution::compiler::CompilationContext::Compile(*insert, exec_ctx->GetExecutionSettings(), exec_ctx->GetAccessor()); @@ -3527,14 +3503,15 @@ TEST_F(CompilerTest, SimpleInsertWithParamsTest) { exec::OutputPrinter printer(index_scan->GetOutputSchema().Get()); MultiOutputCallback callback{std::vector{store, printer}}; exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = MakeExecCtx(&callback_fn, index_scan->GetOutputSchema().Get()); - std::vector params_builder{}; + + std::vector params{}; auto str1_val = sql::ValueUtil::CreateStringVal(str1); auto str2_val = sql::ValueUtil::CreateStringVal(str2); - params_builder.emplace_back(type::TypeId::VARCHAR, str1_val.first, std::move(str1_val.second)); - params_builder.emplace_back(type::TypeId::VARCHAR, str2_val.first, std::move(str2_val.second)); - auto params = TransformParameters(params_builder); - exec_ctx->SetParams(common::ManagedPointer(params)); + params.emplace_back(type::TypeId::VARCHAR, str1_val.first, std::move(str1_val.second)); + params.emplace_back(type::TypeId::VARCHAR, str2_val.first, std::move(str2_val.second)); + + auto exec_ctx = MakeExecCtxWithParameters(params, &callback_fn, index_scan->GetOutputSchema().Get()); + auto executable = execution::compiler::CompilationContext::Compile(*index_scan, exec_ctx->GetExecutionSettings(), exec_ctx->GetAccessor()); executable->Run(common::ManagedPointer(exec_ctx), MODE); diff --git a/test/include/execution/sql_test.h b/test/include/execution/sql_test.h index da7e0a0a7b..988a38f80e 100644 --- a/test/include/execution/sql_test.h +++ b/test/include/execution/sql_test.h @@ -53,41 +53,99 @@ class SqlBasedTest : public TplTest { ~SqlBasedTest() override { txn_manager_->Commit(test_txn_, transaction::TransactionUtil::EmptyCallback, nullptr); } + /** @return The namespace OID */ catalog::namespace_oid_t NSOid() { return test_ns_oid_; } + /** @return The block store */ common::ManagedPointer BlockStore() { return block_store_; } + /** + * Construct and return an execution context. + * @param callback[optional] The output callback + * @param schema[optional] the output schema + * @return The execution context + */ std::unique_ptr MakeExecCtx(exec::OutputCallback *callback = nullptr, const planner::OutputSchema *schema = nullptr) { exec::OutputCallback empty = nullptr; const auto &callback_ref = (callback == nullptr) ? empty : *callback; - return std::make_unique(test_db_oid_, common::ManagedPointer(test_txn_), callback_ref, - schema, common::ManagedPointer(accessor_), *exec_settings_, - metrics_manager_, DISABLED, DISABLED); + return exec::ExecutionContextBuilder() + .WithDatabaseOID(test_db_oid_) + .WithExecutionSettings(*exec_settings_) + .WithTxnContext(common::ManagedPointer{test_txn_}) + .WithOutputSchema(common::ManagedPointer{schema}) + .WithOutputCallback(callback_ref) + .WithCatalogAccessor(common::ManagedPointer{accessor_}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); } + /** + * Construct and return an execution context. + * @param parameters The query execution parameters + * @param callback[optional] The output callback + * @param schema[optional] The output schema + * @return The execution context + */ + std::unique_ptr MakeExecCtxWithParameters( + const std::vector ¶meters, exec::OutputCallback *callback = nullptr, + const planner::OutputSchema *schema = nullptr) { + exec::OutputCallback empty = nullptr; + const auto &callback_ref = (callback == nullptr) ? empty : *callback; + return exec::ExecutionContextBuilder() + .WithDatabaseOID(test_db_oid_) + .WithExecutionSettings(*exec_settings_) + .WithTxnContext(common::ManagedPointer{test_txn_}) + .WithOutputSchema(common::ManagedPointer{schema}) + .WithOutputCallback(callback_ref) + .WithCatalogAccessor(common::ManagedPointer{accessor_}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .WithQueryParametersFrom(parameters) + .Build(); + } + + /** + * Generate the test tables for SQL tests. + * @param exec_ctx The execution context to use for table generation. + */ void GenerateTestTables(exec::ExecutionContext *exec_ctx) { sql::TableGenerator table_generator{exec_ctx, block_store_, test_ns_oid_}; table_generator.GenerateTestTables(); } + /** @return A new, owned catalog accessor */ std::unique_ptr MakeAccessor() { return catalog_->GetAccessor(common::ManagedPointer(test_txn_), test_db_oid_, DISABLED); } protected: + /** The catalog accessor */ std::unique_ptr accessor_; + /** The identifier for the test database */ catalog::db_oid_t test_db_oid_{0}; + /** The statistics storage */ common::ManagedPointer stats_storage_; + /** The test transaction context */ transaction::TransactionContext *test_txn_; + /** The transaction manager */ common::ManagedPointer txn_manager_; private: + /** The database instance */ std::unique_ptr db_main_; + /** The metrics manager instance */ common::ManagedPointer metrics_manager_; + /** The block store */ common::ManagedPointer block_store_; + /** The catalog instance */ common::ManagedPointer catalog_; + /** The identifier for the test namespace */ catalog::namespace_oid_t test_ns_oid_; + /** The execution settings instance */ std::unique_ptr exec_settings_; }; diff --git a/test/test_util/tpch/workload.cpp b/test/test_util/tpch/workload.cpp index 5a5bd0b3e7..41ca3483a0 100644 --- a/test/test_util/tpch/workload.cpp +++ b/test/test_util/tpch/workload.cpp @@ -160,7 +160,7 @@ void Workload::Execute(int8_t worker_id, uint64_t execution_us_per_worker, uint6 .WithExecutionSettings(exec_settings_) .WithTxnContext(common::ManagedPointer{txn}) .WithOutputSchema(common::ManagedPointer{output_schema}) - .WithOutputCallback(std::move(printer)) + .WithOutputCallback(printer) .WithCatalogAccessor(common::ManagedPointer{accessor}) .WithMetricsManager(db_main_->GetMetricsManager()) .WithReplicationManager(DISABLED) diff --git a/util/execution/tpl.cpp b/util/execution/tpl.cpp index a484b6ff88..162d60f3e5 100644 --- a/util/execution/tpl.cpp +++ b/util/execution/tpl.cpp @@ -94,24 +94,30 @@ static void CompileAndRun(const std::string &source, const std::string &name = " exec::ExecutionSettings exec_settings{}; exec::OutputPrinter printer(output_schema); exec::OutputCallback callback = printer; - exec::ExecutionContext exec_ctx{ - db_oid, common::ManagedPointer(txn), callback, output_schema, common::ManagedPointer(accessor), - exec_settings, db_main->GetMetricsManager(), DISABLED, DISABLED}; + // Add dummy parameters for tests - std::vector params_builder{}; - params_builder.emplace_back(type::TypeId::INTEGER, sql::Integer(37)); - params_builder.emplace_back(type::TypeId::REAL, sql::Real(37.73)); - params_builder.emplace_back(type::TypeId::DATE, sql::DateVal(sql::Date::FromYMD(1937, 3, 7))); + std::vector params{}; + params.emplace_back(type::TypeId::INTEGER, sql::Integer(37)); + params.emplace_back(type::TypeId::REAL, sql::Real(37.73)); + params.emplace_back(type::TypeId::DATE, sql::DateVal(sql::Date::FromYMD(1937, 3, 7))); auto string_val = sql::ValueUtil::CreateStringVal(std::string_view("37 Strings")); - params_builder.emplace_back(type::TypeId::VARCHAR, string_val.first, std::move(string_val.second)); - - std::vector> params{}; - std::transform(params_builder.cbegin(), params_builder.cend(), std::back_inserter(params), - [](const parser::ConstantValueExpression &expr) { return common::ManagedPointer{expr.SqlValue()}; }); - exec_ctx.SetParams(common::ManagedPointer{¶ms}); + params.emplace_back(type::TypeId::VARCHAR, string_val.first, std::move(string_val.second)); + + auto exec_ctx = exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(common::ManagedPointer{output_schema}) + .WithOutputCallback(callback) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .WithQueryParametersFrom(params) + .Build(); // Generate test tables - sql::TableGenerator table_generator{&exec_ctx, db_main->GetStorageLayer()->GetBlockStore(), ns_oid}; + sql::TableGenerator table_generator{exec_ctx.get(), db_main->GetStorageLayer()->GetBlockStore(), ns_oid}; table_generator.GenerateTestTables(); // Comment out to make more tables available at runtime // table_generator.GenerateTPCHTables(); @@ -197,7 +203,7 @@ static void CompileAndRun(const std::string &source, const std::string &name = " // { - exec_ctx.SetExecutionMode(static_cast(vm::ExecutionMode::Interpret)); + exec_ctx->SetExecutionMode(vm::ExecutionMode::Interpret); util::ScopedTimer timer(&interp_exec_ms); if (IS_SQL) { @@ -206,7 +212,7 @@ static void CompileAndRun(const std::string &source, const std::string &name = " EXECUTION_LOG_ERROR("Missing 'main' entry function with signature (*ExecutionContext)->int32"); return; } - EXECUTION_LOG_INFO("VM main() returned: {}", main(&exec_ctx)); + EXECUTION_LOG_INFO("VM main() returned: {}", main(exec_ctx.get())); } else { std::function main; if (!module->GetFunction("main", vm::ExecutionMode::Interpret, &main)) { @@ -221,7 +227,7 @@ static void CompileAndRun(const std::string &source, const std::string &name = " // Adaptive // - exec_ctx.SetExecutionMode(static_cast(vm::ExecutionMode::Adaptive)); + exec_ctx->SetExecutionMode(vm::ExecutionMode::Adaptive); util::ScopedTimer timer(&adaptive_exec_ms); if (IS_SQL) { @@ -230,7 +236,7 @@ static void CompileAndRun(const std::string &source, const std::string &name = " EXECUTION_LOG_ERROR("Missing 'main' entry function with signature (*ExecutionContext)->int32"); return; } - EXECUTION_LOG_INFO("ADAPTIVE main() returned: {}", main(&exec_ctx)); + EXECUTION_LOG_INFO("ADAPTIVE main() returned: {}", main(exec_ctx.get())); } else { std::function main; if (!module->GetFunction("main", vm::ExecutionMode::Adaptive, &main)) { @@ -244,7 +250,7 @@ static void CompileAndRun(const std::string &source, const std::string &name = " // JIT // { - exec_ctx.SetExecutionMode(static_cast(vm::ExecutionMode::Compiled)); + exec_ctx->SetExecutionMode(vm::ExecutionMode::Compiled); util::ScopedTimer timer(&jit_exec_ms); if (IS_SQL) { @@ -255,7 +261,7 @@ static void CompileAndRun(const std::string &source, const std::string &name = " } util::Timer x; x.Start(); - EXECUTION_LOG_INFO("JIT main() returned: {}", main(&exec_ctx)); + EXECUTION_LOG_INFO("JIT main() returned: {}", main(exec_ctx.get())); x.Stop(); EXECUTION_LOG_INFO("Jit exec: {} ms", x.GetElapsed()); } else { From c375a9b580c48a617cd9337f2db6b72e4865f502 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 2 Jul 2021 11:27:17 -0400 Subject: [PATCH 058/139] fix doxygen failure --- src/include/execution/exec/execution_context.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/execution/exec/execution_context.h b/src/include/execution/exec/execution_context.h index f6725c5426..862c02eb19 100644 --- a/src/include/execution/exec/execution_context.h +++ b/src/include/execution/exec/execution_context.h @@ -491,7 +491,7 @@ class ExecutionContextBuilder { /** * Set the query parameters for the execution context. - * @param param_expr The collection of expressions from which the query parameters are derived + * @param parameter_exprs The collection of expressions from which the query parameters are derived * @return Builder reference for chaining */ ExecutionContextBuilder &WithQueryParametersFrom(const std::vector ¶meter_exprs); From 45a17ef926dcdada2ee5c1f20b3f6a23cb542927 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 2 Jul 2021 15:45:33 -0400 Subject: [PATCH 059/139] refactor to separate execution context builder --- benchmark/runner/execution_runners.cpp | 1 + src/execution/exec/execution_context.cpp | 49 ----- .../exec/execution_context_builder.cpp | 58 ++++++ .../execution/exec/execution_context.h | 145 ------------- .../exec/execution_context_builder.h | 193 ++++++++++++++++++ src/traffic_cop/traffic_cop.cpp | 2 +- src/util/query_exec_util.cpp | 2 +- test/execution/compiler_test.cpp | 2 +- .../execution_context_builder_test.cpp | 169 +++++++++++++++ test/execution/execution_context_test.cpp | 11 - test/include/execution/sql_test.h | 1 + test/test_util/tpcc/workload_cached.cpp | 2 +- test/test_util/tpch/workload.cpp | 2 +- util/execution/tpl.cpp | 2 +- 14 files changed, 428 insertions(+), 211 deletions(-) create mode 100644 src/execution/exec/execution_context_builder.cpp create mode 100644 src/include/execution/exec/execution_context_builder.h create mode 100644 test/execution/execution_context_builder_test.cpp delete mode 100644 test/execution/execution_context_test.cpp diff --git a/benchmark/runner/execution_runners.cpp b/benchmark/runner/execution_runners.cpp index 8f00493e30..ab50883773 100644 --- a/benchmark/runner/execution_runners.cpp +++ b/benchmark/runner/execution_runners.cpp @@ -11,6 +11,7 @@ #include "common/scoped_timer.h" #include "execution/compiler/compilation_context.h" #include "execution/compiler/executable_query.h" +#include "execution/exec/execution_context_builder.h" #include "execution/exec/execution_settings.h" #include "execution/execution_util.h" #include "execution/sql/ddl_executors.h" diff --git a/src/execution/exec/execution_context.cpp b/src/execution/exec/execution_context.cpp index 2fd1af0c69..1ea02b98c2 100644 --- a/src/execution/exec/execution_context.cpp +++ b/src/execution/exec/execution_context.cpp @@ -5,7 +5,6 @@ #include "execution/sql/value.h" #include "metrics/metrics_manager.h" #include "metrics/metrics_store.h" -#include "parser/expression/constant_value_expression.h" #include "replication/primary_replication_manager.h" #include "self_driving/modeling/operating_unit.h" #include "self_driving/modeling/operating_unit_util.h" @@ -14,54 +13,6 @@ namespace noisepage::execution::exec { -std::unique_ptr ExecutionContextBuilder::Build() { - if (db_oid_ == catalog::INVALID_DATABASE_OID) { - throw EXECUTION_EXCEPTION("Must specify database OID.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); - } - if (!exec_settings_.has_value()) { - throw EXECUTION_EXCEPTION("Must specify exection settings.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); - } - if (!txn_.has_value()) { - throw EXECUTION_EXCEPTION("Must specify a transaction context.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); - } - if (!output_schema_.has_value()) { - throw EXECUTION_EXCEPTION("Must specify output schema.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); - } - if (!output_callback_.has_value()) { - throw EXECUTION_EXCEPTION("Must specify output callback.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); - } - if (!catalog_accessor_.has_value()) { - throw EXECUTION_EXCEPTION("Must specify catalog accessor.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); - } - if (!metrics_manager_.has_value()) { - throw EXECUTION_EXCEPTION("Must specify metrics manager.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); - } - if (!replication_manager_.has_value()) { - throw EXECUTION_EXCEPTION("Must specify replication manager.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); - } - if (!recovery_manager_.has_value()) { - throw EXECUTION_EXCEPTION("Must specify recovery manager.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); - } - - // Query parameters (parameters_) is not validated because default is empty collection - // ExecutionSettings exec_settings = exec_settings_.value(); - return std::unique_ptr{ - new ExecutionContext{db_oid_, std::move(parameters_), exec_settings_.value(), txn_.value(), - output_schema_.value(), std::move(output_callback_.value()), catalog_accessor_.value(), - metrics_manager_.value(), replication_manager_.value(), recovery_manager_.value()}}; -} - -ExecutionContextBuilder &ExecutionContextBuilder::WithQueryParametersFrom( - const std::vector ¶meter_exprs) { - NOISEPAGE_ASSERT(parameters_.empty(), "Attempt to initialize query parameters more than once."); - parameters_.reserve(parameter_exprs.size()); - std::transform(parameter_exprs.cbegin(), parameter_exprs.cend(), std::back_inserter(parameters_), - [](const parser::ConstantValueExpression &expr) -> common::ManagedPointer { - return common::ManagedPointer{expr.SqlValue()}; - }); - return *this; -} - OutputBuffer *ExecutionContext::OutputBufferNew() { if (output_schema_ == nullptr) { return nullptr; diff --git a/src/execution/exec/execution_context_builder.cpp b/src/execution/exec/execution_context_builder.cpp new file mode 100644 index 0000000000..318af2f4c5 --- /dev/null +++ b/src/execution/exec/execution_context_builder.cpp @@ -0,0 +1,58 @@ +#include "execution/exec/execution_context_builder.h" + +#include "common/error/error_code.h" +#include "common/error/exception.h" +#include "common/macros.h" +#include "execution/exec/execution_context.h" +#include "parser/expression/constant_value_expression.h" + +namespace noisepage::execution::exec { + +std::unique_ptr ExecutionContextBuilder::Build() { + if (db_oid_ == catalog::INVALID_DATABASE_OID) { + throw EXECUTION_EXCEPTION("Must specify database OID.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!exec_settings_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify exection settings.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!txn_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify a transaction context.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!output_schema_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify output schema.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!output_callback_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify output callback.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!catalog_accessor_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify catalog accessor.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!metrics_manager_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify metrics manager.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!replication_manager_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify replication manager.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!recovery_manager_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify recovery manager.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + + // Query parameters (parameters_) is not validated because default is empty collection + // ExecutionSettings exec_settings = exec_settings_.value(); + return std::unique_ptr{ + new ExecutionContext{db_oid_, std::move(parameters_), exec_settings_.value(), txn_.value(), + output_schema_.value(), std::move(output_callback_.value()), catalog_accessor_.value(), + metrics_manager_.value(), replication_manager_.value(), recovery_manager_.value()}}; +} + +ExecutionContextBuilder &ExecutionContextBuilder::WithQueryParametersFrom( + const std::vector ¶meter_exprs) { + NOISEPAGE_ASSERT(parameters_.empty(), "Attempt to initialize query parameters more than once."); + parameters_.reserve(parameter_exprs.size()); + std::transform(parameter_exprs.cbegin(), parameter_exprs.cend(), std::back_inserter(parameters_), + [](const parser::ConstantValueExpression &expr) -> common::ManagedPointer { + return common::ManagedPointer{expr.SqlValue()}; + }); + return *this; +} +} // namespace noisepage::execution::exec diff --git a/src/include/execution/exec/execution_context.h b/src/include/execution/exec/execution_context.h index 862c02eb19..8ea9869b82 100644 --- a/src/include/execution/exec/execution_context.h +++ b/src/include/execution/exec/execution_context.h @@ -465,149 +465,4 @@ class EXPORT ExecutionContext { /** The runtime parameter stack */ std::stack>> runtime_parameters_; }; - -/** - * The ExecutionContextBuilder class implements a builder for ExecutionContext. - */ -class ExecutionContextBuilder { - public: - /** - * Construct a new ExecutionContextBuilder. - */ - ExecutionContextBuilder() = default; - - /** @return The completed ExecutionContext instance */ - std::unique_ptr Build(); - - /** - * Set the query parameters for the execution context. - * @param parameters The query parameters - * @return Builder reference for chaining - */ - ExecutionContextBuilder &WithQueryParameters(std::vector> &¶meters) { - parameters_ = std::move(parameters); - return *this; - } - - /** - * Set the query parameters for the execution context. - * @param parameter_exprs The collection of expressions from which the query parameters are derived - * @return Builder reference for chaining - */ - ExecutionContextBuilder &WithQueryParametersFrom(const std::vector ¶meter_exprs); - - /** - * Set the database OID for the execution context. - * @param db_oid The database OID - * @return Builder reference for chaining - */ - ExecutionContextBuilder &WithDatabaseOID(const catalog::db_oid_t db_oid) { - db_oid_ = db_oid; - return *this; - } - - /** - * Set the transaction context for the execution context. - * @param txn The transaction context - * @return Builder reference for chaining - */ - ExecutionContextBuilder &WithTxnContext(common::ManagedPointer txn) { - txn_ = txn; - return *this; - } - - /** - * Set the output schema for the execution context. - * @param output_schema The output schema - * @return Builder reference for chaining - */ - ExecutionContextBuilder &WithOutputSchema(common::ManagedPointer output_schema) { - output_schema_ = output_schema; - return *this; - } - - /** - * Set the output callback for the execution context. - * @param output_callback The output callback - * @return Builder reference for chaining - */ - ExecutionContextBuilder &WithOutputCallback(OutputCallback output_callback) { - output_callback_.emplace(std::move(output_callback)); - return *this; - } - - /** - * Set the catalog accessor for the execution context. - * @param accessor The catalog accessor - * @return Builder reference for chaining - */ - ExecutionContextBuilder &WithCatalogAccessor(common::ManagedPointer accessor) { - catalog_accessor_ = accessor; - return *this; - } - - /** - * Set the execution settings for the execution context. - * @param exec_settings The execution settings - * @return Builder reference for chaining - */ - ExecutionContextBuilder &WithExecutionSettings(exec::ExecutionSettings exec_settings) { - exec_settings_.emplace(exec_settings); - return *this; - } - - /** - * Set the metrics manager for the execution context. - * @param metrics_manager The metrics manager - * @return Builder reference for chaining - */ - ExecutionContextBuilder &WithMetricsManager(common::ManagedPointer metrics_manager) { - metrics_manager_ = metrics_manager; - return *this; - } - - /** - * Set the replication manager for the execution context. - * @param replication_manager The replication manager - * @return Builder reference for chaining - */ - ExecutionContextBuilder &WithReplicationManager( - common::ManagedPointer replication_manager) { - replication_manager_ = replication_manager; - return *this; - } - - /** - * Set the recovery manager for the execution context. - * @param recovery_manager The recovery manager - * @return Builder reference for chaining - */ - ExecutionContextBuilder &WithRecoveryManager(common::ManagedPointer recovery_manager) { - recovery_manager_ = recovery_manager; - return *this; - } - - private: - /** The query execution settings */ - std::optional exec_settings_; - /** The query parmeters */ - std::vector> parameters_; - /** The database OID */ - catalog::db_oid_t db_oid_{catalog::INVALID_DATABASE_OID}; - /** The associated transaction */ - std::optional> txn_; - /** The output callback */ - std::optional output_callback_; - /** The output schema */ - std::optional> output_schema_{nullptr}; - /** The catalog accessor */ - std::optional> catalog_accessor_; - /** The metrics manager */ - std::optional> metrics_manager_; - /** The replication manager */ - std::optional> replication_manager_; - /** The recovery manager */ - std::optional> recovery_manager_; -}; - } // namespace noisepage::execution::exec diff --git a/src/include/execution/exec/execution_context_builder.h b/src/include/execution/exec/execution_context_builder.h new file mode 100644 index 0000000000..d522dd1510 --- /dev/null +++ b/src/include/execution/exec/execution_context_builder.h @@ -0,0 +1,193 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/managed_pointer.h" +#include "execution/exec/execution_settings.h" +#include "execution/exec/output.h" + +namespace noisepage::parser { +class ConstantValueExpression; +} // namespace noisepage::parser + +namespace noisepage::planner { +class OutputSchema; +} // namespace noisepage::planner + +namespace noisepage::catalog { +class CatalogAccessor; +} // namespace noisepage::catalog + +namespace noisepage::metrics { +class MetricsManager; +} // namespace noisepage::metrics + +namespace noisepage::replication { +class ReplicationManager; +} // namespace noisepage::replication + +namespace noisepage::storage { +class RecoveryManager; +} // namespace noisepage::storage + +namespace noisepage::execution::sql { +struct Val; +} // namespace noisepage::execution::sql + +namespace noisepage::transaction { +class TransactionContext; +} // namespace noisepage::transaction + +namespace noisepage::execution::exec { + +class ExecutionContext; +class ExecutionSettings; + +/** + * The ExecutionContextBuilder class implements a builder for ExecutionContext. + */ +class ExecutionContextBuilder { + public: + /** + * Construct a new ExecutionContextBuilder. + */ + ExecutionContextBuilder() = default; + + /** @return The completed ExecutionContext instance */ + std::unique_ptr Build(); + + /** + * Set the query parameters for the execution context. + * @param parameters The query parameters + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithQueryParameters(std::vector> &¶meters) { + parameters_ = std::move(parameters); + return *this; + } + + /** + * Set the query parameters for the execution context. + * @param parameter_exprs The collection of expressions from which the query parameters are derived + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithQueryParametersFrom(const std::vector ¶meter_exprs); + + /** + * Set the database OID for the execution context. + * @param db_oid The database OID + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithDatabaseOID(const catalog::db_oid_t db_oid) { + db_oid_ = db_oid; + return *this; + } + + /** + * Set the transaction context for the execution context. + * @param txn The transaction context + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithTxnContext(common::ManagedPointer txn) { + txn_ = txn; + return *this; + } + + /** + * Set the output schema for the execution context. + * @param output_schema The output schema + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithOutputSchema(common::ManagedPointer output_schema) { + output_schema_ = output_schema; + return *this; + } + + /** + * Set the output callback for the execution context. + * @param output_callback The output callback + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithOutputCallback(OutputCallback output_callback) { + output_callback_.emplace(std::move(output_callback)); + return *this; + } + + /** + * Set the catalog accessor for the execution context. + * @param accessor The catalog accessor + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithCatalogAccessor(common::ManagedPointer accessor) { + catalog_accessor_ = accessor; + return *this; + } + + /** + * Set the execution settings for the execution context. + * @param exec_settings The execution settings + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithExecutionSettings(exec::ExecutionSettings exec_settings) { + exec_settings_.emplace(exec_settings); + return *this; + } + + /** + * Set the metrics manager for the execution context. + * @param metrics_manager The metrics manager + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithMetricsManager(common::ManagedPointer metrics_manager) { + metrics_manager_ = metrics_manager; + return *this; + } + + /** + * Set the replication manager for the execution context. + * @param replication_manager The replication manager + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithReplicationManager( + common::ManagedPointer replication_manager) { + replication_manager_ = replication_manager; + return *this; + } + + /** + * Set the recovery manager for the execution context. + * @param recovery_manager The recovery manager + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithRecoveryManager(common::ManagedPointer recovery_manager) { + recovery_manager_ = recovery_manager; + return *this; + } + + private: + /** The query execution settings */ + std::optional exec_settings_; + /** The query parmeters */ + std::vector> parameters_; + /** The database OID */ + catalog::db_oid_t db_oid_{catalog::INVALID_DATABASE_OID}; + /** The associated transaction */ + std::optional> txn_; + /** The output callback */ + std::optional output_callback_; + /** The output schema */ + std::optional> output_schema_; + /** The catalog accessor */ + std::optional> catalog_accessor_; + /** The metrics manager */ + std::optional> metrics_manager_; + /** The replication manager */ + std::optional> replication_manager_; + /** The recovery manager */ + std::optional> recovery_manager_; +}; + +} // namespace noisepage::execution::exec diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index 8f1f355abe..d2f36b2026 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -14,7 +14,7 @@ #include "common/error/exception.h" #include "common/thread_context.h" #include "execution/compiler/compilation_context.h" -#include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/exec/execution_settings.h" #include "execution/exec/output.h" #include "execution/sql/ddl_executors.h" diff --git a/src/util/query_exec_util.cpp b/src/util/query_exec_util.cpp index e61de1633e..aef55cbe94 100644 --- a/src/util/query_exec_util.cpp +++ b/src/util/query_exec_util.cpp @@ -8,7 +8,7 @@ #include "catalog/catalog_accessor.h" #include "execution/compiler/compilation_context.h" #include "execution/compiler/executable_query.h" -#include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/sql/ddl_executors.h" #include "execution/vm/execution_mode.h" #include "loggers/common_logger.h" diff --git a/test/execution/compiler_test.cpp b/test/execution/compiler_test.cpp index ad0e7e6de0..38b2dcc1e4 100644 --- a/test/execution/compiler_test.cpp +++ b/test/execution/compiler_test.cpp @@ -15,7 +15,7 @@ #include "execution/compiler/expression_maker.h" #include "execution/compiler/output_checker.h" #include "execution/compiler/output_schema_util.h" -#include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/exec/output.h" #include "execution/execution_util.h" #include "execution/sema/sema.h" diff --git a/test/execution/execution_context_builder_test.cpp b/test/execution/execution_context_builder_test.cpp new file mode 100644 index 0000000000..b39c71295d --- /dev/null +++ b/test/execution/execution_context_builder_test.cpp @@ -0,0 +1,169 @@ +#include "execution/exec/execution_context_builder.h" + +#include "execution/compiled_tpl_test.h" +#include "execution/exec/execution_context.h" + +/** A dummy from which we can constuct null ManagedPointers */ +#define DUMMY nullptr + +namespace noisepage::execution::test { + +class ExecutionContextBuilderTest : public TplTest { + /** The OID with which the database OID is initialized */ + constexpr static const uint32_t DB_OID = 15721; + + public: + ExecutionContextBuilderTest() : db_oid_{DB_OID}, output_callback_{[](byte *, uint32_t, uint32_t) {}} {} // NOLINT + + /** @return The dummy database OID */ + catalog::db_oid_t GetDatabaseOID() const { return db_oid_; } + + /** @return The dummy execution settings */ + const exec::ExecutionSettings &GetExecutionSettings() const { return execution_settings_; } + + /** @return The dummy output callback */ + const exec::OutputCallback &GetOutputCallback() const { return output_callback_; } + + private: + /** A dummy database OID */ + catalog::db_oid_t db_oid_; + /** A dummy ExecutionSettings instance */ + exec::ExecutionSettings execution_settings_{}; + /** A dummy output callback */ + const exec::OutputCallback output_callback_; +}; + +TEST_F(ExecutionContextBuilderTest, DoesNotThrowWithAllConfigurationSpecified) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_NO_THROW(builder.Build()); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingDatabaseOID) { + auto builder = exec::ExecutionContextBuilder() + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingTransactionContext) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingExecutionSettings) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingOutputSchema) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingOutputCallback) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingCatalogAccessor) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingMetricsManager) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingReplicationManager) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingRecoveryManager) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +#undef DUMMY + +} // namespace noisepage::execution::test diff --git a/test/execution/execution_context_test.cpp b/test/execution/execution_context_test.cpp deleted file mode 100644 index d764401229..0000000000 --- a/test/execution/execution_context_test.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "execution/exec/execution_context.h" - -#include "execution/compiled_tpl_test.h" - -namespace noisepage::execution::test { - -class ExecutionContextTest : public TplTest {}; - -TEST_F(ExecutionContextTest, ItWorks) { EXPECT_TRUE(true); } - -} // namespace noisepage::execution::test diff --git a/test/include/execution/sql_test.h b/test/include/execution/sql_test.h index 988a38f80e..1928824416 100644 --- a/test/include/execution/sql_test.h +++ b/test/include/execution/sql_test.h @@ -5,6 +5,7 @@ #include #include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/exec/execution_settings.h" #include "execution/sql/sql.h" #include "execution/sql/vector.h" diff --git a/test/test_util/tpcc/workload_cached.cpp b/test/test_util/tpcc/workload_cached.cpp index 25150af99c..78caae385a 100644 --- a/test/test_util/tpcc/workload_cached.cpp +++ b/test/test_util/tpcc/workload_cached.cpp @@ -5,7 +5,7 @@ #include "binder/bind_node_visitor.h" #include "execution/compiler/executable_query.h" -#include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "main/db_main.h" #include "optimizer/cost_model/trivial_cost_model.h" #include "parser/expression/derived_value_expression.h" diff --git a/test/test_util/tpch/workload.cpp b/test/test_util/tpch/workload.cpp index 41ca3483a0..1244c55e11 100644 --- a/test/test_util/tpch/workload.cpp +++ b/test/test_util/tpch/workload.cpp @@ -5,7 +5,7 @@ #include "common/managed_pointer.h" #include "execution/compiler/output_schema_util.h" -#include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/sql/value_util.h" #include "execution/table_generator/table_generator.h" #include "main/db_main.h" diff --git a/util/execution/tpl.cpp b/util/execution/tpl.cpp index 162d60f3e5..9b5392d59f 100644 --- a/util/execution/tpl.cpp +++ b/util/execution/tpl.cpp @@ -16,7 +16,7 @@ #include "execution/ast/ast_dump.h" #include "execution/ast/ast_pretty_print.h" -#include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/exec/execution_settings.h" #include "execution/parsing/parser.h" #include "execution/parsing/scanner.h" From c6cb4354cdc2122e57a54db41e3aae945862c1ec Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 2 Jul 2021 16:05:47 -0400 Subject: [PATCH 060/139] fix bad reference in query execution utility --- src/util/query_exec_util.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/util/query_exec_util.cpp b/src/util/query_exec_util.cpp index aef55cbe94..11ba59a5e7 100644 --- a/src/util/query_exec_util.cpp +++ b/src/util/query_exec_util.cpp @@ -281,6 +281,9 @@ bool QueryExecUtil::ExecuteQuery(const std::string &statement, TupleFunction tup execution::exec::OutputCallback callback = consumer; auto accessor = catalog_->GetAccessor(txn, db_oid_, DISABLED); + // TODO(Kyle): Making this copy is far from ideal... + const std::vector query_parameters = + static_cast(params) ? *params : std::vector{}; auto exec_ctx = execution::exec::ExecutionContextBuilder() .WithDatabaseOID(db_oid_) .WithExecutionSettings(exec_settings) @@ -291,7 +294,7 @@ bool QueryExecUtil::ExecuteQuery(const std::string &statement, TupleFunction tup .WithMetricsManager(metrics) .WithReplicationManager(DISABLED) .WithRecoveryManager(DISABLED) - .WithQueryParametersFrom(*params) + .WithQueryParametersFrom(query_parameters) .Build(); NOISEPAGE_ASSERT(!txn->MustAbort(), "Transaction should not be in must-abort state prior to executing"); From 5f4189a5af7d7a81e7e36d6bd07c47dc98dcc3f8 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 8 Jul 2021 17:36:19 -0400 Subject: [PATCH 061/139] resolve memory corruption bug in TPL lambda code generation, still failing to produce expected results however --- sample_tpl/lambda0.tpl | 2 +- src/execution/vm/bytecode_generator.cpp | 42 +++-- src/execution/vm/bytecode_module.cpp | 5 +- src/execution/vm/llvm_engine.cpp | 24 +-- src/execution/vm/module.cpp | 10 +- src/include/execution/vm/bytecode_generator.h | 26 ++- src/include/execution/vm/bytecode_module.h | 19 +- util/execution/tpl.cpp | 164 +++++++++--------- 8 files changed, 166 insertions(+), 126 deletions(-) diff --git a/sample_tpl/lambda0.tpl b/sample_tpl/lambda0.tpl index c05499dc9e..9929041941 100644 --- a/sample_tpl/lambda0.tpl +++ b/sample_tpl/lambda0.tpl @@ -1,6 +1,6 @@ // Expected output: 2 -fun main(exec : *ExecutionContext) -> int32 { +fun main() -> int32 { // Lambda without capture var addOne = lambda [] (x: int32) -> int32 { return x + 1 diff --git a/src/execution/vm/bytecode_generator.cpp b/src/execution/vm/bytecode_generator.cpp index 52e095e215..b3dcd8e887 100644 --- a/src/execution/vm/bytecode_generator.cpp +++ b/src/execution/vm/bytecode_generator.cpp @@ -206,7 +206,7 @@ void BytecodeGenerator::VisitFunctionDecl(ast::FunctionDecl *node) { auto *func_type = node->TypeRepr()->GetType()->As(); // Allocate the function - auto *func_info = AllocateFunc(node->Name().GetData(), func_type); + auto *func_info = AllocateFunction(node->Name().GetData(), func_type); EnterFunction(func_info->GetId()); { @@ -218,7 +218,10 @@ void BytecodeGenerator::VisitFunctionDecl(ast::FunctionDecl *node) { Visit(node->Function()); } - // Execute the deferred actions for the function + // Execute the deferred actions for the function; + // in the current implementation, the only functionality + // that relies on deferred actions during code generation + // are TPL lambda expressions that generate closures for (auto &f : func_info->actions_) { f(); } @@ -232,10 +235,11 @@ void BytecodeGenerator::VisitLambdaExpr(ast::LambdaExpr *node) { if (!GetExecutionResult()->HasDestination()) { return; } + auto captures = GetCurrentFunction()->NewLocal(node->GetCaptureStructType(), node->GetName().GetString() + "captures"); auto fields = node->GetCaptureStructType()->As()->GetFieldsWithoutPadding(); - for (size_t i = 0; i < fields.size() - 1; i++) { + for (std::size_t i = 0; i < fields.size() - 1; i++) { auto field = fields[i]; ast::IdentifierExpr ident(node->Position(), field.name_); ident.SetType(field.type_->GetPointeeType()); @@ -248,10 +252,12 @@ void BytecodeGenerator::VisitLambdaExpr(ast::LambdaExpr *node) { } GetEmitter()->EmitAssign(Bytecode::Assign8, GetExecutionResult()->GetDestination(), captures.AddressOf()); - FunctionInfo *func_info = AllocateFunc(node->GetName().GetString(), func_type); + FunctionInfo *func_info = + AllocateFunction(node->GetName().GetString(), func_type, captures, node->GetCaptureStructType()); + (void)func_info; // Create a new deferred action for the current function - // that visits the body of the lambda; this actions is subsequently + // that visits the body of the lambda; this action is subsequently // executed when the function declaration itself is visited GetCurrentFunction()->DeferAction([=]() { func_info->captures_ = captures; @@ -4021,19 +4027,20 @@ void BytecodeGenerator::VisitMapTypeRepr(ast::MapTypeRepr *node) { NOISEPAGE_ASSERT(false, "Should not visit type-representation nodes!"); } -FunctionInfo *BytecodeGenerator::AllocateFunc(const std::string &func_name, ast::FunctionType *const func_type) { +FunctionInfo *BytecodeGenerator::AllocateFunction(const std::string &function_name, + ast::FunctionType *const function_type) { // Allocate function const auto func_id = static_cast(functions_.size()); - functions_.emplace_back(func_id, std::string(func_name), func_type); - FunctionInfo *func = &functions_.back(); + functions_.push_back(std::make_unique(func_id, function_name, function_type)); + FunctionInfo *func = functions_.back().get(); // Register return type - if (auto *return_type = func_type->GetReturnType(); !return_type->IsNilType()) { + if (auto *return_type = function_type->GetReturnType(); !return_type->IsNilType()) { func->NewParameterLocal(return_type->PointerTo(), "hiddenRv"); } // Register parameters - for (const auto ¶m : func_type->GetParams()) { + for (const auto ¶m : function_type->GetParams()) { if (param.type_->IsSqlValueType()) { func->NewParameterLocal(param.type_->PointerTo(), param.name_.GetData()); } else { @@ -4050,15 +4057,16 @@ FunctionInfo *BytecodeGenerator::AllocateFunc(const std::string &func_name, ast: return func; } -FunctionInfo *BytecodeGenerator::AllocateFunc(const std::string &func_name, ast::FunctionType *func_type, - LocalVar captures, ast::Type *capture_type) { +FunctionInfo *BytecodeGenerator::AllocateFunction(const std::string &function_name, + ast::FunctionType *const function_type, LocalVar captures, + ast::Type *capture_type) { // Allocate function const auto func_id = static_cast(functions_.size()); - functions_.emplace_back(func_id, func_name, func_type); - FunctionInfo *func = &functions_.back(); + functions_.push_back(std::make_unique(func_id, function_name, function_type)); + FunctionInfo *func = functions_.back().get(); // Register return type - if (auto *return_type = func_type->GetReturnType(); !return_type->IsNilType()) { + if (auto *return_type = function_type->GetReturnType(); !return_type->IsNilType()) { func->NewParameterLocal(return_type->PointerTo(), "hiddenRv"); } @@ -4066,7 +4074,8 @@ FunctionInfo *BytecodeGenerator::AllocateFunc(const std::string &func_name, ast: func->NewParameterLocal(capture_type->PointerTo(), "hiddenCaptures"); // Register parameters - for (const auto ¶m : func_type->GetParams()) { + for (const auto ¶m : function_type->GetParams()) { + // TODO(Kyle): Why do we never check for SQL value types here? func->NewParameterLocal(param.type_, param.name_.GetData()); } @@ -4184,6 +4193,7 @@ std::unique_ptr BytecodeGenerator::Compile(ast::AstNode *root, c return std::make_unique(name, std::move(generator.code_), std::move(generator.data_), std::move(generator.functions_), std::move(generator.static_locals_)); } + void BytecodeGenerator::VisitBuiltinCteScanCall(ast::CallExpr *call, ast::Builtin builtin) { LocalVar iterator = VisitExpressionForRValue(call->Arguments()[0]); switch (builtin) { diff --git a/src/execution/vm/bytecode_module.cpp b/src/execution/vm/bytecode_module.cpp index 8a97ed02b5..c307df5fd8 100644 --- a/src/execution/vm/bytecode_module.cpp +++ b/src/execution/vm/bytecode_module.cpp @@ -13,7 +13,8 @@ namespace noisepage::execution::vm { BytecodeModule::BytecodeModule(std::string name, std::vector &&code, std::vector &&data, - std::vector &&functions, std::vector &&static_locals) + std::vector> &&functions, + std::vector &&static_locals) : name_(std::move(name)), code_(std::move(code)), data_(std::move(data)), @@ -209,7 +210,7 @@ void BytecodeModule::Dump(std::ostream &os) const { // Functions for (const auto &func : functions_) { - PrettyPrintFunc(os, *this, func); + PrettyPrintFunc(os, *this, *func); } } diff --git a/src/execution/vm/llvm_engine.cpp b/src/execution/vm/llvm_engine.cpp index ca324af295..9803819c6a 100644 --- a/src/execution/vm/llvm_engine.cpp +++ b/src/execution/vm/llvm_engine.cpp @@ -201,7 +201,8 @@ llvm::Type *LLVMEngine::TypeMap::GetLLVMType(const ast::Type *type) { break; } case ast::Type::TypeId::MapType: { - // TODO(pmenon): me + // TODO(Kyle): Implement this + throw NOT_IMPLEMENTED_EXCEPTION("MapType Not Implemented"); break; } case ast::Type::TypeId::StructType: { @@ -213,6 +214,7 @@ llvm::Type *LLVMEngine::TypeMap::GetLLVMType(const ast::Type *type) { break; } case ast::Type::TypeId::LambdaType: { + // TODO(Kyle): Implement this throw NOT_IMPLEMENTED_EXCEPTION("LambdaType Not Implemented"); break; } @@ -227,9 +229,7 @@ llvm::Type *LLVMEngine::TypeMap::GetLLVMType(const ast::Type *type) { // NOISEPAGE_ASSERT(llvm_type != nullptr, "No LLVM type found!"); - iter->second = llvm_type; - return llvm_type; } @@ -569,9 +569,9 @@ void LLVMEngine::CompiledModuleBuilder::DeclareStaticLocals() { } void LLVMEngine::CompiledModuleBuilder::DeclareFunctions() { - for (const auto &func_info : tpl_module_.GetFunctionsInfo()) { - auto *func_type = llvm::cast(type_map_->GetLLVMType(func_info.GetFuncType())); - llvm_module_->getOrInsertFunction(func_info.GetName(), func_type); + for (const auto *func_info : tpl_module_.GetFunctionsInfo()) { + auto *func_type = llvm::cast(type_map_->GetLLVMType(func_info->GetFuncType())); + llvm_module_->getOrInsertFunction(func_info->GetName(), func_type); } } @@ -962,8 +962,8 @@ void LLVMEngine::CompiledModuleBuilder::DefineFunction(const FunctionInfo &func_ void LLVMEngine::CompiledModuleBuilder::DefineFunctions() { llvm::IRBuilder<> ir_builder(*context_); - for (const auto &func_info : tpl_module_.GetFunctionsInfo()) { - DefineFunction(func_info, &ir_builder); + for (const auto *func_info : tpl_module_.GetFunctionsInfo()) { + DefineFunction(*func_info, &ir_builder); } } @@ -1166,13 +1166,13 @@ void LLVMEngine::CompiledModule::Load(const BytecodeModule &module) { // all module functions into a handy cache. // - for (const auto &func : module.GetFunctionsInfo()) { - auto symbol = loader.getSymbol(func.GetName()); + for (const auto *func : module.GetFunctionsInfo()) { + auto symbol = loader.getSymbol(func->GetName()); if (symbol.getAddress() == 0) { // for Mac portability - symbol = loader.getSymbol("_" + func.GetName()); + symbol = loader.getSymbol("_" + func->GetName()); } - functions_[func.GetName()] = reinterpret_cast(symbol.getAddress()); + functions_[func->GetName()] = reinterpret_cast(symbol.getAddress()); NOISEPAGE_ASSERT(symbol.getAddress() != 0, "symbol came out to be badly defined or missing"); } diff --git a/src/execution/vm/module.cpp b/src/execution/vm/module.cpp index 5152a1e4ff..e8c5c8fa6e 100644 --- a/src/execution/vm/module.cpp +++ b/src/execution/vm/module.cpp @@ -51,8 +51,8 @@ Module::Module(std::unique_ptr bytecode_module, std::unique_ptr< bytecode_trampolines_(std::make_unique(bytecode_module_->GetFunctionCount())), metadata_(std::move(metadata)) { // Create the trampolines for all bytecode functions - for (const auto &func : bytecode_module_->GetFunctionsInfo()) { - CreateFunctionTrampoline(func.GetId()); + for (const auto *func : bytecode_module_->GetFunctionsInfo()) { + CreateFunctionTrampoline(func->GetId()); } // If a compiled module wasn't provided, all internal function stubs point to @@ -280,10 +280,10 @@ void Module::CompileToMachineCode() { // JIT completed successfully. For each function in the module, pull out its // compiled implementation into the function cache, atomically replacing any // previous implementation. - for (const auto &func_info : bytecode_module_->GetFunctionsInfo()) { - auto *jit_function = jit_module_->GetFunctionPointer(func_info.GetName()); + for (const auto *func_info : bytecode_module_->GetFunctionsInfo()) { + auto *jit_function = jit_module_->GetFunctionPointer(func_info->GetName()); NOISEPAGE_ASSERT(jit_function != nullptr, "Function not found!"); - functions_[func_info.GetId()].store(jit_function, std::memory_order_relaxed); + functions_[func_info->GetId()].store(jit_function, std::memory_order_relaxed); } }); } diff --git a/src/include/execution/vm/bytecode_generator.h b/src/include/execution/vm/bytecode_generator.h index 8e4ab8b41c..d46acf48fa 100644 --- a/src/include/execution/vm/bytecode_generator.h +++ b/src/include/execution/vm/bytecode_generator.h @@ -67,12 +67,24 @@ class BytecodeGenerator final : public ast::AstVisitor { class RValueResultScope; class BytecodePositionScope; - // Allocate a new function ID - FunctionInfo *AllocateFunc(const std::string &func_name, ast::FunctionType *func_type); + /** + * Allocate a new function. + * @param function_name The function name + * @param function_type The function type + * @return A non-owning pointer to the allocated function + */ + FunctionInfo *AllocateFunction(const std::string &function_name, ast::FunctionType *function_type); - // Allocate a new function ID with captures. - FunctionInfo *AllocateFunc(const std::string &func_name, ast::FunctionType *func_type, LocalVar captures, - ast::Type *capture_type); + /** + * Allocate a new function with captures (for lambda expressions). + * @param function_name The function name + * @param function_type The function type + * @param captures The local variable for the captures structure + * @param capture_type The type of the captures structure + * @return A non-owning pointer to the allocated function + */ + FunctionInfo *AllocateFunction(const std::string &function_name, ast::FunctionType *function_type, LocalVar captures, + ast::Type *capture_type); void VisitAbortTxn(ast::CallExpr *call); @@ -194,7 +206,7 @@ class BytecodeGenerator final : public ast::AstVisitor { void SetExecutionResult(ExpressionResultScope *exec_result) { execution_result_ = exec_result; } // Access the current function that's being generated. May be NULL. - FunctionInfo *GetCurrentFunction() { return &functions_[current_fn_]; } + FunctionInfo *GetCurrentFunction() { return functions_[current_fn_].get(); } void EnterFunction(FunctionId id) { current_fn_ = id; } @@ -211,7 +223,7 @@ class BytecodeGenerator final : public ast::AstVisitor { std::unordered_map static_string_cache_; // Information about all generated functions - std::vector functions_; + std::vector> functions_; // The ID of the current function. FunctionId current_fn_{0}; diff --git a/src/include/execution/vm/bytecode_module.h b/src/include/execution/vm/bytecode_module.h index 3689dc825f..c1d38cad77 100644 --- a/src/include/execution/vm/bytecode_module.h +++ b/src/include/execution/vm/bytecode_module.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -28,7 +29,7 @@ class BytecodeModule { * @param static_locals All statically allocated variables in the data section. */ BytecodeModule(std::string name, std::vector &&code, std::vector &&data, - std::vector &&functions, std::vector &&static_locals); + std::vector> &&functions, std::vector &&static_locals); /** * This class cannot be copied or moved. @@ -42,7 +43,7 @@ class BytecodeModule { const FunctionInfo *GetFuncInfoById(const FunctionId func_id) const { // Function IDs are dense, so the given ID must be in the range [0, # functions) NOISEPAGE_ASSERT(func_id < GetFunctionCount(), "Invalid function"); - return &functions_[func_id]; + return functions_[func_id].get(); } /** @@ -50,7 +51,8 @@ class BytecodeModule { * no such function exists, a NULL pointer is returned. */ const FunctionInfo *LookupFuncInfoByName(const std::string &name) const { - for (const FunctionInfo &info : functions_) { + for (const auto &function : functions_) { + const FunctionInfo &info = *function; if (info.GetName() == name) { return &info; } @@ -92,7 +94,14 @@ class BytecodeModule { /** * @return A const-view of the metadata for all functions in this module. */ - const std::vector &GetFunctionsInfo() const { return functions_; } + std::vector GetFunctionsInfo() const { + // TODO(Kyle): Cache these results? + std::vector functions{}; + functions.reserve(functions_.size()); + std::transform(functions_.cbegin(), functions_.cend(), std::back_inserter(functions), + [](const std::unique_ptr &f) { return f.get(); }); + return functions; + } /** * @return A const-view of the metadata for all static-locals in this module. @@ -156,7 +165,7 @@ class BytecodeModule { // The raw static data for ALL static data stored contiguously. const std::vector data_; // Metadata for all functions. - const std::vector functions_; + const std::vector> functions_; // Metadata for all static data. const std::vector static_locals_; }; diff --git a/util/execution/tpl.cpp b/util/execution/tpl.cpp index 9b5392d59f..fb3e22ca77 100644 --- a/util/execution/tpl.cpp +++ b/util/execution/tpl.cpp @@ -47,8 +47,17 @@ // CLI options // --------------------------------------------------------- +/** Enumeration for requested execution modes */ +enum ExecuteOn { VM, JIT, ADAPTIVE, ALL }; + // clang-format off llvm::cl::OptionCategory TPL_OPTIONS_CATEGORY("TPL Compiler Options", "Options for controlling the TPL compilation process."); // NOLINT +llvm::cl::opt EXECUTE_ON("execute-on", llvm::cl::desc("The execution mode"), llvm::cl::values( // NOLINT + clEnumVal(VM, ""), + clEnumVal(JIT, ""), + clEnumVal(ADAPTIVE, ""), + clEnumVal(ALL, "") +), llvm::cl::init(ALL), llvm::cl::cat(TPL_OPTIONS_CATEGORY)); llvm::cl::opt PRINT_AST("print-ast", llvm::cl::desc("Print the programs AST"), llvm::cl::cat(TPL_OPTIONS_CATEGORY)); // NOLINT llvm::cl::opt PRINT_TBC("print-tbc", llvm::cl::desc("Print the generated TPL Bytecode"), llvm::cl::cat(TPL_OPTIONS_CATEGORY)); // NOLINT llvm::cl::opt PRETTY_PRINT("pretty-print", llvm::cl::desc("Pretty-print the source from the parsed AST"), llvm::cl::cat(TPL_OPTIONS_CATEGORY)); // NOLINT @@ -66,6 +75,67 @@ namespace noisepage::execution { static constexpr const char *K_EXIT_KEYWORD = ".exit"; +/** + * + */ +static bool ShouldExecuteInMode(vm::ExecutionMode mode) { + auto mode_requested = [mode]() -> bool { + switch (mode) { + case vm::ExecutionMode::Interpret: + return EXECUTE_ON == VM; + case vm::ExecutionMode::Compiled: + return EXECUTE_ON == JIT; + case vm::ExecutionMode::Adaptive: + return EXECUTE_ON == ADAPTIVE; + default: + return false; + } + }; + return EXECUTE_ON == ALL || mode_requested(); +} + +/** + * Execute + */ +static double ExecuteInMode(vm::Module *module, vm::ExecutionMode mode, exec::ExecutionContext *exec_ctx) { + const char *mode_identifier = [mode]() { + switch (mode) { + case vm::ExecutionMode::Interpret: + return "VM"; + case vm::ExecutionMode::Compiled: + return "JIT"; + case vm::ExecutionMode::Adaptive: + return "ADAPTIVE"; + default: + UNREACHABLE("Unknown Execution Mode"); + } + }(); + + double exec_ms{}; + exec_ctx->SetExecutionMode(mode); + { + util::ScopedTimer timer(&exec_ms); + + if (IS_SQL) { + std::function main; + if (!module->GetFunction("main", mode, &main)) { + EXECUTION_LOG_ERROR("Missing 'main' entry function with signature (*ExecutionContext) - >int32"); + return 0.0; + } + EXECUTION_LOG_INFO("{} main() returned: {}", mode_identifier, main(exec_ctx)); + } else { + std::function main; + if (!module->GetFunction("main", mode, &main)) { + EXECUTION_LOG_ERROR("Missing 'main' entry function with signature () -> int32"); + return 0.0; + } + EXECUTION_LOG_INFO("{} main() returned: {}", mode_identifier, main()); + } + } + + return exec_ms; +} + /** * Compile TPL source code contained in @em source and execute it in all execution modes once. * @param source The TPL source. @@ -132,12 +202,9 @@ static void CompileAndRun(const std::string &source, const std::string &name = " parsing::Scanner scanner(source.data(), source.length()); parsing::Parser parser(&scanner, &context); - double parse_ms = 0.0, // Time to parse the source - typecheck_ms = 0.0, // Time to perform semantic analysis - codegen_ms = 0.0, // Time to generate TBC - interp_exec_ms = 0.0, // Time to execute the program in fully interpreted mode - adaptive_exec_ms = 0.0, // Time to execute the program in adaptive mode - jit_exec_ms = 0.0; // Time to execute the program in JIT excluding compilation time + double parse_ms = 0.0; // Time to parse the source + double typecheck_ms = 0.0; // Time to perform semantic analysis + double codegen_ms = 0.0; // Time to generate TBC // // Parse @@ -199,86 +266,27 @@ static void CompileAndRun(const std::string &source, const std::string &name = " auto module = std::make_unique(std::move(bytecode_module), std::move(module_metadata)); // - // Interpret + // Execution // - { - exec_ctx->SetExecutionMode(vm::ExecutionMode::Interpret); - util::ScopedTimer timer(&interp_exec_ms); - - if (IS_SQL) { - std::function main; - if (!module->GetFunction("main", vm::ExecutionMode::Interpret, &main)) { - EXECUTION_LOG_ERROR("Missing 'main' entry function with signature (*ExecutionContext)->int32"); - return; - } - EXECUTION_LOG_INFO("VM main() returned: {}", main(exec_ctx.get())); - } else { - std::function main; - if (!module->GetFunction("main", vm::ExecutionMode::Interpret, &main)) { - EXECUTION_LOG_ERROR("Missing 'main' entry function with signature ()->int32"); - return; - } - EXECUTION_LOG_INFO("VM main() returned: {}", main()); - } - } - - // - // Adaptive - // - - exec_ctx->SetExecutionMode(vm::ExecutionMode::Adaptive); - util::ScopedTimer timer(&adaptive_exec_ms); - - if (IS_SQL) { - std::function main; - if (!module->GetFunction("main", vm::ExecutionMode::Adaptive, &main)) { - EXECUTION_LOG_ERROR("Missing 'main' entry function with signature (*ExecutionContext)->int32"); - return; - } - EXECUTION_LOG_INFO("ADAPTIVE main() returned: {}", main(exec_ctx.get())); - } else { - std::function main; - if (!module->GetFunction("main", vm::ExecutionMode::Adaptive, &main)) { - EXECUTION_LOG_ERROR("Missing 'main' entry function with signature ()->int32"); - return; - } - EXECUTION_LOG_INFO("ADAPTIVE main() returned: {}", main()); - } + const double vm_ms = ShouldExecuteInMode(vm::ExecutionMode::Interpret) + ? ExecuteInMode(module.get(), vm::ExecutionMode::Interpret, exec_ctx.get()) + : 0.0; + const double jit_ms = ShouldExecuteInMode(vm::ExecutionMode::Compiled) + ? ExecuteInMode(module.get(), vm::ExecutionMode::Compiled, exec_ctx.get()) + : 0.0; + const double adaptive_ms = ShouldExecuteInMode(vm::ExecutionMode::Adaptive) + ? ExecuteInMode(module.get(), vm::ExecutionMode::Adaptive, exec_ctx.get()) + : 0.0; // - // JIT + // Dump stats // - { - exec_ctx->SetExecutionMode(vm::ExecutionMode::Compiled); - util::ScopedTimer timer(&jit_exec_ms); - - if (IS_SQL) { - std::function main; - if (!module->GetFunction("main", vm::ExecutionMode::Compiled, &main)) { - EXECUTION_LOG_ERROR("Missing 'main' entry function with signature (*ExecutionContext)->int32"); - return; - } - util::Timer x; - x.Start(); - EXECUTION_LOG_INFO("JIT main() returned: {}", main(exec_ctx.get())); - x.Stop(); - EXECUTION_LOG_INFO("Jit exec: {} ms", x.GetElapsed()); - } else { - std::function main; - if (!module->GetFunction("main", vm::ExecutionMode::Compiled, &main)) { - EXECUTION_LOG_ERROR("Missing 'main' entry function with signature ()->int32"); - return; - } - EXECUTION_LOG_INFO("JIT main() returned: {}", main()); - } - } - // Dump stats EXECUTION_LOG_INFO( "Parse: {} ms, Type-check: {} ms, Code-gen: {} ms, Interp. Exec.: {} ms, " - "Adaptive Exec.: {} ms, Jit+Exec.: {} ms", - parse_ms, typecheck_ms, codegen_ms, interp_exec_ms, adaptive_exec_ms, jit_exec_ms); + "JIT Exec.: {} ms, Adaptive Exec.: {} ms", + parse_ms, typecheck_ms, codegen_ms, vm_ms, jit_ms, adaptive_ms); txn_manager->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); } From 2bb9944b774b2625f3db4a81a5b9439752a2a7e9 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 9 Jul 2021 11:22:07 -0400 Subject: [PATCH 062/139] update system_functions_test with new ExecutionContext API --- test/execution/system_functions_test.cpp | 28 +++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/test/execution/system_functions_test.cpp b/test/execution/system_functions_test.cpp index 1cb6dbbba2..07c7c38bf7 100644 --- a/test/execution/system_functions_test.cpp +++ b/test/execution/system_functions_test.cpp @@ -4,6 +4,7 @@ #include "common/version.h" #include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/exec/execution_settings.h" #include "execution/sql/value.h" #include "execution/tpl_test.h" @@ -12,14 +13,31 @@ namespace noisepage::execution::sql::test { class SystemFunctionsTests : public TplTest { public: - SystemFunctionsTests() - : ctx_(catalog::db_oid_t(0), nullptr, nullptr, nullptr, nullptr, settings_, nullptr, DISABLED, DISABLED) {} - - exec::ExecutionContext *Ctx() { return &ctx_; } + SystemFunctionsTests() { + ctx_ = exec::ExecutionContextBuilder() + .WithDatabaseOID(DATABASE_OID) + .WithTxnContext(nullptr) + .WithExecutionSettings(settings_) + .WithOutputSchema(nullptr) + .WithOutputCallback(nullptr) + .WithCatalogAccessor(nullptr) + .WithMetricsManager(DISABLED) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); + } + + /** @return A non-owning pointer to the execution context */ + exec::ExecutionContext *Ctx() { return ctx_.get(); } private: + /** Dummy database OID */ + constexpr static catalog::db_oid_t DATABASE_OID{15721}; + + /** The execution settings for the test */ exec::ExecutionSettings settings_{}; - exec::ExecutionContext ctx_; + /** The execution context for the test */ + std::unique_ptr ctx_; }; // NOLINTNEXTLINE From 0fa39b0892c08b59e31fd7b4fc040968a25c1b71 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 9 Jul 2021 15:11:42 -0400 Subject: [PATCH 063/139] refactor index_nested_loops_join_test to account for updated ExecutionContext API --- sample_tpl/lambda2.tpl | 9 ++ .../index_nested_loops_join_test.cpp | 93 +++++++++++++++---- 2 files changed, 84 insertions(+), 18 deletions(-) create mode 100644 sample_tpl/lambda2.tpl diff --git a/sample_tpl/lambda2.tpl b/sample_tpl/lambda2.tpl new file mode 100644 index 0000000000..3446cc80d5 --- /dev/null +++ b/sample_tpl/lambda2.tpl @@ -0,0 +1,9 @@ +// Expected output: 2 + +fun addOne(x: int32) -> int32 { + return x + 1 +} + +fun main() -> int32 { + return addOne(1) +} diff --git a/test/optimizer/index_nested_loops_join_test.cpp b/test/optimizer/index_nested_loops_join_test.cpp index 8090ac8ca8..c29e00be2e 100644 --- a/test/optimizer/index_nested_loops_join_test.cpp +++ b/test/optimizer/index_nested_loops_join_test.cpp @@ -8,6 +8,7 @@ #include "execution/compiler/executable_query.h" #include "execution/compiler/output_checker.h" #include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/exec/execution_settings.h" #include "execution/sql/value.h" #include "execution/vm/module.h" @@ -125,11 +126,17 @@ struct IdxJoinTest : public TerrierTest { void TearDown() override { TerrierTest::TearDown(); } + /** The connection context */ network::ConnectionContext context_; + /** The catalog instance */ common::ManagedPointer catalog_; + /** The transaction manager instance */ common::ManagedPointer txn_manager_; + /** The traffic cop instance */ common::ManagedPointer tcop_; + /** The database instance */ std::unique_ptr db_main_; + /** The database OID */ catalog::db_oid_t db_oid_; }; @@ -203,9 +210,18 @@ TEST_F(IdxJoinTest, SimpleIdxJoinTest) { execution::exec::ExecutionSettings exec_settings{}; exec_settings.is_parallel_execution_enabled_ = false; execution::exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = std::make_unique( - db_oid_, common::ManagedPointer(txn), callback_fn, out_plan->GetOutputSchema().Get(), - common::ManagedPointer(accessor), exec_settings, db_main_->GetMetricsManager(), DISABLED, DISABLED); + + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithExecutionSettings(exec_settings) + .WithOutputSchema(common::ManagedPointer{out_plan->GetOutputSchema().Get()}) + .WithOutputCallback(callback_fn) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // Run & Check auto executable = execution::compiler::CompilationContext::Compile(*out_plan, exec_ctx->GetExecutionSettings(), @@ -326,9 +342,18 @@ TEST_F(IdxJoinTest, MultiPredicateJoin) { execution::exec::ExecutionSettings exec_settings{}; exec_settings.is_parallel_execution_enabled_ = false; execution::exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = std::make_unique( - db_oid_, common::ManagedPointer(txn), callback_fn, out_plan->GetOutputSchema().Get(), - common::ManagedPointer(accessor), exec_settings, db_main_->GetMetricsManager(), DISABLED, DISABLED); + + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithExecutionSettings(exec_settings) + .WithOutputSchema(common::ManagedPointer{out_plan->GetOutputSchema().Get()}) + .WithOutputCallback(callback_fn) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // Run & Check auto executable = execution::compiler::CompilationContext::Compile(*out_plan, exec_ctx->GetExecutionSettings(), @@ -409,9 +434,17 @@ TEST_F(IdxJoinTest, MultiPredicateJoinWithExtra) { execution::exec::ExecutionSettings exec_settings{}; exec_settings.is_parallel_execution_enabled_ = false; execution::exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = std::make_unique( - db_oid_, common::ManagedPointer(txn), callback_fn, out_plan->GetOutputSchema().Get(), - common::ManagedPointer(accessor), exec_settings, db_main_->GetMetricsManager(), DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithExecutionSettings(exec_settings) + .WithOutputSchema(common::ManagedPointer{out_plan->GetOutputSchema().Get()}) + .WithOutputCallback(callback_fn) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // Run & Check auto executable = execution::compiler::CompilationContext::Compile(*out_plan, exec_ctx->GetExecutionSettings(), @@ -478,9 +511,17 @@ TEST_F(IdxJoinTest, FooOnlyScan) { execution::exec::ExecutionSettings exec_settings{}; exec_settings.is_parallel_execution_enabled_ = false; execution::exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = std::make_unique( - db_oid_, common::ManagedPointer(txn), callback_fn, out_plan->GetOutputSchema().Get(), - common::ManagedPointer(accessor), exec_settings, db_main_->GetMetricsManager(), DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithExecutionSettings(exec_settings) + .WithOutputSchema(common::ManagedPointer{out_plan->GetOutputSchema().Get()}) + .WithOutputCallback(callback_fn) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // Run & Check auto executable = execution::compiler::CompilationContext::Compile(*out_plan, exec_ctx->GetExecutionSettings(), @@ -547,9 +588,17 @@ TEST_F(IdxJoinTest, BarOnlyScan) { execution::exec::ExecutionSettings exec_settings{}; exec_settings.is_parallel_execution_enabled_ = false; execution::exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = std::make_unique( - db_oid_, common::ManagedPointer(txn), callback_fn, out_plan->GetOutputSchema().Get(), - common::ManagedPointer(accessor), exec_settings, db_main_->GetMetricsManager(), DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithExecutionSettings(exec_settings) + .WithOutputSchema(common::ManagedPointer{out_plan->GetOutputSchema().Get()}) + .WithOutputCallback(callback_fn) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // Run & Check auto executable = execution::compiler::CompilationContext::Compile(*out_plan, exec_ctx->GetExecutionSettings(), @@ -629,9 +678,17 @@ TEST_F(IdxJoinTest, IndexToIndexJoin) { execution::exec::ExecutionSettings exec_settings{}; exec_settings.is_parallel_execution_enabled_ = false; execution::exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = std::make_unique( - db_oid_, common::ManagedPointer(txn), callback_fn, out_plan->GetOutputSchema().Get(), - common::ManagedPointer(accessor), exec_settings, db_main_->GetMetricsManager(), DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithExecutionSettings(exec_settings) + .WithOutputSchema(common::ManagedPointer{out_plan->GetOutputSchema().Get()}) + .WithOutputCallback(callback_fn) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // Run & Check auto executable = execution::compiler::CompilationContext::Compile(*out_plan, exec_ctx->GetExecutionSettings(), From 3e190981d68b15bc0e511151ef661ac0b6abc96d Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 9 Jul 2021 18:32:28 -0400 Subject: [PATCH 064/139] some refactoring in lambda code generation, still working out the desired implementation --- src/execution/ast/ast_clone.cpp | 2 +- src/execution/ast/ast_dump.cpp | 2 +- src/execution/ast/ast_pretty_print.cpp | 2 +- src/execution/compiler/pipeline.cpp | 8 ++--- src/execution/compiler/udf/udf_codegen.cpp | 8 ++--- src/execution/sema/sema_expr.cpp | 32 +++++++++++-------- src/execution/vm/bytecode_generator.cpp | 20 ++++++------ src/include/execution/ast/ast.h | 36 +++++++++------------- src/include/execution/ast/type.h | 12 ++++++-- util/execution/tpl.cpp | 9 ++++++ 10 files changed, 74 insertions(+), 57 deletions(-) diff --git a/src/execution/ast/ast_clone.cpp b/src/execution/ast/ast_clone.cpp index 8413d18491..8d3d0aa2e8 100644 --- a/src/execution/ast/ast_clone.cpp +++ b/src/execution/ast/ast_clone.cpp @@ -169,7 +169,7 @@ AstNode *AstCloneImpl::VisitLambdaExpr(LambdaExpr *node) { capture_idents.push_back(reinterpret_cast(Visit(ident))); } return factory_->NewLambdaExpr(node->Position(), - reinterpret_cast(Visit(node->GetFunctionLitExpr())), + reinterpret_cast(Visit(node->GetFunctionLiteralExpr())), std::move(capture_idents)); } diff --git a/src/execution/ast/ast_dump.cpp b/src/execution/ast/ast_dump.cpp index ececaef0c4..c99345d996 100644 --- a/src/execution/ast/ast_dump.cpp +++ b/src/execution/ast/ast_dump.cpp @@ -173,7 +173,7 @@ void AstDumperImpl::VisitFunctionDecl(FunctionDecl *node) { void AstDumperImpl::VisitLambdaExpr(LambdaExpr *node) { DumpNodeCommon(node); - DumpExpr(node->GetFunctionLitExpr()); + DumpExpr(node->GetFunctionLiteralExpr()); } void AstDumperImpl::VisitVariableDecl(VariableDecl *node) { diff --git a/src/execution/ast/ast_pretty_print.cpp b/src/execution/ast/ast_pretty_print.cpp index fa8bfa4fb8..f2bf0b9579 100644 --- a/src/execution/ast/ast_pretty_print.cpp +++ b/src/execution/ast/ast_pretty_print.cpp @@ -293,7 +293,7 @@ void AstPrettyPrintImpl::VisitIndexExpr(IndexExpr *node) { void AstPrettyPrintImpl::VisitLambdaExpr(LambdaExpr *node) { os_ << "lambda "; - VisitFunctionLitExpr(node->GetFunctionLitExpr()); + VisitFunctionLitExpr(node->GetFunctionLiteralExpr()); } void AstPrettyPrintImpl::VisitFunctionTypeRepr(FunctionTypeRepr *node) { diff --git a/src/execution/compiler/pipeline.cpp b/src/execution/compiler/pipeline.cpp index c170026344..768f5db7ee 100644 --- a/src/execution/compiler/pipeline.cpp +++ b/src/execution/compiler/pipeline.cpp @@ -292,8 +292,8 @@ ast::FunctionDecl *Pipeline::GeneratePipelineWrapperFunction(ast::LambdaExpr *ou auto params = compilation_context_->QueryParams(); auto run_params = params; if (output_callback != nullptr) { - run_params.push_back(codegen_->MakeField(output_callback->GetName(), - codegen_->LambdaType(output_callback->GetFunctionLitExpr()->TypeRepr()))); + run_params.push_back(codegen_->MakeField( + output_callback->GetName(), codegen_->LambdaType(output_callback->GetFunctionLiteralExpr()->TypeRepr()))); } FunctionBuilder builder(codegen_, name, std::move(run_params), codegen_->Nil()); { @@ -360,7 +360,7 @@ ast::FunctionDecl *Pipeline::GeneratePipelineWorkFunction(ast::LambdaExpr *outpu if (output_callback != nullptr) { params.push_back(codegen_->MakeField(output_callback->GetName(), - codegen_->LambdaType(output_callback->GetFunctionLitExpr()->TypeRepr()))); + codegen_->LambdaType(output_callback->GetFunctionLiteralExpr()->TypeRepr()))); } FunctionBuilder builder(codegen_, GetWorkFunctionName(), std::move(params), codegen_->Nil()); @@ -464,7 +464,7 @@ ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction(query_id_t query_id, as } if (output_callback != nullptr) { params.push_back(codegen_->MakeField(output_callback->GetName(), - codegen_->LambdaType(output_callback->GetFunctionLitExpr()->TypeRepr()))); + codegen_->LambdaType(output_callback->GetFunctionLiteralExpr()->TypeRepr()))); } FunctionBuilder builder(codegen_, name, std::move(params), codegen_->Nil()); { diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 8e0089dace..77d627d224 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -392,8 +392,8 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { auto decls = exec_query->GetDecls(); aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); - fb_->Append( - codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLitExpr()->TypeRepr()), lambda_expr)); + fb_->Append(codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLiteralExpr()->TypeRepr()), + lambda_expr)); auto query_state = codegen_->MakeFreshIdentifier("query_state"); fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); @@ -550,8 +550,8 @@ void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { auto decls = exec_query->GetDecls(); aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); - fb_->Append( - codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLitExpr()->TypeRepr()), lambda_expr)); + fb_->Append(codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLiteralExpr()->TypeRepr()), + lambda_expr)); // Make query state auto query_state = codegen_->MakeFreshIdentifier("query_state"); diff --git a/src/execution/sema/sema_expr.cpp b/src/execution/sema/sema_expr.cpp index fc8c68945a..7448e90597 100644 --- a/src/execution/sema/sema_expr.cpp +++ b/src/execution/sema/sema_expr.cpp @@ -169,6 +169,16 @@ void Sema::VisitCallExpr(ast::CallExpr *node) { void Sema::VisitLambdaExpr(ast::LambdaExpr *node) { auto factory = GetContext()->GetNodeFactory(); + + // Resolve the types necessary to get the type representation + // used to implement captures for closures produced by lambdas + + // TODO(Kyle): We perform quite a bit of mutation here during + // semantic analysis because this is where we resolve the type + // of the captures for the closure produced by the lambda expression; + // in the future we might want to revisit this to determine if + // we can perform this resolution during AST construction instead. + util::RegionVector fields(GetContext()->GetRegion()); for (auto expr : node->GetCaptureIdents()) { auto ident = expr->As(); @@ -183,9 +193,9 @@ void Sema::VisitLambdaExpr(ast::LambdaExpr *node) { ->GetTplName()))); fields.push_back(factory->NewFieldDecl(SourcePosition(), ident->Name(), type_repr)); } else { - util::RegionVector fields2(GetContext()->GetRegion()); - for (auto field : ident->GetType()->SafeAs()->GetFieldsWithoutPadding()) { - fields2.push_back(factory->NewFieldDecl( + util::RegionVector nested_fields{GetContext()->GetRegion()}; + for (const auto &field : ident->GetType()->SafeAs()->GetFieldsWithoutPadding()) { + nested_fields.push_back(factory->NewFieldDecl( SourcePosition(), field.name_, factory->NewIdentifierExpr( SourcePosition(), @@ -193,34 +203,32 @@ void Sema::VisitLambdaExpr(ast::LambdaExpr *node) { ast::BuiltinType::Get(GetContext(), field.type_->As()->GetKind()) ->GetTplName())))); } - - auto type_repr = - factory->NewPointerType(SourcePosition(), factory->NewStructType(SourcePosition(), std::move(fields2))); + auto *type_repr = + factory->NewPointerType(SourcePosition(), factory->NewStructType(SourcePosition(), std::move(nested_fields))); fields.push_back(factory->NewFieldDecl(SourcePosition(), ident->Name(), type_repr)); } } + fields.push_back( factory->NewFieldDecl(SourcePosition(), GetContext()->GetIdentifier("function"), - factory->NewPointerType(SourcePosition(), node->GetFunctionLitExpr()->TypeRepr()))); + factory->NewPointerType(SourcePosition(), node->GetFunctionLiteralExpr()->TypeRepr()))); ast::StructTypeRepr *struct_type_repr = factory->NewStructType(SourcePosition(), std::move(fields)); - // TODO(Kyle): Find a better name for this identifier ast::StructDecl *struct_decl = factory->NewStructDecl( SourcePosition(), GetContext()->GetIdentifier("lambda" + std::to_string(node->Position().line_)), struct_type_repr); VisitStructDecl(struct_decl); node->SetCaptureStructType(Resolve(struct_type_repr)); - node->SetType(ast::LambdaType::Get(Resolve(node->GetFunctionLitExpr()->TypeRepr())->As())); + node->SetType(ast::LambdaType::Get(Resolve(node->GetFunctionLiteralExpr()->TypeRepr())->As())); - // TODO(Kyle): Why are we performing so much mutation in semantic analysis? - auto type = Resolve(node->GetFunctionLitExpr()->TypeRepr()); + auto type = Resolve(node->GetFunctionLiteralExpr()->TypeRepr()); auto fn_type = type->As(); fn_type->GetParams().emplace_back(GetContext()->GetIdentifier("captures"), GetBuiltinType(ast::BuiltinType::Kind::Int32)->PointerTo()); fn_type->SetIsLambda(true); fn_type->SetCapturesType(node->GetCaptureStructType()->As()); - VisitFunctionLitExpr(node->GetFunctionLitExpr()); + VisitFunctionLitExpr(node->GetFunctionLiteralExpr()); } void Sema::VisitFunctionLitExpr(ast::FunctionLitExpr *node) { diff --git a/src/execution/vm/bytecode_generator.cpp b/src/execution/vm/bytecode_generator.cpp index b3dcd8e887..a51d89e45d 100644 --- a/src/execution/vm/bytecode_generator.cpp +++ b/src/execution/vm/bytecode_generator.cpp @@ -229,9 +229,9 @@ void BytecodeGenerator::VisitFunctionDecl(ast::FunctionDecl *node) { void BytecodeGenerator::VisitLambdaExpr(ast::LambdaExpr *node) { // The function's TPL type - auto *func_type = node->GetFunctionLitExpr()->GetType()->As(); + auto *func_type = node->GetFunctionLiteralExpr()->GetType()->As(); - // Allocate the function + // Elide code generation for lambda expressions that are not stored if (!GetExecutionResult()->HasDestination()) { return; } @@ -239,15 +239,17 @@ void BytecodeGenerator::VisitLambdaExpr(ast::LambdaExpr *node) { auto captures = GetCurrentFunction()->NewLocal(node->GetCaptureStructType(), node->GetName().GetString() + "captures"); auto fields = node->GetCaptureStructType()->As()->GetFieldsWithoutPadding(); - for (std::size_t i = 0; i < fields.size() - 1; i++) { + + // Capture each of the values for the closure by storing the + // current value of the captured local in the captures struct + for (std::size_t i = 0; i < fields.size() - 1; ++i) { auto field = fields[i]; - ast::IdentifierExpr ident(node->Position(), field.name_); + ast::IdentifierExpr ident{node->Position(), field.name_}; ident.SetType(field.type_->GetPointeeType()); - auto local = VisitExpressionForLValue(&ident); - - LocalVar fieldvar = GetCurrentFunction()->NewLocal(fields[i].type_->PointerTo(), ""); + LocalVar local = VisitExpressionForLValue(&ident); + LocalVar fieldvar = GetCurrentFunction()->NewLocal(field.type_->PointerTo(), ""); GetEmitter()->EmitLea(fieldvar, captures.AddressOf(), - node->GetCaptureStructType()->As()->GetOffsetOfFieldByName(fields[i].name_)); + node->GetCaptureStructType()->As()->GetOffsetOfFieldByName(field.name_)); GetEmitter()->EmitAssign(Bytecode::Assign8, fieldvar.ValueOf(), local); } @@ -269,7 +271,7 @@ void BytecodeGenerator::VisitLambdaExpr(ast::LambdaExpr *node) { // range in the function. EnterFunction(func_info->GetId()); BytecodePositionScope position_scope(this, func_info); - Visit(node->GetFunctionLitExpr()->Body()); + Visit(node->GetFunctionLiteralExpr()->Body()); } for (auto &f : func_info->actions_) { f(); diff --git a/src/include/execution/ast/ast.h b/src/include/execution/ast/ast.h index 8969dd2f1f..d9ed04d7b9 100644 --- a/src/include/execution/ast/ast.h +++ b/src/include/execution/ast/ast.h @@ -1090,17 +1090,15 @@ class BinaryOpExpr : public Expr { class LambdaExpr : public Expr { public: /** - * Construct + * Construct a new LambdaExpr instance. * @param pos source position - * @param func the associated function literal expression + * @param function the associated function literal expression * @param captures a collection of lambda captures */ - LambdaExpr(const SourcePosition &pos, FunctionLitExpr *func, util::RegionVector &&captures) - : Expr{Kind::LambdaExpr, pos}, func_lit_{func}, capture_idents_{std::move(captures)} {} + LambdaExpr(const SourcePosition &pos, FunctionLitExpr *function, util::RegionVector &&captures) + : Expr{Kind::LambdaExpr, pos}, function_literal_{function}, capture_idents_{std::move(captures)} {} - /** - * @return The identifier for this lambda expression. - */ + /** @return The identifier for this lambda expression. */ const Identifier &GetName() const { return name_; } /** @@ -1109,9 +1107,7 @@ class LambdaExpr : public Expr { */ void SetName(Identifier name) { name_ = name; } - /** - * @return Get the capture struct type for this lambda expression. - */ + /** @return Get the capture struct type for this lambda expression. */ ast::Type *GetCaptureStructType() const { return capture_type_; } /** @@ -1120,14 +1116,10 @@ class LambdaExpr : public Expr { */ void SetCaptureStructType(ast::Type *capture_type) { capture_type_ = capture_type; } - /** - * @return The function literal expression associated with this lambda. - */ - FunctionLitExpr *GetFunctionLitExpr() const { return func_lit_; } + /** @return The function literal expression associated with this lambda. */ + FunctionLitExpr *GetFunctionLiteralExpr() const { return function_literal_; } - /** - * @return The identifiers for the captures of this lambda expression. - */ + /** @return The identifiers for the captures of this lambda expression. */ const util::RegionVector &GetCaptureIdents() const { return capture_idents_; } /** @@ -1141,13 +1133,13 @@ class LambdaExpr : public Expr { private: friend class sema::Sema; - // The identifier for the lambda expression. + /** The identifier for the lambda expression. */ Identifier name_; - // The type of the lambda captures struct. + /** The type of the lambda captures struct. */ ast::Type *capture_type_; - // The associated function literal expression. - FunctionLitExpr *func_lit_; - // The collection of identifers for lambda captures. + /** The associated function literal expression. */ + FunctionLitExpr *function_literal_; + /** The collection of identifers for lambda captures. */ util::RegionVector capture_idents_; }; diff --git a/src/include/execution/ast/type.h b/src/include/execution/ast/type.h index 108cc4dee5..d9237ff9df 100644 --- a/src/include/execution/ast/type.h +++ b/src/include/execution/ast/type.h @@ -598,8 +598,8 @@ class ArrayType : public Type { }; /** - * A field is a pair containing a name and a type. It is used to represent both fields within a struct, and parameters - * to a function. + * A Field is a pair containing a name and a type. + * It is used to represent both fields within a struct, and parameters to a function. */ struct Field { /** @@ -613,12 +613,18 @@ struct Field { Type *type_; /** - * Constructor + * Construct a new Field instance. * @param name of the field * @param type of the field */ Field(const Identifier &name, Type *type) : name_(name), type_(type) {} + /** @return The name of the field */ + const Identifier &GetName() const { return name_; } + + /** @return The type of the field */ + Type *GetType() const { return type_; } + /** * @param other rhs of the comparison * @return whether this == other diff --git a/util/execution/tpl.cpp b/util/execution/tpl.cpp index fb3e22ca77..ff64dfa80c 100644 --- a/util/execution/tpl.cpp +++ b/util/execution/tpl.cpp @@ -43,6 +43,9 @@ #include "transaction/deferred_action_manager.h" #include "transaction/timestamp_manager.h" +/** Suppress warnings from unused variables */ +#define SUPPRESS_UNUSED(x) ((void)x) + // --------------------------------------------------------- // CLI options // --------------------------------------------------------- @@ -133,6 +136,7 @@ static double ExecuteInMode(vm::Module *module, vm::ExecutionMode mode, exec::Ex } } + SUPPRESS_UNUSED(mode_identifier); return exec_ms; } @@ -287,7 +291,12 @@ static void CompileAndRun(const std::string &source, const std::string &name = " "Parse: {} ms, Type-check: {} ms, Code-gen: {} ms, Interp. Exec.: {} ms, " "JIT Exec.: {} ms, Adaptive Exec.: {} ms", parse_ms, typecheck_ms, codegen_ms, vm_ms, jit_ms, adaptive_ms); + txn_manager->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); + + SUPPRESS_UNUSED(vm_ms); + SUPPRESS_UNUSED(jit_ms); + SUPPRESS_UNUSED(adaptive_ms); } /** From cb77a2cb9a89e9f0631fd5fe6f2723355423d30e Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Mon, 12 Jul 2021 18:06:23 -0400 Subject: [PATCH 065/139] tpl closures working on both VM and JIT --- docs/design_lambdas.md | 25 ++++++++++++ sample_tpl/{lambda0.tpl => closure0.tpl} | 2 +- sample_tpl/closure1.tpl | 11 ++++++ sample_tpl/lambda1.tpl | 10 ----- sample_tpl/param-lambda.tpl | 11 ------ sample_tpl/tpl_tests.txt | 5 +-- src/execution/vm/bytecode_generator.cpp | 39 +++---------------- src/execution/vm/llvm_engine.cpp | 3 +- src/include/execution/vm/bytecode_generator.h | 12 +----- 9 files changed, 46 insertions(+), 72 deletions(-) create mode 100644 docs/design_lambdas.md rename sample_tpl/{lambda0.tpl => closure0.tpl} (85%) create mode 100644 sample_tpl/closure1.tpl delete mode 100644 sample_tpl/lambda1.tpl delete mode 100644 sample_tpl/param-lambda.tpl diff --git a/docs/design_lambdas.md b/docs/design_lambdas.md new file mode 100644 index 0000000000..01039e9276 --- /dev/null +++ b/docs/design_lambdas.md @@ -0,0 +1,25 @@ +# Design Doc: TPL Closures + +### Overview + +This document describes the implementation of closures in TPL. It includes both a high-level description as well as a complete walkthough of the low-level implementation details. + +### Architecture + +TPL closures are implemented as regular TPL functions with the added ability to capture arbitrary variables. In the same way that return values in TPL are implemented via a "hidden" out-parameter to each function, the variables captured by a TPL closure are represented as a TPL structure that is passed as a second hidden parameter to the function that implements the logic of the closure. + +The closure itself is represented as a stack-allocated structure - a local variable within the frame of the function in which the lambda that produces the closure appears. This structure contains `N` fields. The first `N - 1` fields are the variables captured by the closure. The final field is a pointer to the compiled function that implements the closure's logic - a regular TPL function. + +TPL closures introduce some interesting implementation challenges that manifest during both code generation and during execution of the generated code. These implementation details are explored in further detail in the sections below. + +Closures can be passed like values throughout a TPL program; this allows one to, for instance, construct a closure and pass it to a higher-order function that then invokes it within its body (this is how clsoures are used to implement generalized output callbacks and enable the implementation of user-defined functions). However, there is a major limitation to our current closure implementation design: because the TPL structure that implements the closure is allocated in the stack frame of the function in which the lambda that produced the closure appears, the closure cannot escape the lexical scope of this function. In other words, we cannot return a closure from a TPL function and invoke it elsewhere because the structure that backs its implementation would be deallocated the moment the function that creates it returns. This is a major limitation, and most languages that support closures (read: every language implementation that I can find) include the ability for closures the escape the scope in which they are defined. Adding support for this functionality obviously requires some additional engineering and often significantly complicates the implementation. We get away with this implementation for now because our use-cases for closures never require us to generate code that allows the closure to escape the scope in which it is defined, but this is likely an issue we should address in the future. + +In a future refactor of this design, it may be beneficial (from a software-design perspective, at least) to implement closures as their own first-class type. Rather than implementing a closure as a TPL structure containing the closure's captures with the ad-hoc constraint that the final member is _always_ a pointer to the closure's associated function, we might consider adding a dedicated `ClosureType` type to the TPL DSL. At a first approximation, this would incur zero additional cost (compile-time or runtime) and would simplify some of the implementation because we could more easily distinguish between regular function invocations and closure invocations. + +### Implementation: Code Generation + +TODO + +### Implementation: Runtime + +TODO diff --git a/sample_tpl/lambda0.tpl b/sample_tpl/closure0.tpl similarity index 85% rename from sample_tpl/lambda0.tpl rename to sample_tpl/closure0.tpl index 9929041941..19d97761a2 100644 --- a/sample_tpl/lambda0.tpl +++ b/sample_tpl/closure0.tpl @@ -1,7 +1,7 @@ // Expected output: 2 fun main() -> int32 { - // Lambda without capture + // Closure without capture var addOne = lambda [] (x: int32) -> int32 { return x + 1 } diff --git a/sample_tpl/closure1.tpl b/sample_tpl/closure1.tpl new file mode 100644 index 0000000000..ea87cabcdf --- /dev/null +++ b/sample_tpl/closure1.tpl @@ -0,0 +1,11 @@ +// Expected output: 3 + +fun main() -> int32 { + var x = 1 + // Closure that uses capture in computation; + // the closure does not write captured variable + var addValue = lambda [x] (y: int32) -> int32 { + return x + y + } + return addValue(2) +} \ No newline at end of file diff --git a/sample_tpl/lambda1.tpl b/sample_tpl/lambda1.tpl deleted file mode 100644 index fc3f3505de..0000000000 --- a/sample_tpl/lambda1.tpl +++ /dev/null @@ -1,10 +0,0 @@ -// Expected output: 3 - -fun main(exec : *ExecutionContext) -> int32 { - var x = 1 - var addValue = lambda [x] (y: int32) -> int32 { - x = x + y - } - addValue(2) - return x -} \ No newline at end of file diff --git a/sample_tpl/param-lambda.tpl b/sample_tpl/param-lambda.tpl deleted file mode 100644 index 417e05b26b..0000000000 --- a/sample_tpl/param-lambda.tpl +++ /dev/null @@ -1,11 +0,0 @@ -// Expected output: 10 - -fun check(x: int32) -> int32 { - var ret = x - return ret -} - -fun main() -> int32 { - var fn = lambda (x: int32) -> nil { return x + 1; } - return fn(2) -} diff --git a/sample_tpl/tpl_tests.txt b/sample_tpl/tpl_tests.txt index c6a8f22d57..f1ee6e49f4 100644 --- a/sample_tpl/tpl_tests.txt +++ b/sample_tpl/tpl_tests.txt @@ -10,8 +10,8 @@ array.tpl,false,44 array-iterate.tpl,false,110 array-iterate-2.tpl,false,110 call.tpl,false,70 -#lambda0.tpl,false,2 -#lambda1.tpl,false,3 +closure0.tpl,false,2 +closure1.tpl,false,3 comments.tpl,false,46 compare.tpl,false,200 date-functions.tpl,false,0 @@ -30,7 +30,6 @@ loop4.tpl,false,166167000 nil.tpl,false,0 offsetof.tpl,false,54 param.tpl,false,10 -#param-lambda.tpl,false,10 TODO(Kyle): Requires lambdas point.tpl,false,-20 pointer.tpl,false,10 return-expr.tpl,false,15 diff --git a/src/execution/vm/bytecode_generator.cpp b/src/execution/vm/bytecode_generator.cpp index 8eb5411bb5..015a9573c0 100644 --- a/src/execution/vm/bytecode_generator.cpp +++ b/src/execution/vm/bytecode_generator.cpp @@ -254,9 +254,9 @@ void BytecodeGenerator::VisitLambdaExpr(ast::LambdaExpr *node) { } GetEmitter()->EmitAssign(Bytecode::Assign8, GetExecutionResult()->GetDestination(), captures.AddressOf()); - FunctionInfo *func_info = - AllocateFunction(node->GetName().GetString(), func_type, captures, node->GetCaptureStructType()); - (void)func_info; + // FunctionInfo *func_info = + // AllocateFunction(node->GetName().GetString(), func_type, captures, node->GetCaptureStructType()); + FunctionInfo *func_info = AllocateFunction(node->GetName().GetString(), func_type); // Create a new deferred action for the current function // that visits the body of the lambda; this action is subsequently @@ -4052,39 +4052,10 @@ FunctionInfo *BytecodeGenerator::AllocateFunction(const std::string &function_na } } - // Cache + // Cache the function func_map_[func->GetName()] = func->GetId(); - for (const auto &action : deferred_function_create_actions_[func->GetName()]) { - action(func->GetId()); - } - - return func; -} - -FunctionInfo *BytecodeGenerator::AllocateFunction(const std::string &function_name, - ast::FunctionType *const function_type, LocalVar captures, - ast::Type *capture_type) { - // Allocate function - const auto func_id = static_cast(functions_.size()); - functions_.push_back(std::make_unique(func_id, function_name, function_type)); - FunctionInfo *func = functions_.back().get(); - // Register return type - if (auto *return_type = function_type->GetReturnType(); !return_type->IsNilType()) { - func->NewParameterLocal(return_type->PointerTo(), "hiddenRv"); - } - - // Lambda captures - func->NewParameterLocal(capture_type->PointerTo(), "hiddenCaptures"); - - // Register parameters - for (const auto ¶m : function_type->GetParams()) { - // TODO(Kyle): Why do we never check for SQL value types here? - func->NewParameterLocal(param.type_, param.name_.GetData()); - } - - // Cache - func_map_[func->GetName()] = func->GetId(); + // Execute all deferred creation actions for the function for (const auto &action : deferred_function_create_actions_[func->GetName()]) { action(func->GetId()); } diff --git a/src/execution/vm/llvm_engine.cpp b/src/execution/vm/llvm_engine.cpp index 9803819c6a..10fc1ac2e8 100644 --- a/src/execution/vm/llvm_engine.cpp +++ b/src/execution/vm/llvm_engine.cpp @@ -214,8 +214,7 @@ llvm::Type *LLVMEngine::TypeMap::GetLLVMType(const ast::Type *type) { break; } case ast::Type::TypeId::LambdaType: { - // TODO(Kyle): Implement this - throw NOT_IMPLEMENTED_EXCEPTION("LambdaType Not Implemented"); + llvm_type = Int32Type()->getPointerTo(); break; } default: { diff --git a/src/include/execution/vm/bytecode_generator.h b/src/include/execution/vm/bytecode_generator.h index d46acf48fa..910f9c357a 100644 --- a/src/include/execution/vm/bytecode_generator.h +++ b/src/include/execution/vm/bytecode_generator.h @@ -75,17 +75,7 @@ class BytecodeGenerator final : public ast::AstVisitor { */ FunctionInfo *AllocateFunction(const std::string &function_name, ast::FunctionType *function_type); - /** - * Allocate a new function with captures (for lambda expressions). - * @param function_name The function name - * @param function_type The function type - * @param captures The local variable for the captures structure - * @param capture_type The type of the captures structure - * @return A non-owning pointer to the allocated function - */ - FunctionInfo *AllocateFunction(const std::string &function_name, ast::FunctionType *function_type, LocalVar captures, - ast::Type *capture_type); - + // Visit a transaction abort call expression void VisitAbortTxn(ast::CallExpr *call); // ONLY FOR TESTING! From a8031d6ab9f8d4fcc234d08f9c7daf6fa6b3aa9e Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 15 Jul 2021 18:16:00 -0400 Subject: [PATCH 066/139] add closures tpl tests, add documentation for the implementation of tpl closures --- docs/design_closures.md | 167 ++++++++++++++++++++++++ docs/design_lambdas.md | 25 ---- sample_tpl/closure2.tpl | 11 ++ sample_tpl/closure3.tpl | 11 ++ sample_tpl/closure4.tpl | 12 ++ sample_tpl/tpl_tests.txt | 4 +- src/execution/sema/sema_checking.cpp | 7 +- src/execution/sema/sema_type.cpp | 9 ++ src/execution/vm/bytecode_generator.cpp | 4 +- 9 files changed, 218 insertions(+), 32 deletions(-) create mode 100644 docs/design_closures.md delete mode 100644 docs/design_lambdas.md create mode 100644 sample_tpl/closure2.tpl create mode 100644 sample_tpl/closure3.tpl create mode 100644 sample_tpl/closure4.tpl diff --git a/docs/design_closures.md b/docs/design_closures.md new file mode 100644 index 0000000000..9a52dcb6a0 --- /dev/null +++ b/docs/design_closures.md @@ -0,0 +1,167 @@ +# Design Doc: TPL Closures + +### Overview + +This document describes the implementation of closures in TPL. It includes both a high-level description as well as a complete walkthough of the low-level implementation details. + +### Architecture + +TPL closures are implemented as regular TPL functions with the added ability to capture arbitrary variables. In the same way that return values in TPL are implemented via a "hidden" out-parameter to each function, the variables captured by a TPL closure are represented as a TPL structure that is passed as a second hidden parameter to the function that implements the logic of the closure. + +The closure itself is represented as a stack-allocated structure - a local variable within the frame of the function in which the lambda that produces the closure appears. This structure contains `N` fields. The first `N - 1` fields are the variables captured by the closure. The final field is a pointer to the compiled function that implements the closure's logic - a regular TPL function. + +TPL closures introduce some interesting implementation challenges that manifest during both code generation and during execution of the generated code. These implementation details are explored in further detail in the sections below. + +Closures can be passed like values throughout a TPL program; this allows one to, for instance, construct a closure and pass it to other functions that may then invoke it to perform computations or produce side-effects. However, there is a major limitation to our current closure implementation design: because the TPL structure that implements the closure is allocated in the stack frame of the function in which the lambda that produced the closure appears, the closure cannot escape the lexical scope of this function. In other words, we cannot return a closure from a TPL function and invoke it elsewhere because the structure that backs its implementation would be deallocated the moment the function that creates it returns. This is a major limitation, and most languages that support closures (read: every language implementation that I can find) include the ability for closures the escape the scope in which they are defined. Adding support for this functionality obviously requires some additional engineering and often significantly complicates the implementation. We get away with this implementation for now because our use-cases for closures never require us to generate code that allows the closure to escape the scope in which it is defined, but this is likely an issue we should address in the future. + +In a future refactor of this design, it may be beneficial (from a software-design perspective, at least) to implement closures as their own first-class type. Rather than implementing a closure as a TPL structure containing the closure's captures with the ad-hoc constraint that the final member is _always_ a pointer to the closure's associated function, we might consider adding a dedicated `ClosureType` type to the TPL DSL. At a first approximation, this would incur zero additional cost (compile-time or runtime) and would simplify some of the implementation because we could more easily distinguish between regular function invocations and closure invocations. + +Another limitation of the current implementation of closures is the inability to easily specify their type. Because TPL is statically typed, this makes implementing higher-order functions that accept closures as arguments or return closures more difficult. At present I am unsure of the best way to address this limitation. Languages like C++ and Rust get around this with either of 1) type erasure (e.g. C++'s `std::function` or Rust's `Fn` trait) or 2) generics. Both of these like relatively involved approaches for our purposes here in TPL. Perhaps we might consider some kind of implicit-conversion facility between function pointer types (which can be concisely specified) and closures (even those that capture). + +### Code Generation Details + +Closures and the lambda expressions that produce them introduce some additional complexity to code generation. The general flow of code generation for a function that contains a lambda expression proceeds as follows: + +- Visit the function declaration for the function in which the lambda expression appears +- Visit the body of the function; during visitation of the statement(s) in the function's body, the lambda expression is encountered +- Visit the lambda expression + - Allocate a new local in the frame of the current function for the closure structure + - Emit the bytecode to "capture" local variables; this is performed by loading the address of all captured locals into the fields in the closure (captures) structure + - Emit the bytecode to pass (a pointer to) the locally-allocated captures structure to the function that will implement the closure's logic + - Allocate the TPL function for the body of the closure + - Defer an action for the current function to visit the body of the closure; this deferred action captures (in C++-land, not in TPL!) the TPL function allocated for the closure +- Complete visitation of the function in which the lambda expression appears +- As the final step in visitation of the function, execute the deferred action to visit the body of the closure's function + +### Walkthough #0: Closure Without Captures + +As a first example, we consider the following TPL program: + +``` +fun main() -> int32 { + var addOne = lambda [] (x: int32) -> int32 { + return x + 1 + } + return addOne(1) +} +``` + +The bytecode generated for this program, with annotations, is shown below. + +``` +Data: + Data section size 0 bytes (0 locals) + +Function 0
: + Frame size 32 bytes (1 parameter, 5 locals) + param hiddenRv: offset=0 size=8 align=8 type=*int32 + + // The addOne local captures the closure that results from evaluating the lambda expression + local addOne: offset=8 size=8 align=8 type=lambda[(int32,*int32)->int32] + + // The captures structure for the closure is allocated in the frame + // of the function in which the lambda expression appears + local addOneCaptures: offset=16 size=8 align=8 type=struct{*(int32,*int32)->int32} + local tmp1: offset=24 size=4 align=4 type=int32 + local tmp2: offset=28 size=4 align=4 type=int32 + + // In the current implementation, the closure is synonymous with a + // pointer to the base of the captures structure; the bytecode in + // the body of the function generated for the closure assumes this + 0x00000000 Assign8 local=&addOne local=&addOneCaptures + 0x0000000c AssignImm4 local=&tmp2 i32=1 + + // Invoke the function generated for the closure; the captures structure is + // passed as an implicit final argument to the function call + 0x00000018 Call func= local=&tmp1 local=tmp2 local=addOne + 0x0000002c Assign4 local=hiddenRv local=tmp1 + 0x00000038 Return + +Function 1 : + Frame size 32 bytes (3 parameters, 5 locals) + param hiddenRv: offset=0 size=8 align=8 type=*int32 + param x: offset=8 size=4 align=4 type=int32 + param captures: offset=16 size=8 align=8 type=*int32 + local tmp1: offset=24 size=4 align=4 type=int32 + local tmp2: offset=28 size=4 align=4 type=int32 + + // The lambda that generated this function has no captures, therefore + // we don't need to do anything special here to handle captured variables + + 0x00000000 AssignImm4 local=&tmp2 i32=1 + // Perform the addition + 0x0000000c Add_int32_t local=&tmp1 local=x local=tmp2 + // Set the return value + 0x0000001c Assign4 local=hiddenRv local=tmp1 + 0x00000028 Return +``` + +### Walkthough #1: Closure With Captures + +As a second example, we consider the following TPL program: + +``` +fun main() -> int32 { + var x = 1 + var addValue = lambda [x] (y: int32) -> int32 { + return x + y + } + return addValue(2) +} +``` + +The bytecode generated for this program, with annotations, is shown below. + +``` +Data: + Data section size 0 bytes (0 locals) + +Function 0
: + Frame size 56 bytes (1 parameter, 7 locals) + param hiddenRv: offset=0 size=8 align=8 type=*int32 + local x: offset=8 size=4 align=4 type=int32 + local addValue: offset=16 size=8 align=8 type=lambda[(int32,*int32)->int32] + + // The first member of the captures structure is a pointer to the captured local; + // the second member of the captures structure is a pointer to the associated function + local addValueCaptures: offset=24 size=16 align=8 type=struct{*int32,*(int32,*int32)->int32} + local tmp1: offset=40 size=8 align=8 type=**int32 + local tmp2: offset=48 size=4 align=4 type=int32 + local tmp3: offset=52 size=4 align=4 type=int32 + + 0x00000000 AssignImm4 local=&x i32=1 + + // Capture the variable `x`; load the address of the local variable `x` into the captures structure + 0x0000000c Lea local=&tmp1 local=&addValueCaptures i32=0 + 0x0000001c Assign8 local=tmp1 local=&x + + // Initialize the closure itself as a pointer to the base of the captures structure + 0x00000028 Assign8 local=&addValue local=&addValueCaptures + + 0x00000034 AssignImm4 local=&tmp3 i32=2 + 0x00000040 Call func= local=&tmp2 local=tmp3 local=addValue + 0x00000054 Assign4 local=hiddenRv local=tmp2 + 0x00000060 Return + +Function 1 : + Frame size 52 bytes (3 parameters, 7 locals) + param hiddenRv: offset=0 size=8 align=8 type=*int32 + param y: offset=8 size=4 align=4 type=int32 + param captures: offset=16 size=8 align=8 type=*int32 + local tmp1: offset=24 size=4 align=4 type=int32 + local tmp2: offset=32 size=8 align=8 type=**int32 + local xptr: offset=40 size=8 align=8 type=*int32 + local tmp3: offset=48 size=4 align=4 type=int32 + + // Load the captured `x` pointer to the local `xptr` + 0x00000000 Lea local=&tmp2 local=captures i32=0 + 0x00000010 DerefN local=&xptr local=tmp2 u32=8 + + // Dereference the pointer to the captured `x` to get its value + 0x00000020 DerefN local=&tmp3 local=xptr u32=4 + + // Perform the addition and return the result + 0x00000030 Add_int32_t local=&tmp1 local=tmp3 local=y + 0x00000040 Assign4 local=hiddenRv local=tmp1 + 0x0000004c Return +``` \ No newline at end of file diff --git a/docs/design_lambdas.md b/docs/design_lambdas.md deleted file mode 100644 index 01039e9276..0000000000 --- a/docs/design_lambdas.md +++ /dev/null @@ -1,25 +0,0 @@ -# Design Doc: TPL Closures - -### Overview - -This document describes the implementation of closures in TPL. It includes both a high-level description as well as a complete walkthough of the low-level implementation details. - -### Architecture - -TPL closures are implemented as regular TPL functions with the added ability to capture arbitrary variables. In the same way that return values in TPL are implemented via a "hidden" out-parameter to each function, the variables captured by a TPL closure are represented as a TPL structure that is passed as a second hidden parameter to the function that implements the logic of the closure. - -The closure itself is represented as a stack-allocated structure - a local variable within the frame of the function in which the lambda that produces the closure appears. This structure contains `N` fields. The first `N - 1` fields are the variables captured by the closure. The final field is a pointer to the compiled function that implements the closure's logic - a regular TPL function. - -TPL closures introduce some interesting implementation challenges that manifest during both code generation and during execution of the generated code. These implementation details are explored in further detail in the sections below. - -Closures can be passed like values throughout a TPL program; this allows one to, for instance, construct a closure and pass it to a higher-order function that then invokes it within its body (this is how clsoures are used to implement generalized output callbacks and enable the implementation of user-defined functions). However, there is a major limitation to our current closure implementation design: because the TPL structure that implements the closure is allocated in the stack frame of the function in which the lambda that produced the closure appears, the closure cannot escape the lexical scope of this function. In other words, we cannot return a closure from a TPL function and invoke it elsewhere because the structure that backs its implementation would be deallocated the moment the function that creates it returns. This is a major limitation, and most languages that support closures (read: every language implementation that I can find) include the ability for closures the escape the scope in which they are defined. Adding support for this functionality obviously requires some additional engineering and often significantly complicates the implementation. We get away with this implementation for now because our use-cases for closures never require us to generate code that allows the closure to escape the scope in which it is defined, but this is likely an issue we should address in the future. - -In a future refactor of this design, it may be beneficial (from a software-design perspective, at least) to implement closures as their own first-class type. Rather than implementing a closure as a TPL structure containing the closure's captures with the ad-hoc constraint that the final member is _always_ a pointer to the closure's associated function, we might consider adding a dedicated `ClosureType` type to the TPL DSL. At a first approximation, this would incur zero additional cost (compile-time or runtime) and would simplify some of the implementation because we could more easily distinguish between regular function invocations and closure invocations. - -### Implementation: Code Generation - -TODO - -### Implementation: Runtime - -TODO diff --git a/sample_tpl/closure2.tpl b/sample_tpl/closure2.tpl new file mode 100644 index 0000000000..20d04de8cb --- /dev/null +++ b/sample_tpl/closure2.tpl @@ -0,0 +1,11 @@ +// Expected output: 6 + +fun main() -> int32 { + var x = 1 + var y = 2 + // Closure that uses multiple captures in computation + var addValues = lambda [x, y] (z: int32) -> int32 { + return x + y + z + } + return addValues(3) +} \ No newline at end of file diff --git a/sample_tpl/closure3.tpl b/sample_tpl/closure3.tpl new file mode 100644 index 0000000000..086d9f782e --- /dev/null +++ b/sample_tpl/closure3.tpl @@ -0,0 +1,11 @@ +// Expected output: 2 + +fun main() -> int32 { + var x = 1 + // Closure that writes to the captured variable + var addOne = lambda [x] () -> nil { + x = x + 1 + } + addOne() + return x +} \ No newline at end of file diff --git a/sample_tpl/closure4.tpl b/sample_tpl/closure4.tpl new file mode 100644 index 0000000000..5575d78a6b --- /dev/null +++ b/sample_tpl/closure4.tpl @@ -0,0 +1,12 @@ +// Expected output: 8 + +fun main() -> int32 { + // Lambda expressions may contain other lambda expressions + var timesFour = lambda [] (x: int32) -> int32 { + var timesTwo = lambda [] (y: int32) -> int32 { + return y*2 + } + return timesTwo(x) + timesTwo(x) + } + return timesFour(2) +} diff --git a/sample_tpl/tpl_tests.txt b/sample_tpl/tpl_tests.txt index f1ee6e49f4..b54b45cae3 100644 --- a/sample_tpl/tpl_tests.txt +++ b/sample_tpl/tpl_tests.txt @@ -12,6 +12,9 @@ array-iterate-2.tpl,false,110 call.tpl,false,70 closure0.tpl,false,2 closure1.tpl,false,3 +closure2.tpl,false,6 +closure3.tpl,false,2 +closure4.tpl,false,8 comments.tpl,false,46 compare.tpl,false,200 date-functions.tpl,false,0 @@ -42,7 +45,6 @@ short-circuit.tpl,false,1 #sql-conversions.tpl,false,0 TODO(WAN): wtf Mac CI? sql-date.tpl,false,0 struct.tpl,false,10 -#struct-lambda.tpl,false,10 TODO(Kyle): Requires lambdas struct-debug.tpl,false,100000 struct-empty.tpl,false,0 struct-field-use.tpl,false,30 diff --git a/src/execution/sema/sema_checking.cpp b/src/execution/sema/sema_checking.cpp index 27c2d6143f..2f74949bac 100644 --- a/src/execution/sema/sema_checking.cpp +++ b/src/execution/sema/sema_checking.cpp @@ -269,10 +269,11 @@ bool Sema::CheckAssignmentConstraints(ast::Type *target_type, ast::Expr **expr) return true; } + // Lambdas (more accurately, the closures produced by lambda expressions) if (target_type->IsLambdaType() && (*expr)->GetType()->IsLambdaType()) { - auto fn_type = (*expr)->GetType()->As()->GetFunctionType(); - auto target_fn = target_type->As()->GetFunctionType(); - return fn_type->IsEqual(target_fn); + auto expr_fn_type = (*expr)->GetType()->As()->GetFunctionType(); + auto target_fn_type = target_type->As()->GetFunctionType(); + return expr_fn_type->IsEqual(target_fn_type); } // Integer expansion diff --git a/src/execution/sema/sema_type.cpp b/src/execution/sema/sema_type.cpp index 49b00f3972..eb619b188d 100644 --- a/src/execution/sema/sema_type.cpp +++ b/src/execution/sema/sema_type.cpp @@ -96,6 +96,15 @@ void Sema::VisitLambdaTypeRepr(ast::LambdaTypeRepr *node) { return; } + // Captures are passed to the function that implements the lambda + // by way of the final parameter to the function; the parameter is + // always specified as an Int32 pointer and then we emit the code + // necessary to dereference the pointers within the structure + // (relative to the base pointer) appropriately to extract captures + + // TODO(Kyle): This seems like a potentially-expedient yet needlessly + // confusing (and potentially unsafe?) way to implement the passage + // the captures structure to the function that implements the closure fn_type->GetParams().emplace_back(GetContext()->GetIdentifier("captures"), GetBuiltinType(ast::BuiltinType::Kind::Int32)->PointerTo()); node->SetType(ast::LambdaType::Get(fn_type)); diff --git a/src/execution/vm/bytecode_generator.cpp b/src/execution/vm/bytecode_generator.cpp index 015a9573c0..2099cc1f2a 100644 --- a/src/execution/vm/bytecode_generator.cpp +++ b/src/execution/vm/bytecode_generator.cpp @@ -237,7 +237,7 @@ void BytecodeGenerator::VisitLambdaExpr(ast::LambdaExpr *node) { } auto captures = - GetCurrentFunction()->NewLocal(node->GetCaptureStructType(), node->GetName().GetString() + "captures"); + GetCurrentFunction()->NewLocal(node->GetCaptureStructType(), node->GetName().GetString() + "Captures"); auto fields = node->GetCaptureStructType()->As()->GetFieldsWithoutPadding(); // Capture each of the values for the closure by storing the @@ -254,8 +254,6 @@ void BytecodeGenerator::VisitLambdaExpr(ast::LambdaExpr *node) { } GetEmitter()->EmitAssign(Bytecode::Assign8, GetExecutionResult()->GetDestination(), captures.AddressOf()); - // FunctionInfo *func_info = - // AllocateFunction(node->GetName().GetString(), func_type, captures, node->GetCaptureStructType()); FunctionInfo *func_info = AllocateFunction(node->GetName().GetString(), func_type); // Create a new deferred action for the current function From f64c71877106dc1163caef7315d469f1348f7c5f Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 16 Jul 2021 15:23:46 -0400 Subject: [PATCH 067/139] minor tweaks, first end to end test of UDF definition and execution works --- .../expression/function_translator.cpp | 11 ++-- src/execution/compiler/udf/udf_codegen.cpp | 51 ++++++++++--------- .../execution/compiler/udf/udf_codegen.h | 27 +++++----- .../execution/functions/function_context.h | 40 +++++---------- src/network/noisepage_server.cpp | 2 +- 5 files changed, 57 insertions(+), 74 deletions(-) diff --git a/src/execution/compiler/expression/function_translator.cpp b/src/execution/compiler/expression/function_translator.cpp index 1d9f045565..cf69088bcf 100644 --- a/src/execution/compiler/expression/function_translator.cpp +++ b/src/execution/compiler/expression/function_translator.cpp @@ -23,11 +23,8 @@ ast::Expr *FunctionTranslator::DeriveValue(WorkContext *ctx, const ColumnValuePr const auto &func_expr = GetExpressionAs(); auto proc_oid = func_expr.GetProcOid(); auto func_context = codegen->GetCatalogAccessor()->GetFunctionContext(proc_oid); - if (!func_context->IsBuiltin()) { - UNREACHABLE("User-defined functions are not supported"); - } - std::vector params; + std::vector params{}; if (func_context->IsExecCtxRequired()) { params.push_back(GetExecutionContextPtr()); } @@ -37,9 +34,9 @@ ast::Expr *FunctionTranslator::DeriveValue(WorkContext *ctx, const ColumnValuePr } if (!func_context->IsBuiltin()) { - auto ident_expr = main_fn_; - std::vector args{params.cbegin(), params.cbegin()}; - return GetCodeGen()->Call(ident_expr, args); + const auto identifier_expr = main_fn_; + std::vector args{params.cbegin(), params.cend()}; + return GetCodeGen()->Call(identifier_expr, args); } return codegen->CallBuiltin(func_context->GetBuiltin(), params); diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 77d627d224..b48d28cec2 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -41,7 +41,7 @@ UDFCodegen::UDFCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, for (auto i = 0UL; fb->GetParameterByPosition(i) != nullptr; ++i) { auto param = fb->GetParameterByPosition(i); const auto &name = param->As()->Name(); - str_to_ident_.emplace(name.GetString(), name); + SymbolTable()[name.GetString()] = name; } } @@ -68,7 +68,7 @@ catalog::type_oid_t UDFCodegen::GetCatalogTypeOidFromSQLType(execution::ast::Bui } default: return accessor_->GetTypeOidFromTypeId(type::TypeId::INVALID); - NOISEPAGE_ASSERT(false, "Unsupported param type"); + NOISEPAGE_ASSERT(false, "Unsupported parameter type"); } } @@ -90,8 +90,8 @@ void UDFCodegen::Visit(ast::udf::CallExprAST *ast) { args_ast.push_back(dst_); args_ast_region_vec.push_back(dst_); auto *builtin = dst_->GetType()->SafeAs(); - NOISEPAGE_ASSERT(builtin != nullptr, "Not builtin parameter"); - NOISEPAGE_ASSERT(builtin->IsSqlValueType(), "Param is not SQL Value Type"); + NOISEPAGE_ASSERT(builtin != nullptr, "Parameter must be a built-in type"); + NOISEPAGE_ASSERT(builtin->IsSqlValueType(), "Parameter must be a SQL value type"); arg_types.push_back(GetCatalogTypeOidFromSQLType(builtin->GetKind())); } auto proc_oid = accessor_->GetProcOid(ast->Callee(), arg_types); @@ -101,9 +101,9 @@ void UDFCodegen::Visit(ast::udf::CallExprAST *ast) { if (context->IsBuiltin()) { fb_->Append(codegen_->MakeStmt(codegen_->CallBuiltin(context->GetBuiltin(), args_ast))); } else { - auto it = str_to_ident_.find(ast->Callee()); + auto it = SymbolTable().find(ast->Callee()); execution::ast::Identifier ident_expr; - if (it != str_to_ident_.end()) { + if (it != SymbolTable().end()) { ident_expr = it->second; } else { auto file = reinterpret_cast( @@ -113,7 +113,7 @@ void UDFCodegen::Visit(ast::udf::CallExprAST *ast) { aux_decls_.push_back(decl); } ident_expr = codegen_->MakeFreshIdentifier(file->Declarations().back()->Name().GetString()); - str_to_ident_[file->Declarations().back()->Name().GetString()] = ident_expr; + SymbolTable()[file->Declarations().back()->Name().GetString()] = ident_expr; } fb_->Append(codegen_->MakeStmt(codegen_->Call(ident_expr, args_ast_region_vec))); } @@ -127,13 +127,14 @@ void UDFCodegen::Visit(ast::udf::DeclStmtAST *ast) { if (ast->Name() == "*internal*") { return; } - execution::ast::Identifier ident = codegen_->MakeFreshIdentifier(ast->Name()); - str_to_ident_.emplace(ast->Name(), ident); + const execution::ast::Identifier ident = codegen_->MakeFreshIdentifier(ast->Name()); + SymbolTable()[ast->Name()] = ident; + auto prev_type = current_type_; execution::ast::Expr *tpl_type = nullptr; if (ast->Type() == type::TypeId::INVALID) { // record type - execution::util::RegionVector fields(codegen_->GetAstContext()->GetRegion()); + execution::util::RegionVector fields{codegen_->GetAstContext()->GetRegion()}; for (const auto &p : udf_ast_context_->GetRecordType(ast->Name())) { fields.push_back(codegen_->MakeField(codegen_->MakeIdentifier(p.first), codegen_->TplType(execution::sql::GetTypeId(p.second)))); @@ -156,15 +157,14 @@ void UDFCodegen::Visit(ast::udf::DeclStmtAST *ast) { void UDFCodegen::Visit(ast::udf::FunctionAST *ast) { for (size_t i = 0; i < ast->ParameterTypes().size(); i++) { - // auto param_type = codegen_->TplType(ast->param_types_[i]); - str_to_ident_.emplace(ast->ParameterNames().at(i), codegen_->MakeFreshIdentifier("udf")); + SymbolTable()[ast->ParameterNames().at(i)] = codegen_->MakeFreshIdentifier("udf"); } ast->Body()->Accept(this); } void UDFCodegen::Visit(ast::udf::VariableExprAST *ast) { - auto it = str_to_ident_.find(ast->Name()); - NOISEPAGE_ASSERT(it != str_to_ident_.end(), "variable not declared"); + auto it = SymbolTable().find(ast->Name()); + NOISEPAGE_ASSERT(it != SymbolTable().end(), "Variable not declared"); dst_ = codegen_->MakeExpr(it->second); } @@ -210,8 +210,8 @@ void UDFCodegen::Visit(ast::udf::AssignStmtAST *ast) { reinterpret_cast(ast->Source())->Accept(this); auto rhs_expr = dst_; - auto it = str_to_ident_.find(ast->Destination()->Name()); - NOISEPAGE_ASSERT(it != str_to_ident_.end(), "Variable not found"); + auto it = SymbolTable().find(ast->Destination()->Name()); + NOISEPAGE_ASSERT(it != SymbolTable().end(), "Variable not found"); auto left_codegen_ident = it->second; auto *left_expr = codegen_->MakeExpr(left_codegen_ident); @@ -264,7 +264,7 @@ void UDFCodegen::Visit(ast::udf::BinaryExprAST *ast) { op_token = execution::parsing::Token::Type::EQUAL_EQUAL; break; default: - // TODO(tanujnay112): figure out concatenation operation from expressions? + // TODO(Kyle): Figure out concatenation operation from expressions? UNREACHABLE("Unsupported expression"); } ast->Left()->Accept(this); @@ -341,9 +341,10 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); std::size_t i{0}; for (const auto &var : ast->Variables()) { - var_idents.push_back(str_to_ident_.find(var)->second); + var_idents.push_back(SymbolTable().find(var)->second); auto var_ident = var_idents.back(); - NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, "Can't support non scalars yet!"); + NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, + "UDF support for non-scalars is not implemented"); auto type = codegen_->TplType(execution::sql::GetTypeId(plan->GetOutputSchema()->GetColumn(i).GetType())); fb_->Append(codegen_->Assign(codegen_->MakeExpr(var_ident), @@ -368,7 +369,7 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { } execution::util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; - for (const auto &[name, identifier] : str_to_ident_) { + for (const auto &[name, identifier] : SymbolTable()) { // TODO(Kyle): Why do we skip this particular identifier? if (name == "executionCtx") { continue; @@ -442,7 +443,7 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { default: UNREACHABLE("Unsupported parameter type"); } - fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, codegen_->MakeExpr(str_to_ident_[entry->first])})); + fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, codegen_->MakeExpr(SymbolTable().at(entry->first))})); } fb_->Append(codegen_->Assign( @@ -503,7 +504,7 @@ void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { std::vector assignees{}; execution::util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; for (auto &col : cols) { - execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->Name())->second); + execution::ast::Expr *capture_var = codegen_->MakeExpr(SymbolTable().find(ast->Name())->second); type::TypeId udf_type{}; udf_ast_context_->GetVariableType(ast->Name(), &udf_type); if (udf_type == type::TypeId::INVALID) { @@ -573,11 +574,11 @@ void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { auto &fields = udf_ast_context_->GetRecordType(entry->second.first); auto it = std::find_if(fields.begin(), fields.end(), [=](auto p) { return p.first == entry->first; }); type = it->second; - expr = codegen_->AccessStructMember(codegen_->MakeExpr(str_to_ident_[entry->second.first]), + expr = codegen_->AccessStructMember(codegen_->MakeExpr(SymbolTable().at(entry->second.first)), codegen_->MakeIdentifier(entry->first)); } else { udf_ast_context_->GetVariableType(entry->first, &type); - expr = codegen_->MakeExpr(str_to_ident_[entry->first]); + expr = codegen_->MakeExpr(SymbolTable().at(entry->first)); } execution::ast::Builtin builtin{}; @@ -619,7 +620,7 @@ void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); for (auto &col : cols) { - execution::ast::Expr *capture_var = codegen_->MakeExpr(str_to_ident_.find(ast->Name())->second); + execution::ast::Expr *capture_var = codegen_->MakeExpr(SymbolTable().find(ast->Name())->second); auto *lhs = capture_var; if (cols.size() > 1) { // Record struct type diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index d2a8da5d2e..50ea90441c 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -13,12 +13,10 @@ namespace noisepage::catalog { class CatalogAccessor; } -namespace noisepage { -namespace execution { +namespace noisepage::execution { // Forward declarations -namespace ast { -namespace udf { +namespace ast::udf { class AbstractAST; class StmtAST; class ExprAST; @@ -38,11 +36,9 @@ class FunctionAST; class IsNullExprAST; class DynamicSQLStmtAST; class ForStmtAST; -} // namespace udf -} // namespace ast +} // namespace ast::udf -namespace compiler { -namespace udf { +namespace compiler::udf { /** * The UDFCodegen class implements a visitor for UDF AST nodes @@ -206,6 +202,13 @@ class UDFCodegen : ast::udf::ASTNodeVisitor { */ catalog::type_oid_t GetCatalogTypeOidFromSQLType(execution::ast::BuiltinType::Kind type); + /** @return A mutable reference to the symbol table */ + std::unordered_map &SymbolTable() { return symbol_table_; } + + /** @return An immutable reference to the symbol table */ + const std::unordered_map &SymbolTable() const { return symbol_table_; } + + private: /** The catalog access used during code generation */ catalog::CatalogAccessor *accessor_; @@ -234,10 +237,8 @@ class UDFCodegen : ast::udf::ASTNodeVisitor { execution::ast::Expr *dst_; /** Map from human-readable string identifier to internal identifier */ - std::unordered_map str_to_ident_; + std::unordered_map symbol_table_; }; -} // namespace udf -} // namespace compiler -} // namespace execution -} // namespace noisepage +} // namespace compiler::udf +} // namespace noisepage::execution diff --git a/src/include/execution/functions/function_context.h b/src/include/execution/functions/function_context.h index 171f5c829e..9c09ad3e6c 100644 --- a/src/include/execution/functions/function_context.h +++ b/src/include/execution/functions/function_context.h @@ -21,7 +21,7 @@ namespace noisepage::execution::functions { class FunctionContext { public: /** - * Creates a FunctionContext object + * Construct a FunctionContext instance. * @param func_name Name of function * @param func_ret_type Return type of function * @param arg_types Vector of argument types @@ -34,7 +34,7 @@ class FunctionContext { is_exec_ctx_required_{false} {} /** - * Creates a FunctionContext object for a builtin function + * Construct a FunctionContext instance for a builtin function. * @param func_name Name of function * @param func_ret_type Return type of function * @param arg_types Vector of argument types @@ -51,7 +51,7 @@ class FunctionContext { is_exec_ctx_required_{is_exec_ctx_required} {} /** - * Creates a FunctionContext object for a non-builtin function. + * Construct a FunctionContext instance for a non-builtin function. * @param func_name Name of function= * @param func_ret_type Return type of function * @param arg_types Vector of argument types @@ -73,14 +73,10 @@ class FunctionContext { ast_context_{std::move(ast_context)}, file_{file} {} - /** - * @return The name of the function represented by this context object. - */ + /** @return The name of the function represented by this context object. */ const std::string &GetFunctionName() const { return func_name_; } - /** - * @return The vector of type arguments of the function represented by this context object. - */ + /** @return The vector of type arguments of the function represented by this context object. */ const std::vector &GetFunctionArgTypes() const { return arg_types_; } /** @@ -89,47 +85,35 @@ class FunctionContext { */ type::TypeId GetFunctionReturnType() const { return func_ret_type_; } - /** - * @return `true` if this represents a builtin function, `false` otherwise. - */ + /** @return `true` if this represents a builtin function, `false` otherwise. */ bool IsBuiltin() const { return is_builtin_; } - /** - * @return The builtin function this procedure represents. - */ + /** @return The builtin function this procedure represents. */ ast::Builtin GetBuiltin() const { NOISEPAGE_ASSERT(IsBuiltin(), "Getting a builtin from a non-builtin function"); return builtin_; } - /** - * @return `true` if this function requires an execution context, `false` otherwise. - */ + /** @return `true` if this function requires an execution context, `false` otherwise. */ bool IsExecCtxRequired() const { - NOISEPAGE_ASSERT(IsBuiltin(), "IsExecCtxRequired is only valid or a builtin function"); + // TODO(Kyle): Is it valid to query execution context requirement for non-builtins? return is_exec_ctx_required_; } - /** - * @return The main functiondecl of this UDF. - */ + /** @return The main function declaration of this UDF. */ common::ManagedPointer GetMainFunctionDecl() const { NOISEPAGE_ASSERT(!IsBuiltin(), "Getting a non-builtin from a builtin function"); return common::ManagedPointer( reinterpret_cast(file_->Declarations().back())); } - /** - * @return The file with the functiondecl and supporting decls. - */ + /** @return The file with the function declaration and supporting declarations. */ ast::File *GetFile() const { NOISEPAGE_ASSERT(!IsBuiltin(), "Getting a non-builtin from a builtin function"); return file_; } - /** - * @return The AST context for this procedure. - */ + /** @return The AST context for this procedure. */ ast::Context *GetASTContext() const { NOISEPAGE_ASSERT(!IsBuiltin(), "No AST Context associated with builtin function"); return ast_context_.get(); diff --git a/src/network/noisepage_server.cpp b/src/network/noisepage_server.cpp index de83c1dbd1..14b9254b1d 100644 --- a/src/network/noisepage_server.cpp +++ b/src/network/noisepage_server.cpp @@ -140,7 +140,7 @@ void TerrierServer::RunServer() { RegisterSocket(); // Register the Unix domain socket. - RegisterSocket(); + // RegisterSocket(); // Register the ConnectionDispatcherTask. This handles connections to the sockets created above. dispatcher_task_ = thread_registry_->RegisterDedicatedThread( From b2cd8e1736cec12e35af2a5a5c131e0c907afd79 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 16 Jul 2021 16:09:39 -0400 Subject: [PATCH 068/139] refactor udf code generator and udf ast context --- src/binder/bind_node_visitor.cpp | 2 +- src/execution/compiler/udf/udf_codegen.cpp | 63 +++++++++++-------- src/execution/sql/ddl_executors.cpp | 8 +-- src/include/binder/bind_node_visitor.h | 4 +- .../execution/ast/udf/udf_ast_context.h | 24 +++---- .../execution/compiler/udf/udf_codegen.h | 30 ++++++--- src/include/parser/udf/udf_parser.h | 12 ++-- src/parser/udf/udf_parser.cpp | 2 +- 8 files changed, 80 insertions(+), 65 deletions(-) diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 603f45c3f8..8d8bec949b 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -66,7 +66,7 @@ BindNodeVisitor::~BindNodeVisitor() = default; std::unordered_map> BindNodeVisitor::BindAndGetUDFParams( common::ManagedPointer parse_result, - common::ManagedPointer udf_ast_context) { + common::ManagedPointer udf_ast_context) { NOISEPAGE_ASSERT(parse_result != nullptr, "We shouldn't be trying to bind something without a ParseResult."); sherpa_ = std::make_unique(parse_result, nullptr, nullptr); NOISEPAGE_ASSERT(sherpa_->GetParseResult()->GetStatements().size() == 1, "Binder can only bind one at a time."); diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index b48d28cec2..7151f76ec1 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -29,8 +29,8 @@ namespace noisepage::execution::compiler::udf { -UDFCodegen::UDFCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, - ast::udf::UDFASTContext *udf_ast_context, CodeGen *codegen, catalog::db_oid_t db_oid) +UdfCodegen::UdfCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, + ast::udf::UdfAstContext *udf_ast_context, CodeGen *codegen, catalog::db_oid_t db_oid) : accessor_{accessor}, fb_{fb}, udf_ast_context_{udf_ast_context}, @@ -46,19 +46,28 @@ UDFCodegen::UDFCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, } // Static -const char *UDFCodegen::GetReturnParamString() { return "return_val"; } +execution::ast::File *UdfCodegen::Run(catalog::CatalogAccessor *accessor, FunctionBuilder *function_builder, + ast::udf::UdfAstContext *ast_context, CodeGen *codegen, catalog::db_oid_t db_oid, + ast::udf::FunctionAST *root) { + UdfCodegen generator{accessor, function_builder, ast_context, codegen, db_oid}; + generator.GenerateUDF(root->Body()); + return generator.Finish(); +} + +// Static +const char *UdfCodegen::GetReturnParamString() { return "return_val"; } -void UDFCodegen::GenerateUDF(ast::udf::AbstractAST *ast) { ast->Accept(this); } +void UdfCodegen::GenerateUDF(ast::udf::AbstractAST *ast) { ast->Accept(this); } -void UDFCodegen::Visit(ast::udf::AbstractAST *ast) { - throw NOT_IMPLEMENTED_EXCEPTION("UDFCodegen::Visit(AbstractAST*)"); +void UdfCodegen::Visit(ast::udf::AbstractAST *ast) { + throw NOT_IMPLEMENTED_EXCEPTION("UdfCodegen::Visit(AbstractAST*)"); } -void UDFCodegen::Visit(ast::udf::DynamicSQLStmtAST *ast) { - throw NOT_IMPLEMENTED_EXCEPTION("UDFCodegen::Visit(DynamicSQLStmtAST*)"); +void UdfCodegen::Visit(ast::udf::DynamicSQLStmtAST *ast) { + throw NOT_IMPLEMENTED_EXCEPTION("UdfCodegen::Visit(DynamicSQLStmtAST*)"); } -catalog::type_oid_t UDFCodegen::GetCatalogTypeOidFromSQLType(execution::ast::BuiltinType::Kind type) { +catalog::type_oid_t UdfCodegen::GetCatalogTypeOidFromSQLType(execution::ast::BuiltinType::Kind type) { switch (type) { case execution::ast::BuiltinType::Kind::Integer: { return accessor_->GetTypeOidFromTypeId(type::TypeId::INTEGER); @@ -72,7 +81,7 @@ catalog::type_oid_t UDFCodegen::GetCatalogTypeOidFromSQLType(execution::ast::Bui } } -execution::ast::File *UDFCodegen::Finish() { +execution::ast::File *UdfCodegen::Finish() { auto fn = fb_->Finish(); execution::util::RegionVector decls{{fn}, codegen_->GetAstContext()->GetRegion()}; decls.insert(decls.begin(), aux_decls_.begin(), aux_decls_.end()); @@ -80,7 +89,7 @@ execution::ast::File *UDFCodegen::Finish() { return file; } -void UDFCodegen::Visit(ast::udf::CallExprAST *ast) { +void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { std::vector args_ast{}; std::vector args_ast_region_vec{}; std::vector arg_types{}; @@ -119,11 +128,11 @@ void UDFCodegen::Visit(ast::udf::CallExprAST *ast) { } } -void UDFCodegen::Visit(ast::udf::StmtAST *ast) { UNREACHABLE("Not implemented"); } +void UdfCodegen::Visit(ast::udf::StmtAST *ast) { UNREACHABLE("Not implemented"); } -void UDFCodegen::Visit(ast::udf::ExprAST *ast) { UNREACHABLE("Not implemented"); } +void UdfCodegen::Visit(ast::udf::ExprAST *ast) { UNREACHABLE("Not implemented"); } -void UDFCodegen::Visit(ast::udf::DeclStmtAST *ast) { +void UdfCodegen::Visit(ast::udf::DeclStmtAST *ast) { if (ast->Name() == "*internal*") { return; } @@ -155,20 +164,20 @@ void UDFCodegen::Visit(ast::udf::DeclStmtAST *ast) { current_type_ = prev_type; } -void UDFCodegen::Visit(ast::udf::FunctionAST *ast) { +void UdfCodegen::Visit(ast::udf::FunctionAST *ast) { for (size_t i = 0; i < ast->ParameterTypes().size(); i++) { SymbolTable()[ast->ParameterNames().at(i)] = codegen_->MakeFreshIdentifier("udf"); } ast->Body()->Accept(this); } -void UDFCodegen::Visit(ast::udf::VariableExprAST *ast) { +void UdfCodegen::Visit(ast::udf::VariableExprAST *ast) { auto it = SymbolTable().find(ast->Name()); NOISEPAGE_ASSERT(it != SymbolTable().end(), "Variable not declared"); dst_ = codegen_->MakeExpr(it->second); } -void UDFCodegen::Visit(ast::udf::ValueExprAST *ast) { +void UdfCodegen::Visit(ast::udf::ValueExprAST *ast) { auto val = common::ManagedPointer(ast->Value()).CastManagedPointerTo(); if (val->IsNull()) { dst_ = codegen_->ConstNull(current_type_); @@ -202,7 +211,7 @@ void UDFCodegen::Visit(ast::udf::ValueExprAST *ast) { } } -void UDFCodegen::Visit(ast::udf::AssignStmtAST *ast) { +void UdfCodegen::Visit(ast::udf::AssignStmtAST *ast) { type::TypeId left_type = type::TypeId::INVALID; udf_ast_context_->GetVariableType(ast->Destination()->Name(), &left_type); current_type_ = left_type; @@ -218,7 +227,7 @@ void UDFCodegen::Visit(ast::udf::AssignStmtAST *ast) { fb_->Append(codegen_->Assign(left_expr, rhs_expr)); } -void UDFCodegen::Visit(ast::udf::BinaryExprAST *ast) { +void UdfCodegen::Visit(ast::udf::BinaryExprAST *ast) { execution::parsing::Token::Type op_token; bool compare = false; switch (ast->Op()) { @@ -279,7 +288,7 @@ void UDFCodegen::Visit(ast::udf::BinaryExprAST *ast) { } } -void UDFCodegen::Visit(ast::udf::IfStmtAST *ast) { +void UdfCodegen::Visit(ast::udf::IfStmtAST *ast) { ast->Condition()->Accept(this); auto cond = dst_; @@ -292,7 +301,7 @@ void UDFCodegen::Visit(ast::udf::IfStmtAST *ast) { branch.EndIf(); } -void UDFCodegen::Visit(ast::udf::IsNullExprAST *ast) { +void UdfCodegen::Visit(ast::udf::IsNullExprAST *ast) { ast->Child()->Accept(this); auto chld = dst_; dst_ = codegen_->CallBuiltin(execution::ast::Builtin::IsValNull, {chld}); @@ -301,13 +310,13 @@ void UDFCodegen::Visit(ast::udf::IsNullExprAST *ast) { } } -void UDFCodegen::Visit(ast::udf::SeqStmtAST *ast) { +void UdfCodegen::Visit(ast::udf::SeqStmtAST *ast) { for (auto &stmt : ast->Statements()) { stmt->Accept(this); } } -void UDFCodegen::Visit(ast::udf::WhileStmtAST *ast) { +void UdfCodegen::Visit(ast::udf::WhileStmtAST *ast) { ast->Condition()->Accept(this); auto cond = dst_; Loop loop(fb_, cond); @@ -315,7 +324,7 @@ void UDFCodegen::Visit(ast::udf::WhileStmtAST *ast) { loop.EndLoop(); } -void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { +void UdfCodegen::Visit(ast::udf::ForStmtAST *ast) { // Once we encounter a For-statement we know we need an execution context needs_exec_ctx_ = true; @@ -462,13 +471,13 @@ void UDFCodegen::Visit(ast::udf::ForStmtAST *ast) { fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); } -void UDFCodegen::Visit(ast::udf::RetStmtAST *ast) { +void UdfCodegen::Visit(ast::udf::RetStmtAST *ast) { ast->Return()->Accept(reinterpret_cast(this)); auto ret_expr = dst_; fb_->Append(codegen_->Return(ret_expr)); } -void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { +void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { // As soon as we encounter an embedded SQL statement, // we know we need an execution context needs_exec_ctx_ = true; @@ -642,7 +651,7 @@ void UDFCodegen::Visit(ast::udf::SQLStmtAST *ast) { fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); } -void UDFCodegen::Visit(ast::udf::MemberExprAST *ast) { +void UdfCodegen::Visit(ast::udf::MemberExprAST *ast) { ast->Object()->Accept(reinterpret_cast(this)); auto object = dst_; dst_ = codegen_->AccessStructMember(object, codegen_->MakeIdentifier(ast->FieldName())); diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index d790d4b0d6..858246d9b5 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -68,7 +68,7 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerGetDatabaseOid()}; std::unique_ptr ast{}; @@ -103,9 +103,9 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerGetReturnType())))}; - compiler::udf::UDFCodegen udf_codegen{accessor.Get(), &fb, &udf_ast_context, &codegen, node->GetDatabaseOid()}; - udf_codegen.GenerateUDF(ast->Body()); - auto *file = udf_codegen.Finish(); + // Run UDF code generation + auto *file = compiler::udf::UdfCodegen::Run(accessor.Get(), &fb, &udf_ast_context, &codegen, node->GetDatabaseOid(), + ast.get()); { sema::Sema type_check{codegen.GetAstContext().Get()}; diff --git a/src/include/binder/bind_node_visitor.h b/src/include/binder/bind_node_visitor.h index 0973a3dd51..ce76085dff 100644 --- a/src/include/binder/bind_node_visitor.h +++ b/src/include/binder/bind_node_visitor.h @@ -61,7 +61,7 @@ class BindNodeVisitor final : public SqlNodeVisitor { */ std::unordered_map> BindAndGetUDFParams( common::ManagedPointer parse_result, - common::ManagedPointer udf_ast_context); + common::ManagedPointer udf_ast_context); /** * Perform binding on the passed in tree. Bind the relation names to oids @@ -118,7 +118,7 @@ class BindNodeVisitor final : public SqlNodeVisitor { common::ManagedPointer context_ = nullptr; /** Context for UDF AST */ - common::ManagedPointer udf_ast_context_{}; + common::ManagedPointer udf_ast_context_{}; /** Parameters for UDF */ std::unordered_map> udf_params_; diff --git a/src/include/execution/ast/udf/udf_ast_context.h b/src/include/execution/ast/udf/udf_ast_context.h index 2b9cc3c547..64f7cde6f7 100644 --- a/src/include/execution/ast/udf/udf_ast_context.h +++ b/src/include/execution/ast/udf/udf_ast_context.h @@ -7,21 +7,18 @@ #include "type/type_id.h" -namespace noisepage { -namespace execution { -namespace ast { -namespace udf { +namespace noisepage::execution::ast::udf { /** - * The UDFASTContext class maintains state that is utilized + * The UdfAstContext class maintains state that is utilized * throughout construction of the UDF abstract syntax tree. */ -class UDFASTContext { +class UdfAstContext { public: /** - * Construct a new UDFASTContext. + * Construct a new AstContext instance. */ - UDFASTContext() = default; + UdfAstContext() = default; /** * Set the type of the variabel identifed by `name`. @@ -84,15 +81,12 @@ class UDFASTContext { } private: - // The symbol table for the UDF. + /** The symbol table for the UDF. */ std::unordered_map symbol_table_; - // Collection of local variable names for the UDF. + /** Collection of local variable names for the UDF. */ std::vector local_variables_; - // Collection of record types for the UDF. + /** Collection of record types for the UDF. */ std::unordered_map>> record_types_; }; -} // namespace udf -} // namespace ast -} // namespace execution -} // namespace noisepage +} // namespace noisepage::execution::ast::udf diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index 50ea90441c..9579381101 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -41,27 +41,41 @@ class ForStmtAST; namespace compiler::udf { /** - * The UDFCodegen class implements a visitor for UDF AST nodes - * and encapsulates all of the logic required to generate code - * from the UDF abstract syntax tree. + * The UdfCodegen class implements a visitor for UDF AST + * nodes and encapsulates all of the logic required to generate + * code from the UDF abstract syntax tree. */ -class UDFCodegen : ast::udf::ASTNodeVisitor { +class UdfCodegen : ast::udf::ASTNodeVisitor { public: /** - * Construct a new UDFCodegen instance. + * Construct a new UdfCodegen instance. * @param accessor The catalog accessor used in code generation * @param fb The function builder instance used for the UDF * @param udf_ast_context The AST context for the UDF * @param codegen The codegen instance * @param db_oid The OID for the relevant database */ - UDFCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, ast::udf::UDFASTContext *udf_ast_context, + UdfCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, ast::udf::UdfAstContext *udf_ast_context, CodeGen *codegen, catalog::db_oid_t db_oid); /** * Destroy the UDF code generation context. */ - ~UDFCodegen() override = default; + ~UdfCodegen() override = default; + + /** + * Run UDF code generation. + * @param accessor The catalog accessor + * @param function_builder The function builder to use during code generation + * @param ast_context The UDF AST context + * @param codegen The code generation instance + * @param db_oid The database OID + * @param root The root of the UDF AST for which code is generated + * @return The file containing the generated code + */ + static execution::ast::File *Run(catalog::CatalogAccessor *accessor, FunctionBuilder *function_builder, + ast::udf::UdfAstContext *ast_context, CodeGen *codegen, catalog::db_oid_t db_oid, + ast::udf::FunctionAST *root); /** * Generate a UDF from the given abstract syntax tree. @@ -216,7 +230,7 @@ class UDFCodegen : ast::udf::ASTNodeVisitor { FunctionBuilder *fb_; /** The AST context for the UDF */ - ast::udf::UDFASTContext *udf_ast_context_; + ast::udf::UdfAstContext *udf_ast_context_; /** The code generation instance */ CodeGen *codegen_; diff --git a/src/include/parser/udf/udf_parser.h b/src/include/parser/udf/udf_parser.h index a6056d1e2c..76b13d220e 100644 --- a/src/include/parser/udf/udf_parser.h +++ b/src/include/parser/udf/udf_parser.h @@ -18,8 +18,7 @@ namespace execution::ast::udf { class FunctionAST; } -namespace parser { -namespace udf { +namespace parser::udf { /** * Namespace alias to make below more manageable. @@ -37,7 +36,7 @@ class PLpgSQLParser { * @param accessor The accessor to use during parsing * @param db_oid The database OID */ - PLpgSQLParser(common::ManagedPointer udf_ast_context, + PLpgSQLParser(common::ManagedPointer udf_ast_context, const common::ManagedPointer accessor, catalog::db_oid_t db_oid) : udf_ast_context_(udf_ast_context), accessor_(accessor), db_oid_(db_oid) {} @@ -51,7 +50,7 @@ class PLpgSQLParser { */ std::unique_ptr Parse(std::vector &¶m_names, std::vector &¶m_types, const std::string &func_body, - common::ManagedPointer ast_context); + common::ManagedPointer ast_context); private: /** @@ -126,7 +125,7 @@ class PLpgSQLParser { private: /** The UDF AST context */ - common::ManagedPointer udf_ast_context_; + common::ManagedPointer udf_ast_context_; /** The catalog accessor */ const common::ManagedPointer accessor_; @@ -138,6 +137,5 @@ class PLpgSQLParser { std::unordered_map symbol_table_; }; -} // namespace udf -} // namespace parser +} // namespace parser::udf } // namespace noisepage diff --git a/src/parser/udf/udf_parser.cpp b/src/parser/udf/udf_parser.cpp index 631f012dc2..bb91a374e0 100644 --- a/src/parser/udf/udf_parser.cpp +++ b/src/parser/udf/udf_parser.cpp @@ -46,7 +46,7 @@ static constexpr const char K_PLPGSQL_STMT_DYNEXECUTE[] = "PLpgSQL_stmt_dynexecu std::unique_ptr PLpgSQLParser::Parse( std::vector &¶m_names, std::vector &¶m_types, const std::string &func_body, - common::ManagedPointer ast_context) { + common::ManagedPointer ast_context) { auto result = pg_query_parse_plpgsql(func_body.c_str()); if (result.error != nullptr) { pg_query_free_plpgsql_parse_result(result); From a1cb9e5f2e9ba10b011fc02a0418c91c82d742c9 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 16 Jul 2021 16:30:57 -0400 Subject: [PATCH 069/139] refactor udf parser to plpgsql parser --- src/execution/sql/ddl_executors.cpp | 7 ++- .../udf/{udf_parser.h => plpgsql_parser.h} | 46 ++++++++----------- .../{udf_parser.cpp => plpgsql_parser.cpp} | 2 +- 3 files changed, 25 insertions(+), 30 deletions(-) rename src/include/parser/udf/{udf_parser.h => plpgsql_parser.h} (66%) rename src/parser/udf/{udf_parser.cpp => plpgsql_parser.cpp} (99%) diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index 858246d9b5..d2d77513da 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -17,7 +17,7 @@ #include "execution/sema/sema.h" #include "loggers/execution_logger.h" #include "parser/expression/column_value_expression.h" -#include "parser/udf/udf_parser.h" +#include "parser/udf/plpgsql_parser.h" #include "planner/plannodes/create_database_plan_node.h" #include "planner/plannodes/create_function_plan_node.h" #include "planner/plannodes/create_index_plan_node.h" @@ -69,8 +69,11 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerGetDatabaseOid()}; + // TODO(Kyle): Revisit this after clearing up what the + // preferred way to report errors is in the system, both + // within components and between components... + parser::udf::PLpgSQLParser udf_parser{(common::ManagedPointer(&udf_ast_context)), accessor, node->GetDatabaseOid()}; std::unique_ptr ast{}; try { ast = udf_parser.Parse(node->GetFunctionParameterNames(), std::move(param_type_ids), body, diff --git a/src/include/parser/udf/udf_parser.h b/src/include/parser/udf/plpgsql_parser.h similarity index 66% rename from src/include/parser/udf/udf_parser.h rename to src/include/parser/udf/plpgsql_parser.h index 76b13d220e..e82dc28013 100644 --- a/src/include/parser/udf/udf_parser.h +++ b/src/include/parser/udf/plpgsql_parser.h @@ -12,18 +12,11 @@ #include "parser/expression_util.h" #include "parser/postgresparser.h" -namespace noisepage { - -namespace execution::ast::udf { +namespace noisepage::execution::ast::udf { class FunctionAST; -} - -namespace parser::udf { +} // namespace noisepage::execution::ast::udf -/** - * Namespace alias to make below more manageable. - */ -namespace udfexec = execution::ast::udf; +namespace noisepage::parser::udf { /** * The PLpgSQLParser class parses source PL/pgSQL to an abstract syntax tree. @@ -36,7 +29,7 @@ class PLpgSQLParser { * @param accessor The accessor to use during parsing * @param db_oid The database OID */ - PLpgSQLParser(common::ManagedPointer udf_ast_context, + PLpgSQLParser(common::ManagedPointer udf_ast_context, const common::ManagedPointer accessor, catalog::db_oid_t db_oid) : udf_ast_context_(udf_ast_context), accessor_(accessor), db_oid_(db_oid) {} @@ -48,9 +41,9 @@ class PLpgSQLParser { * @param ast_context The AST context to use during parsing * @return The abstract syntax tree for the source function */ - std::unique_ptr Parse(std::vector &¶m_names, - std::vector &¶m_types, const std::string &func_body, - common::ManagedPointer ast_context); + std::unique_ptr Parse( + std::vector &¶m_names, std::vector &¶m_types, const std::string &func_body, + common::ManagedPointer ast_context); private: /** @@ -58,74 +51,74 @@ class PLpgSQLParser { * @param block The input JSON object * @return The AST for the block */ - std::unique_ptr ParseBlock(const nlohmann::json &block); + std::unique_ptr ParseBlock(const nlohmann::json &block); /** * Parse a function statement. * @param block The input JSON object * @return The AST for the function */ - std::unique_ptr ParseFunction(const nlohmann::json &function); + std::unique_ptr ParseFunction(const nlohmann::json &function); /** * Parse a declaration statement. * @param decl The input JSON object * @return The AST for the declaration */ - std::unique_ptr ParseDecl(const nlohmann::json &decl); + std::unique_ptr ParseDecl(const nlohmann::json &decl); /** * Parse an if-statement. * @param block The input JSON object * @return The AST for the if-statement */ - std::unique_ptr ParseIf(const nlohmann::json &branch); + std::unique_ptr ParseIf(const nlohmann::json &branch); /** * Parse a while-statement. * @param block The input JSON object * @return The AST for the while-statement */ - std::unique_ptr ParseWhile(const nlohmann::json &loop); + std::unique_ptr ParseWhile(const nlohmann::json &loop); /** * Parse a for-statement. * @param block The input JSON object * @return The AST for the for-statement */ - std::unique_ptr ParseFor(const nlohmann::json &loop); + std::unique_ptr ParseFor(const nlohmann::json &loop); /** * Parse a SQL statement. * @param sql_stmt The input JSON object * @return The AST for the SQL statement */ - std::unique_ptr ParseSQL(const nlohmann::json &sql_stmt); + std::unique_ptr ParseSQL(const nlohmann::json &sql_stmt); /** * Parse a dynamic SQL statement. * @param block The input JSON object * @return The AST for the dynamic SQL statement */ - std::unique_ptr ParseDynamicSQL(const nlohmann::json &sql_stmt); + std::unique_ptr ParseDynamicSQL(const nlohmann::json &sql_stmt); /** * Parse a SQL expression. * @param sql The SQL expression string * @return The AST for the SQL expression */ - std::unique_ptr ParseExprSQL(const std::string &sql); + std::unique_ptr ParseExprSQL(const std::string &sql); /** * Parse an expression. * @param expr The expression * @return The AST for the expression */ - std::unique_ptr ParseExpr(common::ManagedPointer expr); + std::unique_ptr ParseExpr(common::ManagedPointer expr); private: /** The UDF AST context */ - common::ManagedPointer udf_ast_context_; + common::ManagedPointer udf_ast_context_; /** The catalog accessor */ const common::ManagedPointer accessor_; @@ -137,5 +130,4 @@ class PLpgSQLParser { std::unordered_map symbol_table_; }; -} // namespace parser::udf -} // namespace noisepage +} // namespace noisepage::parser::udf diff --git a/src/parser/udf/udf_parser.cpp b/src/parser/udf/plpgsql_parser.cpp similarity index 99% rename from src/parser/udf/udf_parser.cpp rename to src/parser/udf/plpgsql_parser.cpp index bb91a374e0..e16e98a8e6 100644 --- a/src/parser/udf/udf_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -2,7 +2,7 @@ #include "binder/bind_node_visitor.h" #include "execution/ast/udf/udf_ast_nodes.h" -#include "parser/udf/udf_parser.h" +#include "parser/udf/plpgsql_parser.h" #include "libpg_query/pg_query.h" #include "nlohmann/json.hpp" From 4bd7b3cb0dcaa6dc9d4ec68e1d481049029b3d3f Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 16 Jul 2021 22:09:45 -0400 Subject: [PATCH 070/139] some updates to GenerateTrace.java to handle integration tests for user defined functions --- script/testing/junit/README.md | 2 +- script/testing/junit/sql/udf.sql | 23 ++ script/testing/junit/src/GenerateTrace.java | 293 +++++++++++++++++--- script/testing/junit/traces/udf.test | 64 +++++ src/include/parser/udf/plpgsql_parser.h | 4 + 5 files changed, 347 insertions(+), 39 deletions(-) create mode 100644 script/testing/junit/sql/udf.sql create mode 100644 script/testing/junit/traces/udf.test diff --git a/script/testing/junit/README.md b/script/testing/junit/README.md index 09951522fc..49ecf077e0 100644 --- a/script/testing/junit/README.md +++ b/script/testing/junit/README.md @@ -96,7 +96,7 @@ The procedure for running the `GenerateTrace.java` program is as follows: 1. Establish a local Postgres database and start the database server. The procedure to accomplish this depends on the particulars of your development environment. If you are using a CMU DB development machine, see the _PostgreSQL on CMU DB Development Machines_ section below. 2. Write your own SQL input file. The format of this file consists of SQL statements, one per line. Comments (denoted by `#`)) are permitted. 3. Compile the test infrastructure: `ant compile` -4. Run the filter trace program: `ant filter-trace`. The program expects 6 arguments: +4. Run the filter trace program: `ant generate-trace`. The program expects 6 arguments: - `path`: The path to the input file - `db-url`: The JDBC URL for the DBMS server - `db-user`: The database username diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql new file mode 100644 index 0000000000..f70f447740 --- /dev/null +++ b/script/testing/junit/sql/udf.sql @@ -0,0 +1,23 @@ +-- udf.sql +-- Integration tests for user-defined functions. +-- +-- Currently, these tests rely on the fact that we +-- utilize Postgres as a reference implementation +-- because all user-defined functions are implemented +-- in the Postgres PL/SQL dialect, PL/pgSQL. + +-- Create the test table +CREATE TABLE test(id INT PRIMARY KEY, x INT); + +-- Insert some data +INSERT INTO test (id, x) VALUES (0, 1), (1, 2), (2, 3); + +-- Create functions +CREATE FUNCTION return_constant() RETURNS INT AS $$ \ +BEGIN \ + RETURN 1; \ +END \ +$$ LANGUAGE PLPGSQL; + +-- Invoke +SELECT x, return_constant() FROM test; diff --git a/script/testing/junit/src/GenerateTrace.java b/script/testing/junit/src/GenerateTrace.java index 1d849059a4..ad208b85cb 100644 --- a/script/testing/junit/src/GenerateTrace.java +++ b/script/testing/junit/src/GenerateTrace.java @@ -1,47 +1,140 @@ -import java.io.*; -import java.sql.*; -import java.util.ArrayList; -import java.util.Arrays; +/** + * GenerateTrace.java + */ + +import java.io.File; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.io.BufferedReader; + +import java.sql.ResultSet; +import java.sql.Statement; +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.ResultSetMetaData; + import java.util.List; -import moglib.*; +import moglib.MogDb; +import moglib.MogSqlite; +import moglib.Constants; + +/** + * A generic logger interface. + * (Apparently `Logger` is already taken) + */ +interface ILogger { + public void info(final String message); + public void error(final String message); +} + +/** + * A dummy logger class that just writes to standard output. + * + * We might want to replace this eventually with an actual + * logger implementation, and this dummy class might(?) make + * that transition slightly less painful. For now, it also + * provides the slight benefit of making logging less verbose. + */ +class StandardLogger implements ILogger { + /** + * Construct a logger instance. + */ + StandardLogger() {} + + /** + * Log an informational message. + * @param message + */ + public void info(final String message) { + System.out.println(message); + } + + /** + * Log an error message. + * @param message + */ + public void error(final String message) { + System.err.println(message); + } +} /** - * class that convert sql statements to trace format - * first, establish a local postgresql database - * second, start the database server with "pg_ctl -D /usr/local/var/postgres start" - * third, modify the url, user and password string to match the database you set up - * finally, provide path to a file, run generateTrace with the file path as argument - * input file format: sql statements, one per line - * output file: to be tested by TracefileTest + * The GenerateTrace class converts SQL statements to the tracefile + * format used for integration testing. For instructions on how to + * use this program to generate a tracefile, see junit/README. */ public class GenerateTrace { + /** + * Error code for process exit on program success. + */ + private static final int EXIT_SUCCESS = 0; + + /** + * Error code for process exit on program failure. + */ + private static final int EXIT_ERROR = 1; + + /** + * The expected number of commandline arguments. + */ + private static final int EXPECTED_ARGUMENT_COUNT = 5; + + /** + * The character used to delimit multiline statements (e.g. UDF definition). + */ + private static final String MULTILINE_DELIMITER = "\\"; + + /** + * The logger instance. + */ + private static final ILogger LOGGER = new StandardLogger(); + + /** + * Program entry point. + * @param args Commandline arguments + * @throws Throwable + */ public static void main(String[] args) throws Throwable { - System.out.println("Working Directory = " + System.getProperty("user.dir")); - String path = args[0]; + if (args.length < EXPECTED_ARGUMENT_COUNT) { + LOGGER.error("Error: invalid arguments"); + LOGGER.error("Usage: see junit/README.md"); + System.exit(EXIT_ERROR); + } + + LOGGER.info("Working Directory = " + System.getProperty("user.dir")); + + final String path = args[0]; File file = new File(path); - System.out.println("File path: " + path); + MogSqlite mog = new MogSqlite(file); - // open connection to postgresql database with jdbc + + // Open connection to Postgre database over JDBC MogDb db = new MogDb(args[1], args[2], args[3]); - Connection conn = db.getDbTest().newConn(); - // remove existing table name - List tab = getAllExistingTableName(mog,conn); - removeExistingTable(tab,conn); + Connection connection = db.getDbTest().newConn(); + + // Initialize the database + removeAllTables(mog, connection); + removeAllFunctions(mog, connection); String line; String label; Statement statement = null; BufferedReader br = new BufferedReader(new FileReader(file)); - // create output file + + // Create output file FileWriter writer = new FileWriter(new File(Constants.DEST_DIR, args[4])); + int expected_result_num = -1; boolean include_result = false; - while (null != (line = br.readLine())) { + while (null != (line = readLine(br, MULTILINE_DELIMITER))) { line = line.trim(); - // execute sql statement + LOGGER.info(line); + + // Execute SQL statement try{ - statement = conn.createStatement(); + statement = connection.createStatement(); statement.execute(line); label = Constants.STATEMENT_OK; } catch (SQLException e) { @@ -163,26 +256,150 @@ public static void main(String[] args) throws Throwable { } writer.close(); br.close(); + + System.exit(EXIT_SUCCESS); } - public static void writeToFile(FileWriter writer, String str) throws IOException { - writer.write(str); + /** + * Read a line from the specified `BufferedReader` instance. + * @param reader The instance from which lines are read + * @param delimiter The character used to delimit multiline statements + * @return The input line, or `null` on end of input + */ + private static String readLine(BufferedReader reader, final String delimiter) throws IOException { + StringBuilder builder = new StringBuilder(); + for (;;) { + final String input = reader.readLine(); + if (input == null) { + return null; + } + + if (input.endsWith(delimiter)) { + builder.append( + input.substring(0, input.length() - delimiter.length() - 1) + .trim() + " "); + } else { + builder.append(input); + break; + } + } + return builder.toString(); + } + + /** + * Write the specified line to a file using the provided `FileWriter`. + * @param writer The `FileWriter` instance + * @param line The line to be written + * @throws IOException On IO error + */ + public static void writeToFile(FileWriter writer, final String line) throws IOException { + writer.write(line); writer.write('\n'); } - public static void removeExistingTable(List tab, Connection connection) throws SQLException { - for(String i:tab){ - Statement st = connection.createStatement(); - String sql = "DROP TABLE IF EXISTS " + i + " CASCADE"; - st.execute(sql); + /* ------------------------------------------------------------------------ + Table Management + ------------------------------------------------------------------------ */ + + /** + * Remove all existing tables from the database + * @param mog The `MogSqlite` instance + * @param connection The database connection + * @throws SQLException On SQL error + */ + private static void removeAllTables(MogSqlite mog, Connection connection) throws SQLException { + final List tableNames = getExistingTableNames(mog, connection); + removeTables(tableNames, connection); + } + + /** + * Get the names of all existing tables in the database. + * @param mog The `MogSqlite` instance + * @param connection The database connection + * @return A list of all table names + * @throws SQLException On SQL exception + */ + public static List getExistingTableNames(MogSqlite mog, Connection connection) throws SQLException { + final String query = "SELECT TABLENAME FROM pg_tables WHERE schemaname = 'public';"; + Statement statement = connection.createStatement(); + statement.execute(query); + return mog.processResults(statement.getResultSet()); + } + + /** + * Remove all specified tables from the database. + * @param tableNames The collection of table names to remove + * @param connection The database connection + * @throws SQLException On SQL error + */ + private static void removeTables(final List tableNames, Connection connection) throws SQLException { + for (final String tableName : tableNames){ + removeTable(tableName, connection); } } - public static List getAllExistingTableName(MogSqlite mog,Connection connection) throws SQLException { - Statement st = connection.createStatement(); - String getTableName = "SELECT tablename FROM pg_tables WHERE schemaname = 'public';"; - st.execute(getTableName); - ResultSet rs = st.getResultSet(); - List res = mog.processResults(rs); - return res; + + /** + * Remove the specified table from the database. + * @param tableName The name of the table to remove + * @param connection The database connection + * @throws SQLException On SQL error + */ + private static void removeTable(final String tableName, Connection connection) throws SQLException { + final String query = "DROP TABLE IF EXISTS " + tableName + " CASCADE"; + Statement statement = connection.createStatement(); + statement.execute(query); + } + + /* ------------------------------------------------------------------------ + Function Management + ------------------------------------------------------------------------ */ + + /** + * Remove all existing functions from the database. + * @param mog The `MogSqlite` instance. + * @param connection The database connection. + * @throws SQLException On SQL error + */ + private static void removeAllFunctions(MogSqlite mog, Connection connection) throws SQLException { + final List functionNames = getExistingFunctions(mog, connection); + removeFunctions(functionNames, connection); + } + + /** + * Get the names of all existing functions in the database. + * @param mog The MogSqlite instance + * @param connection The databse connection + * @return A collection of the function names + * @throws SQLException On SQL error + */ + private static List getExistingFunctions(MogSqlite mog, Connection connection) throws SQLException { + final String query = "SELECT proname FROM pg_proc WHERE pronamespace = 'public'::regnamespace;"; + Statement statement = connection.createStatement(); + statement.execute(query); + return mog.processResults(statement.getResultSet()); + } + + /** + * Remove all of the functions in `functionNames` from the database. + * @param functionNames The names of the functions to remove + * @param connection The database connection + * @throws SQLException On SQL error + */ + private static void removeFunctions(final List functionNames, Connection connection) throws SQLException { + for (final String functionName : functionNames) { + removeFunction(functionName, connection); + } + } + + /** + * Remove the function identified by `functionName` from the database. + * @param functionName The name of the function to remove + * @param connection The database connection + * @throws SQLException On SQL error + */ + private static void removeFunction(final String functionName, Connection connection) throws SQLException { + final String query = "DROP FUNCTION IF EXISTS " + functionName + " CASCADE;"; + Statement statement = connection.createStatement(); + statement.execute(query); } } diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test new file mode 100644 index 0000000000..b190716b86 --- /dev/null +++ b/script/testing/junit/traces/udf.test @@ -0,0 +1,64 @@ +statement ok +-- udf.sql + +statement ok +-- Integration tests for user-defined functions. + +statement ok +-- + +statement ok +-- Currently, these tests rely on the fact that we + +statement ok +-- utilize Postgres as a reference implementation + +statement ok +-- because all user-defined functions are implemented + +statement ok +-- in the Postgres PL/SQL dialect, PL/pgSQL. + +statement ok + + +statement ok +-- Create the test table + +statement ok +CREATE TABLE test(id INT PRIMARY KEY, x INT); + +statement ok + + +statement ok +-- Insert some data + +statement ok +INSERT INTO test (id, x) VALUES (0, 1), (1, 2), (2, 3); + +statement ok + + +statement ok +-- Create functions + +statement ok +CREATE FUNCTION return_constant() RETURNS INT AS $$ BEGIN RETURN 1; END $$ LANGUAGE PLPGSQL; + +statement ok + + +statement ok +-- Invoke + +query II rowsort +SELECT x, return_constant() FROM test; +---- +1 +1 +2 +1 +3 +1 + diff --git a/src/include/parser/udf/plpgsql_parser.h b/src/include/parser/udf/plpgsql_parser.h index e82dc28013..ae194044dd 100644 --- a/src/include/parser/udf/plpgsql_parser.h +++ b/src/include/parser/udf/plpgsql_parser.h @@ -20,6 +20,10 @@ namespace noisepage::parser::udf { /** * The PLpgSQLParser class parses source PL/pgSQL to an abstract syntax tree. + * + * Internally, PLpgSQLParser utilizes libpg_query to perform the actual parsing + * of the input PL/pgSQL source, and then maps the representation from libpg_query + * to our our internal representation that then proceeds through code generation. */ class PLpgSQLParser { public: From 9343add13700734549af51580edec70f87654cd0 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sat, 17 Jul 2021 10:58:27 -0400 Subject: [PATCH 071/139] a little more refactoring for GenerateTrace.java, happy with it now --- script/testing/junit/src/GenerateTrace.java | 169 +++++++++++--------- script/testing/junit/traces/udf.test | 1 + 2 files changed, 95 insertions(+), 75 deletions(-) diff --git a/script/testing/junit/src/GenerateTrace.java b/script/testing/junit/src/GenerateTrace.java index ad208b85cb..893a7bba9b 100644 --- a/script/testing/junit/src/GenerateTrace.java +++ b/script/testing/junit/src/GenerateTrace.java @@ -86,6 +86,11 @@ public class GenerateTrace { */ private static final String MULTILINE_DELIMITER = "\\"; + /** + * The current working directory. + */ + private static final String WORKING_DRIECTORY = System.getProperty("user.dir"); + /** * The logger instance. */ @@ -103,58 +108,74 @@ public static void main(String[] args) throws Throwable { System.exit(EXIT_ERROR); } - LOGGER.info("Working Directory = " + System.getProperty("user.dir")); - - final String path = args[0]; - File file = new File(path); + LOGGER.info("Working Directory = " + WORKING_DRIECTORY); + + // Parse commandline arguments + final String inputPath = args[0]; + final String jdbcUrl = args[1]; + final String dbUsername = args[2]; + final String dbPassword = args[3]; + final String outputPath = args[4]; - MogSqlite mog = new MogSqlite(file); + MogSqlite mog = new MogSqlite(new File(inputPath)); // Open connection to Postgre database over JDBC - MogDb db = new MogDb(args[1], args[2], args[3]); + MogDb db = new MogDb(jdbcUrl, dbUsername, dbPassword); Connection connection = db.getDbTest().newConn(); // Initialize the database removeAllTables(mog, connection); removeAllFunctions(mog, connection); + BufferedReader reader = new BufferedReader(new FileReader(new File(inputPath))); + FileWriter writer = new FileWriter(new File(Constants.DEST_DIR, outputPath)); + + System.exit(run(db, mog, connection, reader, writer)); + } + + /** + * Run trace generation. + * @param db The `MogDb` instance + * @param mog The `MogSqlite` instance + * @param connection The database connection + * @param reader The buffered reader for the input file + * @param writer The file writer for the output file + * @return The status code + */ + private static int run(MogDb db, MogSqlite mog, Connection connection, + BufferedReader reader, FileWriter writer) throws SQLException, IOException { String line; String label; Statement statement = null; - BufferedReader br = new BufferedReader(new FileReader(file)); - - // Create output file - FileWriter writer = new FileWriter(new File(Constants.DEST_DIR, args[4])); - + int expected_result_num = -1; boolean include_result = false; - while (null != (line = readLine(br, MULTILINE_DELIMITER))) { + while (null != (line = readLine(reader, MULTILINE_DELIMITER))) { line = line.trim(); - LOGGER.info(line); // Execute SQL statement - try{ + try { statement = connection.createStatement(); statement.execute(line); label = Constants.STATEMENT_OK; } catch (SQLException e) { - System.err.println("Error executing SQL Statement: '" + line + "'; " + e.getMessage()); + LOGGER.error("Error executing SQL Statement: '" + line + "'; " + e.getMessage()); label = Constants.STATEMENT_ERROR; } catch (Throwable e) { label = Constants.STATEMENT_ERROR; } - if(line.startsWith("SELECT") || line.toLowerCase().startsWith("with")) { + if (line.startsWith("SELECT") || line.startsWith("WITH")) { ResultSet rs = statement.getResultSet(); - if (line.toLowerCase().startsWith("with") && null == rs) { + if (line.startsWith("WITH") && null == rs) { // We might have a query that begins with `WITH` that has a null result set int updateCount = statement.getUpdateCount(); // check if expected number is equal to update count - if(expected_result_num>=0 && expected_result_num!=updateCount){ + if (expected_result_num >= 0 && expected_result_num != updateCount) { label = Constants.STATEMENT_ERROR; } - writeToFile(writer, label); - writeToFile(writer, line); + writeLine(writer, label); + writeLine(writer, line); writer.write('\n'); expected_result_num = -1; continue; @@ -165,13 +186,13 @@ public static void main(String[] args) throws Throwable { for (int i = 1; i <= rsmd.getColumnCount(); ++i) { String colTypeName = rsmd.getColumnTypeName(i); MogDb.DbColumnType colType = db.getDbTest().getDbColumnType(colTypeName); - if(colType==MogDb.DbColumnType.FLOAT){ + if (colType == MogDb.DbColumnType.FLOAT) { typeString += "R"; - }else if(colType==MogDb.DbColumnType.INTEGER){ + } else if (colType == MogDb.DbColumnType.INTEGER) { typeString += "I"; - }else if(colType==MogDb.DbColumnType.TEXT){ + } else if(colType == MogDb.DbColumnType.TEXT) { typeString += "T"; - }else{ + } else { System.out.println(colTypeName + " column invalid"); } } @@ -186,78 +207,76 @@ public static void main(String[] args) throws Throwable { sortOption = "rowsort"; mog.sortMode = "rowsort"; } - String query_sort = Constants.QUERY + " " + typeString + " " + sortOption; - writeToFile(writer, query_sort); - writeToFile(writer, line); - writeToFile(writer, Constants.SEPARATION); - List res = mog.processResults(rs); - // compute the hash - String hash = TestUtility.getHashFromDb(res); - String queryResult = ""; - // when include_result is true, set queryResult to be exact result instead of hash - if(include_result){ - for(String i:res){ - queryResult += i; - queryResult += "\n"; + final String query_sort = Constants.QUERY + " " + typeString + " " + sortOption; + writeLine(writer, query_sort); + writeLine(writer, line); + writeLine(writer, Constants.SEPARATION); + + final List results = mog.processResults(rs); + final String hash = TestUtility.getHashFromDb(results); + + StringBuilder resultBuilder = new StringBuilder(); + if (include_result) { + for (final String result : results) { + resultBuilder.append(result); + resultBuilder.append('\n'); } - queryResult = queryResult.trim(); - }else{ - // if expected number of results is specified - if(expected_result_num>=0){ - queryResult = "Expected " + expected_result_num + " values hashing to " + hash; - }else{ - if(res.size()>0){ - // set queryResult to format x values hashing to xxx - queryResult = res.size() + " values hashing to " + hash; + } else { + // Expected number of results is specified + if (expected_result_num >= 0) { + resultBuilder.append("Expected " + expected_result_num + " values hashing to " + hash); + } else { + if (results.size() > 0) { + resultBuilder.append(results.size() + " values hashing to " + hash); } - // set queryResult to be exact result instead of hash when - // result size is smaller than Constants.DISPLAY_RESULT_SIZE - if(res.size() < Constants.DISPLAY_RESULT_SIZE){ - queryResult = ""; - for(String i:res){ - queryResult += i; - queryResult += "\n"; + if (results.size() < Constants.DISPLAY_RESULT_SIZE) { + resultBuilder.setLength(0); + for (final String result : results) { + resultBuilder.append(result); + resultBuilder.append('\n'); } - queryResult = queryResult.trim(); } } } - writeToFile(writer, queryResult); - if(res.size()>0){ + + writeLine(writer, resultBuilder.toString()); + if (results.size() > 0) { writer.write('\n'); } + include_result = false; expected_result_num = -1; - } else if(line.startsWith(Constants.HASHTAG)){ - writeToFile(writer, line); - if(line.contains(Constants.NUM_OUTPUT_FLAG)){ - // case for specifying the expected number of outputs - String[] arr = line.split(" "); - expected_result_num = Integer.parseInt(arr[arr.length-1]); - }else if(line.contains(Constants.FAIL_FLAG)){ - // case for expecting the query to fail + } else if (line.startsWith(Constants.HASHTAG)) { + writeLine(writer, line); + if (line.contains(Constants.NUM_OUTPUT_FLAG)) { + // Case for specifying the expected number of outputs + final String[] arr = line.split(" "); + expected_result_num = Integer.parseInt(arr[arr.length - 1]); + } else if (line.contains(Constants.FAIL_FLAG)) { + // Case for expecting the query to fail label = Constants.STATEMENT_ERROR; - } else if(line.contains(Constants.EXPECTED_OUTPUT_FLAG)){ - // case for including exact result in mog.queryResult + } else if (line.contains(Constants.EXPECTED_OUTPUT_FLAG)) { + // Case for including exact result in mog.queryResult include_result = true; } - } else{ - // other sql statements - int rs = statement.getUpdateCount(); + } else { + // Other sql statements + final int updateCount = statement.getUpdateCount(); // check if expected number is equal to update count - if(expected_result_num>=0 && expected_result_num!=rs){ + if (expected_result_num >= 0 && expected_result_num != updateCount){ label = Constants.STATEMENT_ERROR; } - writeToFile(writer, label); - writeToFile(writer, line); + writeLine(writer, label); + writeLine(writer, line); writer.write('\n'); expected_result_num = -1; } } + writer.close(); - br.close(); + reader.close(); - System.exit(EXIT_SUCCESS); + return EXIT_SUCCESS; } /** @@ -292,7 +311,7 @@ private static String readLine(BufferedReader reader, final String delimiter) th * @param line The line to be written * @throws IOException On IO error */ - public static void writeToFile(FileWriter writer, final String line) throws IOException { + public static void writeLine(FileWriter writer, final String line) throws IOException { writer.write(line); writer.write('\n'); } diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index b190716b86..48700f576e 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -62,3 +62,4 @@ SELECT x, return_constant() FROM test; 3 1 + From 635a6de7c828191819db4d530cfcbb0cf0fa4a03 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sat, 17 Jul 2021 11:06:47 -0400 Subject: [PATCH 072/139] passing first integration test for UDFs --- script/testing/junit/sql/udf.sql | 5 +++-- script/testing/junit/src/GenerateTrace.java | 9 +++++++++ script/testing/junit/traces/udf.test | 9 ++++++--- src/network/noisepage_server.cpp | 2 +- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index f70f447740..260c1547d6 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -12,12 +12,13 @@ CREATE TABLE test(id INT PRIMARY KEY, x INT); -- Insert some data INSERT INTO test (id, x) VALUES (0, 1), (1, 2), (2, 3); --- Create functions +-- ---------------------------------------------------------------------------- +-- return_constant() + CREATE FUNCTION return_constant() RETURNS INT AS $$ \ BEGIN \ RETURN 1; \ END \ $$ LANGUAGE PLPGSQL; --- Invoke SELECT x, return_constant() FROM test; diff --git a/script/testing/junit/src/GenerateTrace.java b/script/testing/junit/src/GenerateTrace.java index 893a7bba9b..8f64e0a8bb 100644 --- a/script/testing/junit/src/GenerateTrace.java +++ b/script/testing/junit/src/GenerateTrace.java @@ -25,7 +25,16 @@ * (Apparently `Logger` is already taken) */ interface ILogger { + /** + * Log an informational message. + * @param message The message + */ public void info(final String message); + + /** + * Log an error message. + * @param message The message + */ public void error(final String message); } diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 48700f576e..55f820ee4c 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -41,16 +41,19 @@ statement ok statement ok --- Create functions +-- ---------------------------------------------------------------------------- statement ok -CREATE FUNCTION return_constant() RETURNS INT AS $$ BEGIN RETURN 1; END $$ LANGUAGE PLPGSQL; +-- return_constant() statement ok statement ok --- Invoke +CREATE FUNCTION return_constant() RETURNS INT AS $$ BEGIN RETURN 1; END $$ LANGUAGE PLPGSQL; + +statement ok + query II rowsort SELECT x, return_constant() FROM test; diff --git a/src/network/noisepage_server.cpp b/src/network/noisepage_server.cpp index 14b9254b1d..de83c1dbd1 100644 --- a/src/network/noisepage_server.cpp +++ b/src/network/noisepage_server.cpp @@ -140,7 +140,7 @@ void TerrierServer::RunServer() { RegisterSocket(); // Register the Unix domain socket. - // RegisterSocket(); + RegisterSocket(); // Register the ConnectionDispatcherTask. This handles connections to the sockets created above. dispatcher_task_ = thread_registry_->RegisterDedicatedThread( From da68ea7c52803732a283d48a52fc1b93d3e40e19 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sat, 17 Jul 2021 15:56:44 -0400 Subject: [PATCH 073/139] fixed bug in CreateFunction DDL executor, another integration test passing --- script/testing/junit/sql/udf.sql | 11 +++++++ script/testing/junit/traces/udf.test | 29 +++++++++++++++++++ src/catalog/postgres/pg_proc_impl.cpp | 4 --- src/execution/sql/ddl_executors.cpp | 6 ++-- .../execution/ast/udf/udf_ast_context.h | 6 ++-- src/include/execution/ast/udf/udf_ast_nodes.h | 6 ++-- src/include/parser/udf/plpgsql_parser.h | 6 ++-- .../plannodes/create_function_plan_node.h | 4 +-- src/parser/udf/plpgsql_parser.cpp | 12 ++++---- 9 files changed, 63 insertions(+), 21 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 260c1547d6..2a08396d52 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -22,3 +22,14 @@ END \ $$ LANGUAGE PLPGSQL; SELECT x, return_constant() FROM test; + +-- ---------------------------------------------------------------------------- +-- return_input() + +CREATE FUNCTION return_input(x INT) RETURNS INT AS $$ \ +BEGIN \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT x, return_input(x) FROM test; diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 55f820ee4c..15f8818825 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -66,3 +66,32 @@ SELECT x, return_constant() FROM test; 1 +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- return_input() + +statement ok + + +statement ok +CREATE FUNCTION return_input(x INT) RETURNS INT AS $$ BEGIN RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query II rowsort +SELECT x, return_input(x) FROM test; +---- +1 +1 +2 +2 +3 +3 + + diff --git a/src/catalog/postgres/pg_proc_impl.cpp b/src/catalog/postgres/pg_proc_impl.cpp index ea1a457e4a..d6093b0813 100644 --- a/src/catalog/postgres/pg_proc_impl.cpp +++ b/src/catalog/postgres/pg_proc_impl.cpp @@ -92,10 +92,6 @@ bool PgProcImpl::CreateProcedure(const common::ManagedPointer arg_name_vec{}; arg_name_vec.reserve(args.size() * sizeof(storage::VarlenEntry)); std::copy(args.cbegin(), args.cend(), std::back_inserter(arg_name_vec)); - // arg_name_vec.reserve(args.size() * ); - // for (auto &arg : args) { - // arg_name_vec.push_back(arg); - // } const auto arg_names_varlen = storage::StorageUtil::CreateVarlen(args); const auto arg_types_varlen = storage::StorageUtil::CreateVarlen(arg_types); diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index d2d77513da..bdfca15b56 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -76,7 +76,7 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerGetDatabaseOid()}; std::unique_ptr ast{}; try { - ast = udf_parser.Parse(node->GetFunctionParameterNames(), std::move(param_type_ids), body, + ast = udf_parser.Parse(node->GetFunctionParameterNames(), param_type_ids, body, (common::ManagedPointer(&udf_ast_context))); } catch (Exception &e) { return false; @@ -95,8 +95,8 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerGetFunctionParameterNames().size(); i++) { - const auto &name = node->GetFunctionParameterNames()[i]; - const auto &type = parser::ReturnType::DataTypeToTypeId(node->GetFunctionParameterTypes()[i]); + const auto name = node->GetFunctionParameterNames()[i]; + const auto type = parser::ReturnType::DataTypeToTypeId(node->GetFunctionParameterTypes()[i]); fn_params.emplace_back( codegen.MakeField(ast_context->GetIdentifier(name), codegen.TplType(execution::sql::GetTypeId(type)))); } diff --git a/src/include/execution/ast/udf/udf_ast_context.h b/src/include/execution/ast/udf/udf_ast_context.h index 64f7cde6f7..49a3be5af4 100644 --- a/src/include/execution/ast/udf/udf_ast_context.h +++ b/src/include/execution/ast/udf/udf_ast_context.h @@ -57,8 +57,10 @@ class UdfAstContext { * @return The name of the variable at the specified index */ const std::string &GetLocalVariableAtIndex(const std::size_t index) { - NOISEPAGE_ASSERT(local_variables_.size() >= index, "Bad variable"); - // TODO(Kyle): Why did this originally have index - 1? + NOISEPAGE_ASSERT(local_variables_.size() >= index, "Index out of range"); + // TODO(Kyle): I moved the subtraction to the call site because + // it seems misleading to have a getter for an index but deliver + // a local that does not actually appear at that index... return local_variables_.at(index); } diff --git a/src/include/execution/ast/udf/udf_ast_nodes.h b/src/include/execution/ast/udf/udf_ast_nodes.h index 78f71e2755..376c9d242a 100644 --- a/src/include/execution/ast/udf/udf_ast_nodes.h +++ b/src/include/execution/ast/udf/udf_ast_nodes.h @@ -639,12 +639,14 @@ class FunctionAST : public AbstractAST { * @param parameter_names The names of the parameters to the function * @param parameter_types The types of the parameters to the function */ - FunctionAST(std::unique_ptr &&body, std::vector &¶meter_names, - std::vector &¶meter_types) + FunctionAST(std::unique_ptr &&body, std::vector parameter_names, + std::vector parameter_types) : body_{std::move(body)}, parameter_names_{std::move(parameter_names)}, parameter_types_{std::move(parameter_types)} { NOISEPAGE_ASSERT(parameter_names_.size() == parameter_types_.size(), "Parameter Name and Type Mismatch"); + // TODO(Kyle): The copies made in this constructor may not be necessary, + // I need to look more closely at the ownership for this data } /** diff --git a/src/include/parser/udf/plpgsql_parser.h b/src/include/parser/udf/plpgsql_parser.h index ae194044dd..8e77d64943 100644 --- a/src/include/parser/udf/plpgsql_parser.h +++ b/src/include/parser/udf/plpgsql_parser.h @@ -20,7 +20,7 @@ namespace noisepage::parser::udf { /** * The PLpgSQLParser class parses source PL/pgSQL to an abstract syntax tree. - * + * * Internally, PLpgSQLParser utilizes libpg_query to perform the actual parsing * of the input PL/pgSQL source, and then maps the representation from libpg_query * to our our internal representation that then proceeds through code generation. @@ -46,8 +46,8 @@ class PLpgSQLParser { * @return The abstract syntax tree for the source function */ std::unique_ptr Parse( - std::vector &¶m_names, std::vector &¶m_types, const std::string &func_body, - common::ManagedPointer ast_context); + const std::vector ¶m_names, const std::vector ¶m_types, + const std::string &func_body, common::ManagedPointer ast_context); private: /** diff --git a/src/include/planner/plannodes/create_function_plan_node.h b/src/include/planner/plannodes/create_function_plan_node.h index 2a3cac21d8..b4cf421c8a 100644 --- a/src/include/planner/plannodes/create_function_plan_node.h +++ b/src/include/planner/plannodes/create_function_plan_node.h @@ -243,12 +243,12 @@ class CreateFunctionPlanNode : public AbstractPlanNode { /** * @return parameter names of the user defined function */ - std::vector GetFunctionParameterNames() const { return function_param_names_; } + const std::vector &GetFunctionParameterNames() const { return function_param_names_; } /** * @return parameter types of the user defined function */ - std::vector GetFunctionParameterTypes() const { + const std::vector &GetFunctionParameterTypes() const { return function_param_types_; } diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index e16e98a8e6..b8da04963a 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -45,8 +45,8 @@ static constexpr const char K_PLPGSQL_ROW[] = "PLpgSQL_row"; static constexpr const char K_PLPGSQL_STMT_DYNEXECUTE[] = "PLpgSQL_stmt_dynexecute"; std::unique_ptr PLpgSQLParser::Parse( - std::vector &¶m_names, std::vector &¶m_types, const std::string &func_body, - common::ManagedPointer ast_context) { + const std::vector ¶m_names, const std::vector ¶m_types, + const std::string &func_body, common::ManagedPointer ast_context) { auto result = pg_query_parse_plpgsql(func_body.c_str()); if (result.error != nullptr) { pg_query_free_plpgsql_parse_result(result); @@ -72,8 +72,8 @@ std::unique_ptr PLpgSQLParser::Parse( udf_ast_context_->SetVariableType(udf_name, param_types[i++]); } const auto function = function_list[0][K_PLPGSQL_FUNCTION]; - auto function_ast = std::make_unique( - ParseFunction(function), std::move(param_names), std::move(param_types)); + auto function_ast = + std::make_unique(ParseFunction(function), param_names, param_types); return function_ast; } @@ -110,8 +110,10 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl stmts.push_back(ParseIf(stmt[K_PLPGSQL_STMT_IF])); } else if (stmt_names.key() == K_PLPGSQL_STMT_ASSIGN) { // TODO(Kyle): Need to fix Assignment expression / statement + // NOTE(Kyle): We subtract 1 here because variable numbers from + // the Postres parser index from 1 rather than 0 (?) const auto &var_name = - udf_ast_context_->GetLocalVariableAtIndex(stmt[K_PLPGSQL_STMT_ASSIGN][K_VARNO].get()); + udf_ast_context_->GetLocalVariableAtIndex(stmt[K_PLPGSQL_STMT_ASSIGN][K_VARNO].get() - 1); auto lhs = std::make_unique(var_name); auto rhs = ParseExprSQL(stmt[K_PLPGSQL_STMT_ASSIGN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); stmts.push_back(std::make_unique(std::move(lhs), std::move(rhs))); From 8a2caa1c849084d99cbf7b5771a1bef91bcc00ce Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Mon, 19 Jul 2021 17:50:34 -0400 Subject: [PATCH 074/139] update 20min timeout in SQL integration tests to 40 because I keep failing CI on timeout and just want to see if this will go all the way through --- Jenkinsfile-utils.groovy | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Jenkinsfile-utils.groovy b/Jenkinsfile-utils.groovy index ff7e7df8b6..8603ce07ed 100644 --- a/Jenkinsfile-utils.groovy +++ b/Jenkinsfile-utils.groovy @@ -76,13 +76,13 @@ void stageTest(Boolean runPipelineMetrics, Map args = [:]) { buildType = (args.cmake.toUpperCase().contains("CMAKE_BUILD_TYPE=RELEASE")) ? "release" : "debug" - sh script: "cd build && PYTHONPATH=.. timeout 20m python3 -m script.testing.junit --build-type=$buildType --query-mode=simple", label: 'UnitTest (Simple)' + sh script: "cd build && PYTHONPATH=.. timeout 40m python3 -m script.testing.junit --build-type=$buildType --query-mode=simple", label: 'UnitTest (Simple)' sh script: "cd build && PYTHONPATH=.. timeout 60m python3 -m script.testing.junit --build-type=$buildType --query-mode=simple -a 'compiled_query_execution=True' -a 'bytecode_handlers_path=./bytecode_handlers_ir.bc'", label: 'UnitTest (Simple, Compiled Execution)' - sh script: "cd build && PYTHONPATH=.. timeout 20m python3 -m script.testing.junit --build-type=$buildType --query-mode=extended", label: 'UnitTest (Extended)' + sh script: "cd build && PYTHONPATH=.. timeout 40m python3 -m script.testing.junit --build-type=$buildType --query-mode=extended", label: 'UnitTest (Extended)' sh script: "cd build && PYTHONPATH=.. timeout 60m python3 -m script.testing.junit --build-type=$buildType --query-mode=extended -a 'compiled_query_execution=True' -a 'bytecode_handlers_path=./bytecode_handlers_ir.bc'", label: 'UnitTest (Extended, Compiled Execution)' if (runPipelineMetrics) { - sh script: "cd build && PYTHONPATH=.. timeout 20m python3 -m script.testing.junit --build-type=$buildType --query-mode=extended -a 'pipeline_metrics_enable=True' -a 'pipeline_metrics_sample_rate=100' -a 'counters_enable=True' -a 'query_trace_metrics_enable=True'", label: 'UnitTest (Extended with pipeline metrics, counters, and query trace metrics)' + sh script: "cd build && PYTHONPATH=.. timeout 40m python3 -m script.testing.junit --build-type=$buildType --query-mode=extended -a 'pipeline_metrics_enable=True' -a 'pipeline_metrics_sample_rate=100' -a 'counters_enable=True' -a 'query_trace_metrics_enable=True'", label: 'UnitTest (Extended with pipeline metrics, counters, and query trace metrics)' sh script: "cd build && PYTHONPATH=.. timeout 60m python3 -m script.testing.junit --build-type=$buildType --query-mode=extended -a 'pipeline_metrics_enable=True' -a 'pipeline_metrics_sample_rate=100' -a 'counters_enable=True' -a 'query_trace_metrics_enable=True' -a 'compiled_query_execution=True' -a 'bytecode_handlers_path=./bytecode_handlers_ir.bc'", label: 'UnitTest (Extended, Compiled Execution with pipeline metrics, counters, and query trace metrics)' } From b72aa9504fdf41231ed428450176f44a33217017 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Mon, 19 Jul 2021 22:43:49 -0400 Subject: [PATCH 075/139] fix bug in plpgsql parser that caused it to fail on integer variable declarations --- script/testing/junit/sql/udf.sql | 45 +++++++++-- script/testing/junit/traces/udf.test | 106 ++++++++++++++++++++++++-- src/include/parser/udf/string_utils.h | 29 +++++++ src/parser/udf/plpgsql_parser.cpp | 46 ++++++++--- src/parser/udf/string_utils.cpp | 32 ++++++++ test/parser/plpgsql_parser_test.cpp | 24 ++++++ 6 files changed, 260 insertions(+), 22 deletions(-) create mode 100644 src/include/parser/udf/string_utils.h create mode 100644 src/parser/udf/string_utils.cpp create mode 100644 test/parser/plpgsql_parser_test.cpp diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 2a08396d52..5a63c3ae70 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -6,11 +6,11 @@ -- because all user-defined functions are implemented -- in the Postgres PL/SQL dialect, PL/pgSQL. --- Create the test table -CREATE TABLE test(id INT PRIMARY KEY, x INT); +-- Create a test table +CREATE TABLE integers(x INT, y INT); -- Insert some data -INSERT INTO test (id, x) VALUES (0, 1), (1, 2), (2, 3); +INSERT INTO integers (x, y) VALUES (1, 1), (2, 2), (3, 3); -- ---------------------------------------------------------------------------- -- return_constant() @@ -21,7 +21,7 @@ BEGIN \ END \ $$ LANGUAGE PLPGSQL; -SELECT x, return_constant() FROM test; +SELECT x, return_constant() FROM integers; -- ---------------------------------------------------------------------------- -- return_input() @@ -32,4 +32,39 @@ BEGIN \ END \ $$ LANGUAGE PLPGSQL; -SELECT x, return_input(x) FROM test; +SELECT x, return_input(x) FROM integers; + +-- ---------------------------------------------------------------------------- +-- return_sum() + +CREATE FUNCTION return_sum(x INT, y INT) RETURNS INT AS $$ \ +BEGIN \ + RETURN x + y; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT x, y, return_sum(x, y) FROM integers; + +-- ---------------------------------------------------------------------------- +-- return_prod() + +CREATE FUNCTION return_product(x INT, y INT) RETURNS INT AS $$ \ +BEGIN \ + RETURN x * y; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT x, y, return_product(x, y) FROM integers; + +-- ---------------------------------------------------------------------------- +-- integer_decl() + +CREATE FUNCTION integer_decl() RETURNS INT AS $$ \ +DECLARE \ + x INT := 0; \ +BEGIN \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT x, y, integer_decl() FROM integers; diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 15f8818825..8c61e9d468 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -23,10 +23,10 @@ statement ok statement ok --- Create the test table +-- Create a test table statement ok -CREATE TABLE test(id INT PRIMARY KEY, x INT); +CREATE TABLE integers(x INT, y INT); statement ok @@ -35,7 +35,7 @@ statement ok -- Insert some data statement ok -INSERT INTO test (id, x) VALUES (0, 1), (1, 2), (2, 3); +INSERT INTO integers (x, y) VALUES (1, 1), (2, 2), (3, 3); statement ok @@ -56,7 +56,7 @@ statement ok query II rowsort -SELECT x, return_constant() FROM test; +SELECT x, return_constant() FROM integers; ---- 1 1 @@ -85,7 +85,7 @@ statement ok query II rowsort -SELECT x, return_input(x) FROM test; +SELECT x, return_input(x) FROM integers; ---- 1 1 @@ -95,3 +95,99 @@ SELECT x, return_input(x) FROM test; 3 +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- return_sum() + +statement ok + + +statement ok +CREATE FUNCTION return_sum(x INT, y INT) RETURNS INT AS $$ BEGIN RETURN x + y; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query III rowsort +SELECT x, y, return_sum(x, y) FROM integers; +---- +1 +1 +2 +2 +2 +4 +3 +3 +6 + + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- return_prod() + +statement ok + + +statement ok +CREATE FUNCTION return_product(x INT, y INT) RETURNS INT AS $$ BEGIN RETURN x * y; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query III rowsort +SELECT x, y, return_product(x, y) FROM integers; +---- +1 +1 +1 +2 +2 +4 +3 +3 +9 + + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- integer_decl() + +statement ok + + +statement ok +CREATE FUNCTION integer_decl() RETURNS INT AS $$ DECLARE x INT := 0; BEGIN RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query III rowsort +SELECT x, y, integer_decl() FROM integers; +---- +1 +1 +0 +2 +2 +0 +3 +3 +0 + + diff --git a/src/include/parser/udf/string_utils.h b/src/include/parser/udf/string_utils.h new file mode 100644 index 0000000000..94165545c9 --- /dev/null +++ b/src/include/parser/udf/string_utils.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +namespace noisepage::parser::udf { + +/** + * StringUtils is a static class that implements some basic + * string-processing utilities. Eventually, we might want to + * move functionality like this to our own internal algo library. + */ +class StringUtils { + public: + /** + * Convert a non-owned string to lowercase. + * @param string The input string + * @return The lowercased string + */ + static std::string Lower(const std::string &string); + + /** + * Strip whitespace from the start and end of a non-owned string. + * @param string The input string + * @return The stripped string + */ + static std::string Strip(const std::string &string); +}; + +} // namespace noisepage::parser::udf diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index b8da04963a..3b4ce7f0e1 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -3,15 +3,14 @@ #include "binder/bind_node_visitor.h" #include "execution/ast/udf/udf_ast_nodes.h" #include "parser/udf/plpgsql_parser.h" +#include "parser/udf/string_utils.h" #include "libpg_query/pg_query.h" #include "nlohmann/json.hpp" namespace noisepage::parser::udf { -/** - * @brief The identifiers used as keys in the parse tree. - */ +/** The identifiers used as keys in the parse tree */ static constexpr const char K_FUNCTION_LIST[] = "FunctionList"; static constexpr const char K_DATUMS[] = "datums"; static constexpr const char K_PLPGSQL_VAR[] = "PLpgSQL_var"; @@ -44,6 +43,15 @@ static constexpr const char K_NAME[] = "name"; static constexpr const char K_PLPGSQL_ROW[] = "PLpgSQL_row"; static constexpr const char K_PLPGSQL_STMT_DYNEXECUTE[] = "PLpgSQL_stmt_dynexecute"; +/** Variable declaration type identifiers */ +static constexpr const char DECL_TYPE_ID_INT[] = "int"; +static constexpr const char DECL_TYPE_ID_INTEGER[] = "integer"; +static constexpr const char DECL_TYPE_ID_DOUBLE[] = "double"; +static constexpr const char DECL_TYPE_ID_NUMERIC[] = "numeric"; +static constexpr const char DECL_TYPE_ID_VARCHAR[] = "varchar"; +static constexpr const char DECL_TYPE_ID_DATE[] = "date"; +static constexpr const char DECL_TYPE_ID_RECORD[] = "record"; + std::unique_ptr PLpgSQLParser::Parse( const std::vector ¶m_names, const std::vector ¶m_types, const std::string &func_body, common::ManagedPointer ast_context) { @@ -67,6 +75,7 @@ std::unique_ptr PLpgSQLParser::Parse( throw PARSER_EXCEPTION("Function list has size other than 1"); } + // TODO(Kyle): This is a zip() std::size_t i{0}; for (const auto &udf_name : param_names) { udf_ast_context_->SetVariableType(udf_name, param_types[i++]); @@ -135,48 +144,61 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &decl) { const auto &decl_names = decl.items().begin(); - if (decl_names.key() == K_PLPGSQL_VAR) { auto var_name = decl[K_PLPGSQL_VAR][K_REFNAME].get(); udf_ast_context_->AddVariable(var_name); - auto type = decl[K_PLPGSQL_VAR][K_DATATYPE][K_PLPGSQL_TYPE][K_TYPENAME].get(); - std::unique_ptr initial = nullptr; + + // Grab the type identifier from the PL/pgSQL parser + const std::string type = StringUtils::Strip( + StringUtils::Lower(decl[K_PLPGSQL_VAR][K_DATATYPE][K_PLPGSQL_TYPE][K_TYPENAME].get())); + + // Parse the initializer, if present + std::unique_ptr initial{nullptr}; if (decl[K_PLPGSQL_VAR].find(K_DEFAULT_VAL) != decl[K_PLPGSQL_VAR].end()) { initial = ParseExprSQL(decl[K_PLPGSQL_VAR][K_DEFAULT_VAL][K_PLPGSQL_EXPR][K_QUERY].get()); } + // Detemine if the variable has already been declared; + // if so, just re-use this type that has already been resolved type::TypeId temp_type{}; if (udf_ast_context_->GetVariableType(var_name, &temp_type)) { return std::make_unique(var_name, temp_type, std::move(initial)); } - if ((type.find("integer") != std::string::npos) || type.find("INTEGER") != std::string::npos) { + + // Otherwise, we perform a string comparison with the type identifier + // for the variable to determine the type for the declaration + + if ((type == DECL_TYPE_ID_INT) || (type == DECL_TYPE_ID_INTEGER)) { udf_ast_context_->SetVariableType(var_name, type::TypeId::INTEGER); return std::make_unique(var_name, type::TypeId::INTEGER, std::move(initial)); } - if (type == "double" || type.rfind("numeric") == 0) { + if ((type == DECL_TYPE_ID_DOUBLE) || (type == DECL_TYPE_ID_NUMERIC)) { + // TODO(Kyle): type.rfind("numeric") udf_ast_context_->SetVariableType(var_name, type::TypeId::DECIMAL); return std::make_unique(var_name, type::TypeId::DECIMAL, std::move(initial)); } - if (type == "varchar") { + if (type == DECL_TYPE_ID_VARCHAR) { udf_ast_context_->SetVariableType(var_name, type::TypeId::VARCHAR); return std::make_unique(var_name, type::TypeId::VARCHAR, std::move(initial)); } - if (type.find("date") != std::string::npos) { + if (type == DECL_TYPE_ID_DATE) { udf_ast_context_->SetVariableType(var_name, type::TypeId::DATE); return std::make_unique(var_name, type::TypeId::DATE, std::move(initial)); } - if (type == "record") { + if (type == DECL_TYPE_ID_RECORD) { udf_ast_context_->SetVariableType(var_name, type::TypeId::INVALID); return std::make_unique(var_name, type::TypeId::INVALID, std::move(initial)); } NOISEPAGE_ASSERT(false, "Unsupported Type"); } else if (decl_names.key() == K_PLPGSQL_ROW) { - auto var_name = decl[K_PLPGSQL_ROW][K_REFNAME].get(); + const auto var_name = decl[K_PLPGSQL_ROW][K_REFNAME].get(); NOISEPAGE_ASSERT(var_name == "*internal*", "Unexpected refname"); + // TODO(Kyle): Support row types later udf_ast_context_->SetVariableType(var_name, type::TypeId::INVALID); return std::make_unique(var_name, type::TypeId::INVALID, nullptr); } + // TODO(Kyle): Need to handle other types like row, table etc; throw PARSER_EXCEPTION("Declaration type not supported"); } diff --git a/src/parser/udf/string_utils.cpp b/src/parser/udf/string_utils.cpp new file mode 100644 index 0000000000..4cd0db11ab --- /dev/null +++ b/src/parser/udf/string_utils.cpp @@ -0,0 +1,32 @@ +#include "parser/udf/string_utils.h" + +#include + +namespace noisepage::parser::udf { + +std::string StringUtils::Lower(const std::string &string) { + std::string result{}; + std::transform(string.cbegin(), string.cend(), std::back_inserter(result), + [](unsigned char c) { return std::tolower(c); }); + return result; +} + +std::string StringUtils::Strip(const std::string &string) { + auto not_whitespace = [](unsigned char c) { return std::isspace(c) == 0; }; + + // Find the first non-whitespace character + auto begin = std::find_if(string.cbegin(), string.cend(), not_whitespace); + if (begin == string.cend()) { + return std::string{}; + } + + // Find the last non whitespace character + auto end = std::find_if(string.rbegin(), string.rend(), not_whitespace); + + // Construct the result + std::string result{}; + std::copy(begin, end.base(), std::back_inserter(result)); + return result; +} + +} // namespace noisepage::parser::udf diff --git a/test/parser/plpgsql_parser_test.cpp b/test/parser/plpgsql_parser_test.cpp new file mode 100644 index 0000000000..e1a19f2cd8 --- /dev/null +++ b/test/parser/plpgsql_parser_test.cpp @@ -0,0 +1,24 @@ +#include + +#include "parser/udf/string_utils.h" +#include "test_util/test_harness.h" + +namespace noisepage::parser { + +class PLpgSQLParserTest : public TerrierTest {}; + +TEST_F(PLpgSQLParserTest, LowerTest0) { + const std::string input{"HELLO WORLD"}; + const std::string expected{"hello world"}; + const auto result = udf::StringUtils::Lower(input); + EXPECT_EQ(expected, result); +} + +TEST_F(PLpgSQLParserTest, StripTest0) { + const std::string input{" hello "}; + const std::string expected{"hello"}; + const auto result = udf::StringUtils::Strip(input); + EXPECT_EQ(expected, result); +} + +} // namespace noisepage::parser From eb5ebc3b528a1f09d842cb9baab257e8438edf54 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Wed, 21 Jul 2021 14:21:55 -0400 Subject: [PATCH 076/139] integration tests for conditionals and while loop iteration, found a bug in for-loop iteration that we need to address --- ...n_cte_implementation.md => design_ctes.md} | 10 +- docs/design_udfs.md | 33 +++++ script/testing/junit/sql/udf.sql | 71 ++++++++++ script/testing/junit/traces/udf.test | 121 ++++++++++++++++++ src/parser/udf/plpgsql_parser.cpp | 8 +- 5 files changed, 239 insertions(+), 4 deletions(-) rename docs/{discussion_cte_implementation.md => design_ctes.md} (98%) create mode 100644 docs/design_udfs.md diff --git a/docs/discussion_cte_implementation.md b/docs/design_ctes.md similarity index 98% rename from docs/discussion_cte_implementation.md rename to docs/design_ctes.md index a733bf4988..c1dc33f151 100644 --- a/docs/discussion_cte_implementation.md +++ b/docs/design_ctes.md @@ -1,8 +1,14 @@ -# Discussion Doc: Common Table Expression Implementation +# Design Doc: Common Table Expressions + +### Overview This document provides an overview of some of the important features of our implementation of common table expressions (CTEs). -## Known Limitations +### Design + +TODO(Kyle): Fill this in. + +### Limitations Our current implementation of CTEs suffers from some known limitations which limits the queries we are able to execute. This section provides a comprehensive overview of the queries on which the system currently fails, a best-estimate of the underlying reason for the failure, and what might be required to address it. diff --git a/docs/design_udfs.md b/docs/design_udfs.md new file mode 100644 index 0000000000..67d0d895e6 --- /dev/null +++ b/docs/design_udfs.md @@ -0,0 +1,33 @@ +# Design Doc: User-Defined Functions + +### Overview + +This document describes important aspects of the design and implementation of user-defined functions in NoisePage. + +### Limitations + +This section describes known limitations of our implementation of UDFs. + +**Missing `RETURN`** + +In Postgres, a PL/pgSQL function that declares a return type but is missing a `RETURN` statement in the body of the function parses successfully, but results in a runtime error when the function is executed. Currently, we fail to parse such functions (which may be directly related to the issue below). + +**Implicit `RETURN`s** + +Currently, the following control flow is not supported: + +```sql +CREATE FUNCTION fun(x INT) RETURNS INT AS $$ +BEGIN + IF x > 10 THEN + RETURN 0; + ELSE + RETURN 1; + END IF; +END +$$ LANGUAGE PLPGSQL; +``` + +This fails in the parser because the library we use to parse the raw UDF (libpg_query) inserts an implicit empty `RETURN` at the end of the body of the function. This implicit `RETURN` has no associated expression, and therefore it fails when we attempt to parse it. + +Obviously, we can see that this implicit `RETURN` is unreachable code, so we know this UDF body is valid. diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 5a63c3ae70..cc0e7fb9f7 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -68,3 +68,74 @@ END \ $$ LANGUAGE PLPGSQL; SELECT x, y, integer_decl() FROM integers; + +-- ---------------------------------------------------------------------------- +-- conditional() +-- +-- TODO(Kyle): The final RETURN 0 is unreachable, but we +-- need this temporary hack to deal with missing logic in parser + +CREATE FUNCTION conditional(x INT) RETURNS INT AS $$ \ +BEGIN \ + IF x > 1 THEN \ + RETURN 1; \ + ELSE \ + RETURN 2; \ + END IF; \ + RETURN 0; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT x, conditional(x) FROM integers; + +-- ---------------------------------------------------------------------------- +-- proc_while() + +CREATE FUNCTION proc_while() RETURNS INT AS $$ \ +DECLARE \ + x INT := 0; \ +BEGIN \ + WHILE x < 10 LOOP \ + x = x + 1; \ + END LOOP; \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT x, proc_while() FROM integers; + +-- ---------------------------------------------------------------------------- +-- proc_fori() + +-- CREATE FUNCTION proc_fori() RETURNS INT AS $$ \ +-- DECLARE \ +-- x INT := 0; \ +-- BEGIN \ +-- FOR i IN 1..10 LOOP \ +-- x = x + 1; \ +-- END LOOP; \ +-- RETURN x; \ +-- END \ +-- $$ LANGUAGE PLPGSQL; + +-- SELECT x, proc_fori() FROM integers; + +-- ---------------------------------------------------------------------------- +-- proc_fors() + +-- CREATE TABLE temp(z INT); +-- INSERT INTO temp(z) VALUES (0), (1); + +-- CREATE FUNCTION proc_fors() RETURNS INT AS $$ \ +-- DECLARE \ +-- x INT := 0; \ +-- BEGIN \ +-- FOR v IN SELECT z FROM temp \ +-- LOOP \ +-- x = x + 1; \ +-- END LOOP; \ +-- RETURN x; \ +-- END \ +-- $$ LANGUAGE PLPGSQL; + +-- SELECT x, proc_fors() FROM integers; diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 8c61e9d468..90414bd5e2 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -191,3 +191,124 @@ SELECT x, y, integer_decl() FROM integers; 0 +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- conditional() + +statement ok +-- + +statement ok +-- TODO(Kyle): The final RETURN 0 is unreachable, but we + +statement ok +-- need this temporary hack to deal with missing logic in parser + +statement ok + + +statement ok +CREATE FUNCTION conditional(x INT) RETURNS INT AS $$ BEGIN IF x > 1 THEN RETURN 1; ELSE RETURN 2; END IF; RETURN 0; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query II rowsort +SELECT x, conditional(x) FROM integers; +---- +1 +2 +2 +1 +3 +1 + + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_while() + +statement ok + + +statement ok +CREATE FUNCTION proc_while() RETURNS INT AS $$ DECLARE x INT := 0; BEGIN WHILE x < 10 LOOP x = x + 1; END LOOP; RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query II rowsort +SELECT x, proc_while() FROM integers; +---- +1 +10 +2 +10 +3 +10 + + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_fori() + +statement ok + + +statement ok +-- CREATE FUNCTION proc_fori() RETURNS INT AS $$ -- DECLARE -- x INT := 0; -- BEGIN -- FOR i IN 1..10 LOOP -- x = x + 1; -- END LOOP; -- RETURN x; -- END -- $$ LANGUAGE PLPGSQL; + +statement ok + + +statement ok +-- SELECT x, proc_fori() FROM integers; + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_fors() + +statement ok + + +statement ok +-- CREATE TABLE temp(z INT); + +statement ok +-- INSERT INTO temp(z) VALUES (0), (1); + +statement ok + + +statement ok +-- CREATE FUNCTION proc_fors() RETURNS INT AS $$ -- DECLARE \ + +statement ok +-- x INT := 0; -- BEGIN -- FOR v IN SELECT z FROM temp -- LOOP -- x = x + 1; -- END LOOP; -- RETURN x; -- END -- $$ LANGUAGE PLPGSQL; + +statement ok + + +statement ok +-- SELECT x, proc_fors() FROM integers; + diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index 3b4ce7f0e1..378c672691 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -75,7 +75,8 @@ std::unique_ptr PLpgSQLParser::Parse( throw PARSER_EXCEPTION("Function list has size other than 1"); } - // TODO(Kyle): This is a zip() + // TODO(Kyle): This is a zip(), can we add our own generic + // algorithms library somewhere for stuff like this? std::size_t i{0}; for (const auto &udf_name : param_names) { udf_ast_context_->SetVariableType(udf_name, param_types[i++]); @@ -112,8 +113,11 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl for (const auto &stmt : block) { const auto stmt_names = stmt.items().begin(); if (stmt_names.key() == K_PLPGSQL_STMT_RETURN) { + // TODO(Kyle): Handle RETURN without expression + if (stmt[K_PLPGSQL_STMT_RETURN].empty()) { + throw NOT_IMPLEMENTED_EXCEPTION("RETURN without expression not implemented."); + } auto expr = ParseExprSQL(stmt[K_PLPGSQL_STMT_RETURN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); - // TODO(Kyle): Handle return stmt w/o expression stmts.push_back(std::make_unique(std::move(expr))); } else if (stmt_names.key() == K_PLPGSQL_STMT_IF) { stmts.push_back(ParseIf(stmt[K_PLPGSQL_STMT_IF])); From 0f2154b2d06bc1db4c73a31a85943e22228b44ea Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 22 Jul 2021 16:05:00 -0400 Subject: [PATCH 077/139] some refactoring of UDF AST context to make querying variables and record types easier --- script/testing/junit/sql/udf.sql | 30 ++--- src/binder/bind_node_visitor.cpp | 23 ++-- src/execution/compiler/udf/udf_codegen.cpp | 91 +++++++++----- src/execution/sql/ddl_executors.cpp | 5 +- src/include/binder/bind_node_visitor.h | 3 + .../execution/ast/udf/udf_ast_context.h | 113 +++++++++++++----- .../execution/compiler/udf/udf_codegen.h | 23 ++++ src/include/parser/udf/plpgsql_parser.h | 6 +- src/parser/udf/plpgsql_parser.cpp | 26 ++-- 9 files changed, 215 insertions(+), 105 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index cc0e7fb9f7..f1ac82c2be 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -123,19 +123,19 @@ SELECT x, proc_while() FROM integers; -- ---------------------------------------------------------------------------- -- proc_fors() --- CREATE TABLE temp(z INT); --- INSERT INTO temp(z) VALUES (0), (1); - --- CREATE FUNCTION proc_fors() RETURNS INT AS $$ \ --- DECLARE \ --- x INT := 0; \ --- BEGIN \ --- FOR v IN SELECT z FROM temp \ --- LOOP \ --- x = x + 1; \ --- END LOOP; \ --- RETURN x; \ --- END \ --- $$ LANGUAGE PLPGSQL; +CREATE TABLE temp(z INT); +INSERT INTO temp(z) VALUES (0), (1); + +CREATE FUNCTION proc_fors() RETURNS INT AS $$ \ +DECLARE \ + x INT := 0; \ + v RECORD; \ +BEGIN \ + FOR v IN (SELECT z FROM temp) LOOP \ + x = x + 1; \ + END LOOP; \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; --- SELECT x, proc_fors() FROM integers; +SELECT x, proc_fors() FROM integers; diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 8d8bec949b..19e250f1c5 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -700,10 +700,10 @@ void BindNodeVisitor::Visit(common::ManagedPointerGetVariableType(expr->GetColumnName(), &the_type)) { - expr->SetReturnValueType(the_type); + if (BindingForUDF() && udf_ast_context_->HasVariable(expr->GetColumnName())) { + const type::TypeId type = udf_ast_context_->GetVariableTypeFailFast(expr->GetColumnName()); + expr->SetReturnValueType(type); std::size_t idx = 0; if (udf_params_.count(expr->GetColumnName()) == 0) { udf_params_[expr->GetColumnName()] = std::make_pair("", udf_params_.size()); @@ -722,12 +722,15 @@ void BindNodeVisitor::Visit(common::ManagedPointerGetVariableType(expr->GetTableName(), &the_type)) { - NOISEPAGE_ASSERT(the_type == type::TypeId::INVALID, "unknown type"); - auto &fields = udf_ast_context_->GetRecordType(expr->GetTableName()); - auto it = std::find_if(fields.begin(), fields.end(), [=](auto p) { return p.first == expr->GetColumnName(); }); + } else if (BindingForUDF() && udf_ast_context_->HasVariable(expr->GetTableName())) { + const type::TypeId type = udf_ast_context_->GetVariableTypeFailFast(expr->GetTableName()); + NOISEPAGE_ASSERT(type == type::TypeId::INVALID, "Must be a RECORD type"); + + const auto fields = udf_ast_context_->GetRecordTypeFailFast(expr->GetTableName()); + auto it = + std::find_if(fields.cbegin(), fields.cend(), [=](auto p) { return p.first == expr->GetColumnName(); }); std::size_t idx = 0; - if (it != fields.end()) { + if (it != fields.cend()) { if (udf_params_.count(expr->GetColumnName()) == 0) { udf_params_[expr->GetColumnName()] = std::make_pair(expr->GetTableName(), udf_params_.size()); idx = udf_params_.size() - 1; @@ -758,7 +761,7 @@ void BindNodeVisitor::Visit(common::ManagedPointer SqlNodeVisitor::Visit(expr); // If any of the operands are typecasts, the typecast children should have been casted by now. Pull the children up. - for (size_t i = 0; i < expr->GetChildrenSize(); ++i) { + for (std::size_t i = 0; i < expr->GetChildrenSize(); ++i) { auto child = expr->GetChild(i); if (parser::ExpressionType::OPERATOR_CAST == child->GetExpressionType()) { NOISEPAGE_ASSERT(parser::ExpressionType::VALUE_CONSTANT == child->GetChild(0)->GetExpressionType(), @@ -1138,4 +1141,6 @@ void BindNodeVisitor::ValidateAndCorrectInsertValues( } } +bool BindNodeVisitor::BindingForUDF() const { return udf_ast_context_ != nullptr; } + } // namespace noisepage::binder diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 7151f76ec1..83ba28e2af 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -1,3 +1,6 @@ +#include "execution/compiler/udf/udf_codegen.h" + +#include "common/error/error_code.h" #include "common/error/exception.h" #include "binder/bind_node_visitor.h" @@ -23,7 +26,6 @@ #include "parser/postgresparser.h" #include "execution/ast/udf/udf_ast_nodes.h" -#include "execution/compiler/udf/udf_codegen.h" #include "planner/plannodes/abstract_plan_node.h" @@ -133,18 +135,27 @@ void UdfCodegen::Visit(ast::udf::StmtAST *ast) { UNREACHABLE("Not implemented"); void UdfCodegen::Visit(ast::udf::ExprAST *ast) { UNREACHABLE("Not implemented"); } void UdfCodegen::Visit(ast::udf::DeclStmtAST *ast) { - if (ast->Name() == "*internal*") { + if (ast->Name() == INTERNAL_DECL_ID) { return; } - const execution::ast::Identifier ident = codegen_->MakeFreshIdentifier(ast->Name()); - SymbolTable()[ast->Name()] = ident; + + const execution::ast::Identifier identifier = codegen_->MakeFreshIdentifier(ast->Name()); + SymbolTable()[ast->Name()] = identifier; auto prev_type = current_type_; execution::ast::Expr *tpl_type = nullptr; if (ast->Type() == type::TypeId::INVALID) { - // record type + // Record type execution::util::RegionVector fields{codegen_->GetAstContext()->GetRegion()}; - for (const auto &p : udf_ast_context_->GetRecordType(ast->Name())) { + + // TODO(Kyle): Handle unbound record types + const auto record_type = udf_ast_context_->GetRecordType(ast->Name()); + if (!record_type.has_value()) { + // Unbound record type + throw NOT_IMPLEMENTED_EXCEPTION("Unbound RECORD types not supported"); + } + + for (const auto &p : record_type.value()) { fields.push_back(codegen_->MakeField(codegen_->MakeIdentifier(p.first), codegen_->TplType(execution::sql::GetTypeId(p.second)))); } @@ -157,9 +168,9 @@ void UdfCodegen::Visit(ast::udf::DeclStmtAST *ast) { current_type_ = ast->Type(); if (ast->Initial() != nullptr) { ast->Initial()->Accept(this); - fb_->Append(codegen_->DeclareVar(ident, tpl_type, dst_)); + fb_->Append(codegen_->DeclareVar(identifier, tpl_type, dst_)); } else { - fb_->Append(codegen_->DeclareVarNoInit(ident, tpl_type)); + fb_->Append(codegen_->DeclareVarNoInit(identifier, tpl_type)); } current_type_ = prev_type; } @@ -212,8 +223,7 @@ void UdfCodegen::Visit(ast::udf::ValueExprAST *ast) { } void UdfCodegen::Visit(ast::udf::AssignStmtAST *ast) { - type::TypeId left_type = type::TypeId::INVALID; - udf_ast_context_->GetVariableType(ast->Destination()->Name(), &left_type); + const type::TypeId left_type = GetVariableType(ast->Destination()->Name()); current_type_ = left_type; reinterpret_cast(ast->Source())->Accept(this); @@ -418,8 +428,7 @@ void UdfCodegen::Visit(ast::udf::ForStmtAST *ast) { std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second < y->second; }); for (auto entry : sorted_vec) { // TODO(Kyle): Order these - type::TypeId type = type::TypeId::INVALID; - udf_ast_context_->GetVariableType(entry->first, &type); + const type::TypeId type = GetVariableType(entry->first); execution::ast::Builtin builtin{}; switch (type) { case type::TypeId::BOOLEAN: @@ -514,11 +523,10 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { execution::util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; for (auto &col : cols) { execution::ast::Expr *capture_var = codegen_->MakeExpr(SymbolTable().find(ast->Name())->second); - type::TypeId udf_type{}; - udf_ast_context_->GetVariableType(ast->Name(), &udf_type); - if (udf_type == type::TypeId::INVALID) { + const type::TypeId type = GetVariableType(ast->Name()); + if (type == type::TypeId::INVALID) { // Record type - auto &struct_vars = udf_ast_context_->GetRecordType(ast->Name()); + const auto struct_vars = GetRecordType(ast->Name()); if (captures.empty()) { captures.push_back(capture_var); } @@ -528,10 +536,9 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { assignees.push_back(capture_var); captures.push_back(capture_var); } - auto *type = codegen_->TplType(execution::sql::GetTypeId(col.GetType())); - + auto *tpl_type = codegen_->TplType(execution::sql::GetTypeId(col.GetType())); auto input_param = codegen_->MakeFreshIdentifier("input"); - params.push_back(codegen_->MakeField(input_param, type)); + params.push_back(codegen_->MakeField(input_param, tpl_type)); i++; } @@ -577,18 +584,22 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second.second < y->second.second; }); for (auto entry : sorted_vec) { // TODO(Kyle): Order these - type::TypeId type = type::TypeId::INVALID; - execution::ast::Expr *expr = nullptr; - if (entry->second.first.length() > 0) { - auto &fields = udf_ast_context_->GetRecordType(entry->second.first); - auto it = std::find_if(fields.begin(), fields.end(), [=](auto p) { return p.first == entry->first; }); - type = it->second; - expr = codegen_->AccessStructMember(codegen_->MakeExpr(SymbolTable().at(entry->second.first)), - codegen_->MakeIdentifier(entry->first)); - } else { - udf_ast_context_->GetVariableType(entry->first, &type); - expr = codegen_->MakeExpr(SymbolTable().at(entry->first)); - } + + // TODO(Kyle): This IILE is cool and all... but way more + // complex than I would like, all of the logic in this + // function deserves a second look to refactor + auto [type, expr] = [=, &entry]() { + if (entry->second.first.length() > 0) { + const auto fields = GetRecordType(entry->second.first); + auto it = std::find_if(fields.cbegin(), fields.cend(), [=](auto p) { return p.first == entry->first; }); + NOISEPAGE_ASSERT(it != fields.cend(), "Broken invariant"); + return std::pair{ + it->second, codegen_->AccessStructMember(codegen_->MakeExpr(SymbolTable().at(entry->second.first)), + codegen_->MakeIdentifier(entry->first))}; + } + const type::TypeId type = GetVariableType(entry->first); + return std::pair{type, codegen_->MakeExpr(SymbolTable().at(entry->first))}; + }(); execution::ast::Builtin builtin{}; switch (type) { @@ -657,4 +668,22 @@ void UdfCodegen::Visit(ast::udf::MemberExprAST *ast) { dst_ = codegen_->AccessStructMember(object, codegen_->MakeIdentifier(ast->FieldName())); } +type::TypeId UdfCodegen::GetVariableType(const std::string &name) const { + auto type = udf_ast_context_->GetVariableType(name); + if (!type.has_value()) { + throw EXECUTION_EXCEPTION(fmt::format("Failed to resolve type for variable '{}'", name), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + return type.value(); +} + +std::vector> UdfCodegen::GetRecordType(const std::string &name) const { + auto type = udf_ast_context_->GetRecordType(name); + if (!type.has_value()) { + throw EXECUTION_EXCEPTION(fmt::format("Failed to resolve type for record variable '{}'", name), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + return type.value(); +} + } // namespace noisepage::execution::compiler::udf diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index bdfca15b56..96abc1c0c6 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -73,11 +73,10 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerGetDatabaseOid()}; + parser::udf::PLpgSQLParser udf_parser{common::ManagedPointer{&udf_ast_context}, accessor, node->GetDatabaseOid()}; std::unique_ptr ast{}; try { - ast = udf_parser.Parse(node->GetFunctionParameterNames(), param_type_ids, body, - (common::ManagedPointer(&udf_ast_context))); + ast = udf_parser.Parse(node->GetFunctionParameterNames(), param_type_ids, body); } catch (Exception &e) { return false; } diff --git a/src/include/binder/bind_node_visitor.h b/src/include/binder/bind_node_visitor.h index ce76085dff..851abea615 100644 --- a/src/include/binder/bind_node_visitor.h +++ b/src/include/binder/bind_node_visitor.h @@ -157,6 +157,9 @@ class BindNodeVisitor final : public SqlNodeVisitor { void ValidateAndCorrectInsertValues(common::ManagedPointer node, std::vector> *values, const catalog::Schema &table_schema); + + /** @return `true` if we are binding within the context of a UDF, `false` otherwise */ + bool BindingForUDF() const; }; } // namespace binder diff --git a/src/include/execution/ast/udf/udf_ast_context.h b/src/include/execution/ast/udf/udf_ast_context.h index 49a3be5af4..6d56018adb 100644 --- a/src/include/execution/ast/udf/udf_ast_context.h +++ b/src/include/execution/ast/udf/udf_ast_context.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -14,12 +15,32 @@ namespace noisepage::execution::ast::udf { * throughout construction of the UDF abstract syntax tree. */ class UdfAstContext { + /** An invidual entry for a record type, (name, type ID) */ + using RecordTypeEntry = std::pair; + + /** A full description of a record type */ + using RecordType = std::vector; + public: /** * Construct a new AstContext instance. */ UdfAstContext() = default; + /** + * Add a new variable to the symbol table. + * @param name The name of the variable + */ + void AddVariable(const std::string &name) { local_variables_.push_back(name); } + + /** + * Determine if a variable with name `name` is present in the UDF AST. + * @param name The name of the variable + * @return `true` if the UDF AST context contains a variable + * identified by `name`, `false` otherwise + */ + bool HasVariable(const std::string &name) const { return (symbol_table_.find(name) != symbol_table_.cend()); } + /** * Set the type of the variabel identifed by `name`. * @param name The name of the variable @@ -30,56 +51,86 @@ class UdfAstContext { /** * Get the type of the variable identified by `name`. * @param name The name of the variable - * @param type The out-parameter used to store the result - * @return `true` if the variable is present in the symbol - * table and the Get() succeeds, `false` otherwise + * @return The type ID for the specified variable if present, + * empty optional value otherwise */ - bool GetVariableType(const std::string &name, type::TypeId *type) { + std::optional GetVariableType(const std::string &name) const { auto it = symbol_table_.find(name); - if (it == symbol_table_.end()) { - return false; - } - if (type != nullptr) { - *type = it->second; - } - return true; + return (it == symbol_table_.cend()) ? std::nullopt : std::make_optional(it->second); } /** - * Add a new variable to the symbol table. + * Get the type of the variable identified by `name`. * @param name The name of the variable + * @return The type ID for the specified variable + * + * NOTE: This function terminates the program in the event + * that the variable is not present; for variable queries + * that may fail, use UdfAstContext::GetVariableType(). */ - void AddVariable(const std::string &name) { local_variables_.push_back(name); } + type::TypeId GetVariableTypeFailFast(const std::string &name) const { + auto it = symbol_table_.find(name); + NOISEPAGE_ASSERT(it != symbol_table_.cend(), "Required variable is not present in UDF AST"); + return it->second; + } /** - * Get the local variable at index `index`. - * @param index The index of interest - * @return The name of the variable at the specified index + * Determine if a record variable with name `name` is present in the UDF AST. + * @param name The name of the variable + * @return `true` if the UDF AST context contains a record variable + * identified by `name`, `false` otherwise */ - const std::string &GetLocalVariableAtIndex(const std::size_t index) { - NOISEPAGE_ASSERT(local_variables_.size() >= index, "Index out of range"); - // TODO(Kyle): I moved the subtraction to the call site because - // it seems misleading to have a getter for an index but deliver - // a local that does not actually appear at that index... - return local_variables_.at(index); + bool HasRecord(const std::string &name) const { return (record_types_.find(name) != record_types_.cend()); } + + /** + * Set the record type for the variable identified by `name`. + * @param name The name of the variable + * @param elems The record + */ + void SetRecordType(const std::string &name, std::vector> &&elems) { + record_types_[name] = std::move(elems); } /** - * Get the record type for the specified variable. + * Get the record type for the variable identified by `name`. * @param name The name of the variable - * @return The record + * @return The type of the record variable if present, + * empty optional value otherwise */ - const std::vector> &GetRecordType(const std::string &name) const { - return record_types_.find(name)->second; + std::optional GetRecordType(const std::string &name) const { + auto it = record_types_.find(name); + // TODO(Kyle): I updated the API for this function to use std::optional, + // I like this more, but it makes it impossible to return a reference to + // the underlying data so this now materializes a copy every time + return (it == record_types_.cend()) ? std::nullopt : std::make_optional(it->second); } /** - * Set the record type for the specified variable. + * Get the record type for the variable identified by `name`. * @param name The name of the variable - * @param elems The record + * @return The type of the record variable + * + * NOTE: This function terminates the program in the event + * that the variable is not present; for variable queries + * that may fail, use UdfAstContext::GetRecordType(). */ - void SetRecordType(const std::string &name, std::vector> &&elems) { - record_types_[name] = std::move(elems); + RecordType GetRecordTypeFailFast(const std::string &name) const { + auto it = record_types_.find(name); + NOISEPAGE_ASSERT(it != record_types_.cend(), "Required record variable is not present in UDF AST"); + return it->second; + } + + /** + * Get the local variable at index `index`. + * @param index The index of interest + * @return The name of the variable at the specified index + */ + const std::string &GetLocalVariableAtIndex(const std::size_t index) const { + NOISEPAGE_ASSERT(local_variables_.size() >= index, "Index out of range"); + // TODO(Kyle): I moved the subtraction to the call site because + // it seems misleading to have a getter for an index but deliver + // a local that does not actually appear at that index... + return local_variables_.at(index); } private: @@ -88,7 +139,7 @@ class UdfAstContext { /** Collection of local variable names for the UDF. */ std::vector local_variables_; /** Collection of record types for the UDF. */ - std::unordered_map>> record_types_; + std::unordered_map record_types_; }; } // namespace noisepage::execution::ast::udf diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index 9579381101..c4dd9d8d31 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -1,7 +1,10 @@ #pragma once #include +#include #include +#include +#include #include "execution/ast/udf/udf_ast_context.h" #include "execution/ast/udf/udf_ast_node_visitor.h" @@ -77,6 +80,7 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { ast::udf::UdfAstContext *ast_context, CodeGen *codegen, catalog::db_oid_t db_oid, ast::udf::FunctionAST *root); + private: /** * Generate a UDF from the given abstract syntax tree. * @param ast The AST from which to generate the UDF @@ -222,7 +226,26 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { /** @return An immutable reference to the symbol table */ const std::unordered_map &SymbolTable() const { return symbol_table_; } + /** + * Get the type of the variable identified by `name`. + * @param name The name of the variable + * @return The type of the variable identified by `name` + * @throw EXECUTION_EXCEPTION on failure to resolve type + */ + type::TypeId GetVariableType(const std::string &name) const; + + /** + * Get the type of the record variable identified by `name`. + * @param name The name of the variable + * @return The type of the record variable identified by `name` + * @throw EXECUTION_EXCEPTION on failure to resolve type + */ + std::vector> GetRecordType(const std::string &name) const; + private: + /** The string identifier for internal declarations */ + constexpr static const char INTERNAL_DECL_ID[] = "*internal*"; + /** The catalog access used during code generation */ catalog::CatalogAccessor *accessor_; diff --git a/src/include/parser/udf/plpgsql_parser.h b/src/include/parser/udf/plpgsql_parser.h index 8e77d64943..f501d3490a 100644 --- a/src/include/parser/udf/plpgsql_parser.h +++ b/src/include/parser/udf/plpgsql_parser.h @@ -45,9 +45,9 @@ class PLpgSQLParser { * @param ast_context The AST context to use during parsing * @return The abstract syntax tree for the source function */ - std::unique_ptr Parse( - const std::vector ¶m_names, const std::vector ¶m_types, - const std::string &func_body, common::ManagedPointer ast_context); + std::unique_ptr Parse(const std::vector ¶m_names, + const std::vector ¶m_types, + const std::string &func_body); private: /** diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index 378c672691..2bb7920c70 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -52,9 +52,9 @@ static constexpr const char DECL_TYPE_ID_VARCHAR[] = "varchar"; static constexpr const char DECL_TYPE_ID_DATE[] = "date"; static constexpr const char DECL_TYPE_ID_RECORD[] = "record"; -std::unique_ptr PLpgSQLParser::Parse( - const std::vector ¶m_names, const std::vector ¶m_types, - const std::string &func_body, common::ManagedPointer ast_context) { +std::unique_ptr PLpgSQLParser::Parse(const std::vector ¶m_names, + const std::vector ¶m_types, + const std::string &func_body) { auto result = pg_query_parse_plpgsql(func_body.c_str()); if (result.error != nullptr) { pg_query_free_plpgsql_parse_result(result); @@ -164,9 +164,9 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo // Detemine if the variable has already been declared; // if so, just re-use this type that has already been resolved - type::TypeId temp_type{}; - if (udf_ast_context_->GetVariableType(var_name, &temp_type)) { - return std::make_unique(var_name, temp_type, std::move(initial)); + const auto resolved_type = udf_ast_context_->GetVariableType(var_name); + if (resolved_type.has_value()) { + return std::make_unique(var_name, resolved_type.value(), std::move(initial)); } // Otherwise, we perform a string comparison with the type identifier @@ -262,20 +262,20 @@ std::unique_ptr PLpgSQLParser::ParseSQL(const nloh } // Check to see if a record type can be bound to this - type::TypeId type{}; - auto ret = udf_ast_context_->GetVariableType(var_name, &type); - if (!ret) { + const auto type = udf_ast_context_->GetVariableType(var_name); + if (!type.has_value()) { throw PARSER_EXCEPTION("PL/pgSQL parser: variable was not declared"); } - if (type == type::TypeId::INVALID) { + if (type.value() == type::TypeId::INVALID) { std::vector> elems{}; const auto &select_columns = parse_result->GetStatement(0).CastManagedPointerTo()->GetSelectColumns(); elems.reserve(select_columns.size()); - for (const auto &col : select_columns) { - elems.emplace_back(col->GetAlias().GetName(), col->GetReturnValueType()); - } + std::transform(select_columns.cbegin(), select_columns.cend(), std::back_inserter(elems), + [](const common::ManagedPointer &column) { + return std::make_pair(column->GetAlias().GetName(), column->GetReturnValueType()); + }); udf_ast_context_->SetRecordType(var_name, std::move(elems)); } From 72a0507a947ad9358eac06226af577ae3adfc2d4 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 22 Jul 2021 19:25:36 -0400 Subject: [PATCH 078/139] refactor to add support for integer for loop variant to parser, now need to complete it by adding support in code generation --- src/execution/compiler/udf/udf_codegen.cpp | 11 ++- src/execution/sql/ddl_executors.cpp | 3 +- .../execution/ast/udf/udf_ast_context.h | 34 +++---- .../execution/ast/udf/udf_ast_node_visitor.h | 23 ++--- src/include/execution/ast/udf/udf_ast_nodes.h | 95 ++++++++++++++++--- .../execution/compiler/udf/udf_codegen.h | 12 ++- src/include/parser/udf/plpgsql_parser.h | 23 +++-- src/parser/udf/plpgsql_parser.cpp | 92 +++++++++++------- 8 files changed, 201 insertions(+), 92 deletions(-) diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 83ba28e2af..a269f8bf52 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -334,8 +334,11 @@ void UdfCodegen::Visit(ast::udf::WhileStmtAST *ast) { loop.EndLoop(); } -void UdfCodegen::Visit(ast::udf::ForStmtAST *ast) { - // Once we encounter a For-statement we know we need an execution context +void UdfCodegen::Visit(ast::udf::ForIStmtAST *ast) { throw NOT_IMPLEMENTED_EXCEPTION("ForIStmtAST Not Implemented"); } + +void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { + // Once we encounter a for-statement we know we need an execution + // context because the loop always draws values from a query needs_exec_ctx_ = true; const auto query = common::ManagedPointer(ast->Query()); @@ -352,7 +355,7 @@ void UdfCodegen::Visit(ast::udf::ForStmtAST *ast) { auto plan = optimizer_result->GetPlanNode(); // Make a lambda that just writes into this - std::vector var_idents; + std::vector var_idents{}; auto lam_var = codegen_->MakeFreshIdentifier("looplamb"); execution::util::RegionVector params(codegen_->GetAstContext()->GetRegion()); params.push_back(codegen_->MakeField( @@ -420,7 +423,7 @@ void UdfCodegen::Visit(ast::udf::ForStmtAST *ast) { // Set its execution context to whatever exec context was passed in here fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); - std::vector>::iterator> sorted_vec; + std::vector>::iterator> sorted_vec{}; for (auto it = query_params.begin(); it != query_params.end(); it++) { sorted_vec.push_back(it); } diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index 96abc1c0c6..41a654665f 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -77,7 +77,8 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointer ast{}; try { ast = udf_parser.Parse(node->GetFunctionParameterNames(), param_type_ids, body); - } catch (Exception &e) { + } catch (const ParserException &e) { + PARSER_LOG_ERROR(e.what()); return false; } diff --git a/src/include/execution/ast/udf/udf_ast_context.h b/src/include/execution/ast/udf/udf_ast_context.h index 6d56018adb..2e039cf3d7 100644 --- a/src/include/execution/ast/udf/udf_ast_context.h +++ b/src/include/execution/ast/udf/udf_ast_context.h @@ -28,10 +28,23 @@ class UdfAstContext { UdfAstContext() = default; /** - * Add a new variable to the symbol table. + * Push a new local variable. * @param name The name of the variable */ - void AddVariable(const std::string &name) { local_variables_.push_back(name); } + void AddLocal(const std::string &name) { locals_.push_back(name); } + + /** + * Get the local variable at index `index`. + * @param index The index of interest + * @return The name of the variable at the specified index + */ + const std::string &GetLocalAtIndex(const std::size_t index) const { + NOISEPAGE_ASSERT(locals_.size() >= index, "Index out of range"); + // TODO(Kyle): I moved the subtraction to the call site because + // it seems misleading to have a getter for an index but deliver + // a local that does not actually appear at that index... + return locals_.at(index); + } /** * Determine if a variable with name `name` is present in the UDF AST. @@ -120,24 +133,11 @@ class UdfAstContext { return it->second; } - /** - * Get the local variable at index `index`. - * @param index The index of interest - * @return The name of the variable at the specified index - */ - const std::string &GetLocalVariableAtIndex(const std::size_t index) const { - NOISEPAGE_ASSERT(local_variables_.size() >= index, "Index out of range"); - // TODO(Kyle): I moved the subtraction to the call site because - // it seems misleading to have a getter for an index but deliver - // a local that does not actually appear at that index... - return local_variables_.at(index); - } - private: + /** Collection of local variable names for the UDF. */ + std::vector locals_; /** The symbol table for the UDF. */ std::unordered_map symbol_table_; - /** Collection of local variable names for the UDF. */ - std::vector local_variables_; /** Collection of record types for the UDF. */ std::unordered_map record_types_; }; diff --git a/src/include/execution/ast/udf/udf_ast_node_visitor.h b/src/include/execution/ast/udf/udf_ast_node_visitor.h index 4c24b4a2a2..9a115c5ec6 100644 --- a/src/include/execution/ast/udf/udf_ast_node_visitor.h +++ b/src/include/execution/ast/udf/udf_ast_node_visitor.h @@ -1,9 +1,6 @@ #pragma once -namespace noisepage { -namespace execution { -namespace ast { -namespace udf { +namespace noisepage::execution::ast::udf { class AbstractAST; class StmtAST; @@ -22,7 +19,8 @@ class RetStmtAST; class AssignStmtAST; class SQLStmtAST; class DynamicSQLStmtAST; -class ForStmtAST; +class ForIStmtAST; +class ForSStmtAST; class FunctionAST; /** @@ -133,10 +131,16 @@ class ASTNodeVisitor { virtual void Visit(AssignStmtAST *ast) = 0; /** - * Visit an ForStmtAST node. + * Visit a ForIStmtAST node. * @param ast The node to visit */ - virtual void Visit(ForStmtAST *ast) = 0; + virtual void Visit(ForIStmtAST *ast) = 0; + + /** + * Visit an ForSStmtAST node. + * @param ast The node to visit + */ + virtual void Visit(ForSStmtAST *ast) = 0; /** * Visit an SQLStmtAST node. @@ -151,7 +155,4 @@ class ASTNodeVisitor { virtual void Visit(DynamicSQLStmtAST *ast) = 0; }; -} // namespace udf -} // namespace ast -} // namespace execution -} // namespace noisepage +} // namespace noisepage::execution::ast::udf diff --git a/src/include/execution/ast/udf/udf_ast_nodes.h b/src/include/execution/ast/udf/udf_ast_nodes.h index 376c9d242a..42c0b82d3d 100644 --- a/src/include/execution/ast/udf/udf_ast_nodes.h +++ b/src/include/execution/ast/udf/udf_ast_nodes.h @@ -13,10 +13,7 @@ #include "execution/ast/udf/udf_ast_node_visitor.h" #include "execution/sql/value.h" -namespace noisepage { -namespace execution { -namespace ast { -namespace udf { +namespace noisepage::execution::ast::udf { /** * The AbstractAST class serves as a base class for all AST nodes. @@ -399,18 +396,93 @@ class IfStmtAST : public StmtAST { }; /** - * The ForStmtAST class represents a `for`-loop construct. + * The ForIStmtAST class represents a `for`-loop construct. + * + * Ex: FOR i IN 1..10 LOOP... */ -class ForStmtAST : public StmtAST { +class ForIStmtAST : public StmtAST { public: /** - * Construct a new ForStmtAST instance. + * The default query that defines the "step" expression. + * + * The PLpgSQL documentation specifies this behavior. + */ + constexpr static const char DEFAULT_STEP_EXPR[] = "SELECT 1"; + + /** + * Construct a new ForIStmtAST instance. + * @param variables The collection of variables in the loop + * @param body The body of the loop + */ + ForIStmtAST(std::string variable, std::unique_ptr lower, std::unique_ptr upper, + std::unique_ptr step, std::unique_ptr body) + : variable_{std::move(variable)}, + lower_{std::move(lower)}, + upper_{std::move(upper)}, + step_{std::move(step)}, + body_{std::move(body)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + + /** @return The loop variable */ + const std::string &Variable() const { return variable_; } + + /** @return A mutable pointer to the loop lower-bound expression */ + ExprAST *Lower() { return lower_.get(); } + + /** @return An immutable pointer to the loop lower-bound expression */ + const ExprAST *Lower() const { return lower_.get(); } + + /** @return A mutable pointer to the loop upper-bound expression */ + ExprAST *Upper() { return upper_.get(); } + + /** @return An immutable pointer to the loop upper-bound expression */ + const ExprAST *Upper() const { return upper_.get(); } + + /** @return A mutable pointer to the loop step expression */ + ExprAST *Step() { return step_.get(); } + + /** @return An immutable pointer to the loop step expression */ + const ExprAST *Step() const { return step_.get(); } + + /** @return A mutable pointer to the loop body statement */ + StmtAST *Body() { return body_.get(); } + + /** @return An immutable pointer to the loop body statement */ + const StmtAST *Body() const { return body_.get(); } + + private: + /** The identifier for the loop variable */ + const std::string variable_; + /** The expression that defines the loop lower-bound */ + std::unique_ptr lower_; + /** The expression that defines the loop upper-bound */ + std::unique_ptr upper_; + /** The expression that defines the loop step */ + std::unique_ptr step_; + /** The loop body */ + std::unique_ptr body_; +}; + +/** + * The ForSStmtAST class represents a `for`-loop construct. + * + * Ex: FOR record IN (SELECT * FROM tmp) LOOP ... + */ +class ForSStmtAST : public StmtAST { + public: + /** + * Construct a new ForSStmtAST instance. * @param variables The collection of variables in the loop * @param query The associated query * @param body The body of the loop */ - ForStmtAST(std::vector &&variables, std::unique_ptr &&query, - std::unique_ptr body) + ForSStmtAST(std::vector &&variables, std::unique_ptr &&query, + std::unique_ptr body) : variables_{std::move(variables)}, query_{std::move(query)}, body_{std::move(body)} {} /** @@ -684,7 +756,4 @@ class FunctionAST : public AbstractAST { std::unique_ptr LogError(const char *str); -} // namespace udf -} // namespace ast -} // namespace execution -} // namespace noisepage +} // namespace noisepage::execution::ast::udf diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index c4dd9d8d31..c7da0457a1 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -38,7 +38,7 @@ class SQLStmtAST; class FunctionAST; class IsNullExprAST; class DynamicSQLStmtAST; -class ForStmtAST; +class ForSStmtAST; } // namespace ast::udf namespace compiler::udf { @@ -189,10 +189,16 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { void Visit(ast::udf::DynamicSQLStmtAST *ast) override; /** - * Visit a ForStmtAST node. + * Visit a ForIStmtAST node. * @param ast The AST node to visit */ - void Visit(ast::udf::ForStmtAST *ast) override; + void Visit(ast::udf::ForIStmtAST *ast) override; + + /** + * Visit a ForSStmtAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::ForSStmtAST *ast) override; /** * Visit a MemberExprAST node. diff --git a/src/include/parser/udf/plpgsql_parser.h b/src/include/parser/udf/plpgsql_parser.h index f501d3490a..15465807a0 100644 --- a/src/include/parser/udf/plpgsql_parser.h +++ b/src/include/parser/udf/plpgsql_parser.h @@ -86,28 +86,35 @@ class PLpgSQLParser { std::unique_ptr ParseWhile(const nlohmann::json &loop); /** - * Parse a for-statement. - * @param block The input JSON object + * Parse a for-statement (integer variant). + * @param loop The input JSON object + * @return The AST for the for-statement + */ + std::unique_ptr ParseForI(const nlohmann::json &loop); + + /** + * Parse a for-statement (query variant). + * @param loop The input JSON object * @return The AST for the for-statement */ - std::unique_ptr ParseFor(const nlohmann::json &loop); + std::unique_ptr ParseForS(const nlohmann::json &loop); /** * Parse a SQL statement. - * @param sql_stmt The input JSON object + * @param sql The input JSON object * @return The AST for the SQL statement */ - std::unique_ptr ParseSQL(const nlohmann::json &sql_stmt); + std::unique_ptr ParseSQL(const nlohmann::json &sql); /** * Parse a dynamic SQL statement. - * @param block The input JSON object + * @param sql The input JSON object * @return The AST for the dynamic SQL statement */ - std::unique_ptr ParseDynamicSQL(const nlohmann::json &sql_stmt); + std::unique_ptr ParseDynamicSQL(const nlohmann::json &sql); /** - * Parse a SQL expression. + * Parse a SQL expression to an expression AST. * @param sql The SQL expression string * @return The AST for the SQL expression */ diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index 2bb7920c70..ce5fa4564e 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -27,6 +27,7 @@ static constexpr const char K_PLPGSQL_STMT_RETURN[] = "PLpgSQL_stmt_return"; static constexpr const char K_PLPGSQL_STMT_IF[] = "PLpgSQL_stmt_if"; static constexpr const char K_PLPGSQL_STMT_WHILE[] = "PLpgSQL_stmt_while"; static constexpr const char K_PLPGSQL_STMT_FORS[] = "PLpgSQL_stmt_fors"; +static constexpr const char K_PLPGSQL_STMT_FORI[] = "PLpgSQL_stmt_fori"; static constexpr const char K_COND[] = "cond"; static constexpr const char K_THEN_BODY[] = "then_body"; static constexpr const char K_ELSE_BODY[] = "else_body"; @@ -42,6 +43,10 @@ static constexpr const char K_FIELDS[] = "fields"; static constexpr const char K_NAME[] = "name"; static constexpr const char K_PLPGSQL_ROW[] = "PLpgSQL_row"; static constexpr const char K_PLPGSQL_STMT_DYNEXECUTE[] = "PLpgSQL_stmt_dynexecute"; +static constexpr const char K_LOWER[] = "lower"; +static constexpr const char K_UPPER[] = "upper"; +static constexpr const char K_STEP[] = "step"; +static constexpr const char K_VAR[] = "var"; /** Variable declaration type identifiers */ static constexpr const char DECL_TYPE_ID_INT[] = "int"; @@ -71,6 +76,7 @@ std::unique_ptr PLpgSQLParser::Parse(const std ss >> ast_json; const auto function_list = ast_json[K_FUNCTION_LIST]; NOISEPAGE_ASSERT(function_list.is_array(), "Function list is not an array"); + if (function_list.size() != 1) { throw PARSER_EXCEPTION("Function list has size other than 1"); } @@ -109,48 +115,53 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl throw PARSER_EXCEPTION("PL/pgSQL parser : Empty block is not supported"); } - std::vector> stmts{}; - for (const auto &stmt : block) { - const auto stmt_names = stmt.items().begin(); - if (stmt_names.key() == K_PLPGSQL_STMT_RETURN) { + std::vector> statements{}; + for (const auto &statement : block) { + std::cout << statement << std::endl; + const std::string &statement_type = statement.items().begin().key(); + if (statement_type == K_PLPGSQL_STMT_RETURN) { // TODO(Kyle): Handle RETURN without expression - if (stmt[K_PLPGSQL_STMT_RETURN].empty()) { + if (statement[K_PLPGSQL_STMT_RETURN].empty()) { throw NOT_IMPLEMENTED_EXCEPTION("RETURN without expression not implemented."); } - auto expr = ParseExprSQL(stmt[K_PLPGSQL_STMT_RETURN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); - stmts.push_back(std::make_unique(std::move(expr))); - } else if (stmt_names.key() == K_PLPGSQL_STMT_IF) { - stmts.push_back(ParseIf(stmt[K_PLPGSQL_STMT_IF])); - } else if (stmt_names.key() == K_PLPGSQL_STMT_ASSIGN) { + auto expr = ParseExprSQL(statement[K_PLPGSQL_STMT_RETURN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); + statements.push_back(std::make_unique(std::move(expr))); + } else if (statement_type == K_PLPGSQL_STMT_IF) { + statements.push_back(ParseIf(statement[K_PLPGSQL_STMT_IF])); + } else if (statement_type == K_PLPGSQL_STMT_ASSIGN) { // TODO(Kyle): Need to fix Assignment expression / statement // NOTE(Kyle): We subtract 1 here because variable numbers from // the Postres parser index from 1 rather than 0 (?) const auto &var_name = - udf_ast_context_->GetLocalVariableAtIndex(stmt[K_PLPGSQL_STMT_ASSIGN][K_VARNO].get() - 1); + udf_ast_context_->GetLocalAtIndex(statement[K_PLPGSQL_STMT_ASSIGN][K_VARNO].get() - 1); auto lhs = std::make_unique(var_name); - auto rhs = ParseExprSQL(stmt[K_PLPGSQL_STMT_ASSIGN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); - stmts.push_back(std::make_unique(std::move(lhs), std::move(rhs))); - } else if (stmt_names.key() == K_PLPGSQL_STMT_WHILE) { - stmts.push_back(ParseWhile(stmt[K_PLPGSQL_STMT_WHILE])); - } else if (stmt_names.key() == K_PLPGSQL_STMT_FORS) { - stmts.push_back(ParseFor(stmt[K_PLPGSQL_STMT_FORS])); - } else if (stmt_names.key() == K_PLGPSQL_STMT_EXECSQL) { - stmts.push_back(ParseSQL(stmt[K_PLGPSQL_STMT_EXECSQL])); - } else if (stmt_names.key() == K_PLPGSQL_STMT_DYNEXECUTE) { - stmts.push_back(ParseDynamicSQL(stmt[K_PLPGSQL_STMT_DYNEXECUTE])); + auto rhs = ParseExprSQL(statement[K_PLPGSQL_STMT_ASSIGN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); + statements.push_back(std::make_unique(std::move(lhs), std::move(rhs))); + } else if (statement_type == K_PLPGSQL_STMT_WHILE) { + statements.push_back(ParseWhile(statement[K_PLPGSQL_STMT_WHILE])); + } else if (statement_type == K_PLPGSQL_STMT_FORI) { + statements.push_back(ParseForI(statement[K_PLPGSQL_STMT_FORI])); + } else if (statement_type == K_PLPGSQL_STMT_FORS) { + statements.push_back(ParseForS(statement[K_PLPGSQL_STMT_FORS])); + } else if (statement_type == K_PLGPSQL_STMT_EXECSQL) { + statements.push_back(ParseSQL(statement[K_PLGPSQL_STMT_EXECSQL])); + } else if (statement_type == K_PLPGSQL_STMT_DYNEXECUTE) { + statements.push_back(ParseDynamicSQL(statement[K_PLPGSQL_STMT_DYNEXECUTE])); } else { - throw PARSER_EXCEPTION("Statement type not supported"); + throw PARSER_EXCEPTION(fmt::format("Statement type '{}' not supported", statement_type)); } } - return std::make_unique(std::move(stmts)); + return std::make_unique(std::move(statements)); } std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &decl) { const auto &decl_names = decl.items().begin(); if (decl_names.key() == K_PLPGSQL_VAR) { auto var_name = decl[K_PLPGSQL_VAR][K_REFNAME].get(); - udf_ast_context_->AddVariable(var_name); + + // Track the local variable (for assignment) + udf_ast_context_->AddLocal(var_name); // Grab the type identifier from the PL/pgSQL parser const std::string type = StringUtils::Strip( @@ -193,7 +204,7 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo udf_ast_context_->SetVariableType(var_name, type::TypeId::INVALID); return std::make_unique(var_name, type::TypeId::INVALID, std::move(initial)); } - NOISEPAGE_ASSERT(false, "Unsupported Type"); + throw PARSER_EXCEPTION(fmt::format("Unsupported type '{}' for variable '{}'", type, var_name)); } else if (decl_names.key() == K_PLPGSQL_ROW) { const auto var_name = decl[K_PLPGSQL_ROW][K_REFNAME].get(); NOISEPAGE_ASSERT(var_name == "*internal*", "Unexpected refname"); @@ -224,7 +235,18 @@ std::unique_ptr PLpgSQLParser::ParseWhile(const nl return std::make_unique(std::move(cond_expr), std::move(body_stmt)); } -std::unique_ptr PLpgSQLParser::ParseFor(const nlohmann::json &loop) { +std::unique_ptr PLpgSQLParser::ParseForI(const nlohmann::json &loop) { + const auto name = loop[K_VAR][K_PLPGSQL_VAR][K_REFNAME].get(); + auto lower = ParseExprSQL(loop[K_LOWER][K_PLPGSQL_EXPR][K_QUERY]); + auto upper = ParseExprSQL(loop[K_UPPER][K_PLPGSQL_EXPR][K_QUERY]); + auto step = loop.contains(K_STEP) ? ParseExprSQL(loop[K_STEP][K_PLPGSQL_EXPR][K_QUERY]) + : ParseExprSQL(execution::ast::udf::ForIStmtAST::DEFAULT_STEP_EXPR); + auto body = ParseBlock(loop[K_BODY]); + return std::make_unique(name, std::move(lower), std::move(upper), std::move(step), + std::move(body)); +} + +std::unique_ptr PLpgSQLParser::ParseForS(const nlohmann::json &loop) { const auto sql_query = loop[K_QUERY][K_PLPGSQL_EXPR][K_QUERY].get(); auto parse_result = PostgresParser::BuildParseTree(sql_query); if (parse_result == nullptr) { @@ -236,15 +258,15 @@ std::unique_ptr PLpgSQLParser::ParseFor(const nloh variables.reserve(var_array.size()); std::transform(var_array.cbegin(), var_array.cend(), std::back_inserter(variables), [](const nlohmann::json &var) { return var[K_NAME].get(); }); - return std::make_unique(std::move(variables), std::move(parse_result), - std::move(body_stmt)); + return std::make_unique(std::move(variables), std::move(parse_result), + std::move(body_stmt)); } -std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql_stmt) { +std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql) { // The query text - const auto sql_query = sql_stmt[K_SQLSTMT][K_PLPGSQL_EXPR][K_QUERY].get(); + const auto sql_query = sql[K_SQLSTMT][K_PLPGSQL_EXPR][K_QUERY].get(); // The variable name (non-const for later std::move) - auto var_name = sql_stmt[K_ROW][K_PLPGSQL_ROW][K_FIELDS][0][K_NAME].get(); + auto var_name = sql[K_ROW][K_PLPGSQL_ROW][K_FIELDS][0][K_NAME].get(); auto parse_result = PostgresParser::BuildParseTree(sql_query); if (parse_result == nullptr) { @@ -283,9 +305,9 @@ std::unique_ptr PLpgSQLParser::ParseSQL(const nloh std::move(query_params)); } -std::unique_ptr PLpgSQLParser::ParseDynamicSQL(const nlohmann::json &sql_stmt) { - auto sql_expr = ParseExprSQL(sql_stmt[K_QUERY][K_PLPGSQL_EXPR][K_QUERY].get()); - auto var_name = sql_stmt[K_ROW][K_PLPGSQL_ROW][K_FIELDS][0][K_NAME].get(); +std::unique_ptr PLpgSQLParser::ParseDynamicSQL(const nlohmann::json &sql) { + auto sql_expr = ParseExprSQL(sql[K_QUERY][K_PLPGSQL_EXPR][K_QUERY].get()); + auto var_name = sql[K_ROW][K_PLPGSQL_ROW][K_FIELDS][0][K_NAME].get(); return std::make_unique(std::move(sql_expr), std::move(var_name)); } @@ -300,7 +322,7 @@ std::unique_ptr PLpgSQLParser::ParseExprSQL(const NOISEPAGE_ASSERT(stmt.CastManagedPointerTo()->GetSelectTable() == nullptr, "Unsupported SQL Expr in UDF"); auto &select_list = stmt.CastManagedPointerTo()->GetSelectColumns(); - NOISEPAGE_ASSERT(select_list.size() == 1, "Unsupported number of select columns in udf"); + NOISEPAGE_ASSERT(select_list.size() == 1, "Unsupported number of select columns in UDF"); return PLpgSQLParser::ParseExpr(select_list[0]); } From 66c222b47fd016c9628bcff4c58151b85ca3d1b4 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 22 Jul 2021 19:29:22 -0400 Subject: [PATCH 079/139] make clang tidy happy by removing else-if after throw --- src/parser/udf/plpgsql_parser.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index ce5fa4564e..c00f5fc9dd 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -156,8 +156,8 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl } std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &decl) { - const auto &decl_names = decl.items().begin(); - if (decl_names.key() == K_PLPGSQL_VAR) { + const auto &declaration_type = decl.items().begin().key(); + if (declaration_type == K_PLPGSQL_VAR) { auto var_name = decl[K_PLPGSQL_VAR][K_REFNAME].get(); // Track the local variable (for assignment) @@ -204,8 +204,11 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo udf_ast_context_->SetVariableType(var_name, type::TypeId::INVALID); return std::make_unique(var_name, type::TypeId::INVALID, std::move(initial)); } + throw PARSER_EXCEPTION(fmt::format("Unsupported type '{}' for variable '{}'", type, var_name)); - } else if (decl_names.key() == K_PLPGSQL_ROW) { + } + + if (declaration_type == K_PLPGSQL_ROW) { const auto var_name = decl[K_PLPGSQL_ROW][K_REFNAME].get(); NOISEPAGE_ASSERT(var_name == "*internal*", "Unexpected refname"); @@ -215,7 +218,7 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo } // TODO(Kyle): Need to handle other types like row, table etc; - throw PARSER_EXCEPTION("Declaration type not supported"); + throw PARSER_EXCEPTION(fmt::format("Declaration type '{}' not supported", declaration_type)); } std::unique_ptr PLpgSQLParser::ParseIf(const nlohmann::json &branch) { From d5fc469cbaaa2bf78c26414e7731817fd3b0b8f1 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 22 Jul 2021 22:33:36 -0400 Subject: [PATCH 080/139] refactor formal parameters for plpgsql parser to reflect types --- src/include/parser/udf/plpgsql_parser.h | 47 +++++----- src/parser/udf/plpgsql_parser.cpp | 114 ++++++++++++------------ 2 files changed, 79 insertions(+), 82 deletions(-) diff --git a/src/include/parser/udf/plpgsql_parser.h b/src/include/parser/udf/plpgsql_parser.h index 15465807a0..792342827f 100644 --- a/src/include/parser/udf/plpgsql_parser.h +++ b/src/include/parser/udf/plpgsql_parser.h @@ -52,80 +52,81 @@ class PLpgSQLParser { private: /** * Parse a block statement. - * @param block The input JSON object + * @param json The input JSON object * @return The AST for the block */ - std::unique_ptr ParseBlock(const nlohmann::json &block); + std::unique_ptr ParseBlock(const nlohmann::json &json); /** * Parse a function statement. - * @param block The input JSON object - * @return The AST for the function + * @param json The input JSON object + * @return json AST for the function */ - std::unique_ptr ParseFunction(const nlohmann::json &function); + std::unique_ptr ParseFunction(const nlohmann::json &json); /** * Parse a declaration statement. - * @param decl The input JSON object + * @param json The input JSON object * @return The AST for the declaration */ - std::unique_ptr ParseDecl(const nlohmann::json &decl); + std::unique_ptr ParseDecl(const nlohmann::json &json); /** * Parse an if-statement. - * @param block The input JSON object + * @param json The input JSON object * @return The AST for the if-statement */ - std::unique_ptr ParseIf(const nlohmann::json &branch); + std::unique_ptr ParseIf(const nlohmann::json &json); /** * Parse a while-statement. - * @param block The input JSON object + * @param json The input JSON object * @return The AST for the while-statement */ - std::unique_ptr ParseWhile(const nlohmann::json &loop); + std::unique_ptr ParseWhile(const nlohmann::json &json); /** * Parse a for-statement (integer variant). - * @param loop The input JSON object + * @param json The input JSON object * @return The AST for the for-statement */ - std::unique_ptr ParseForI(const nlohmann::json &loop); + std::unique_ptr ParseForI(const nlohmann::json &json); /** * Parse a for-statement (query variant). - * @param loop The input JSON object + * @param json The input JSON object * @return The AST for the for-statement */ - std::unique_ptr ParseForS(const nlohmann::json &loop); + std::unique_ptr ParseForS(const nlohmann::json &json); /** * Parse a SQL statement. - * @param sql The input JSON object + * @param json The input JSON object * @return The AST for the SQL statement */ - std::unique_ptr ParseSQL(const nlohmann::json &sql); + std::unique_ptr ParseSQL(const nlohmann::json &json); /** * Parse a dynamic SQL statement. - * @param sql The input JSON object + * @param json The input JSON object * @return The AST for the dynamic SQL statement */ - std::unique_ptr ParseDynamicSQL(const nlohmann::json &sql); + std::unique_ptr ParseDynamicSQL(const nlohmann::json &json); /** * Parse a SQL expression to an expression AST. * @param sql The SQL expression string * @return The AST for the SQL expression */ - std::unique_ptr ParseExprSQL(const std::string &sql); + std::unique_ptr ParseExprFromSQL(const std::string &sql); /** - * Parse an expression. - * @param expr The expression + * Parse an abstract expression to an expression AST. + * @param expr The abstract expression * @return The AST for the expression */ - std::unique_ptr ParseExpr(common::ManagedPointer expr); + std::unique_ptr ParseExprFromAbstract( + common::ManagedPointer expr); private: /** The UDF AST context */ diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index c00f5fc9dd..f568e9f874 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -66,9 +66,7 @@ std::unique_ptr PLpgSQLParser::Parse(const std throw PARSER_EXCEPTION("PL/pgSQL parsing error"); } // The result is a list, we need to wrap it - const auto ast_json_str = - "{ \"" + std::string{K_FUNCTION_LIST} + "\" : " + std::string{result.plpgsql_funcs} + " }"; // NOLINT - + const auto ast_json_str = fmt::format("{{ \"{}\" : {} }}", K_FUNCTION_LIST, result.plpgsql_funcs); pg_query_free_plpgsql_parse_result(result); std::istringstream ss{ast_json_str}; @@ -83,23 +81,22 @@ std::unique_ptr PLpgSQLParser::Parse(const std // TODO(Kyle): This is a zip(), can we add our own generic // algorithms library somewhere for stuff like this? - std::size_t i{0}; + std::size_t i = 0; for (const auto &udf_name : param_names) { udf_ast_context_->SetVariableType(udf_name, param_types[i++]); } const auto function = function_list[0][K_PLPGSQL_FUNCTION]; - auto function_ast = - std::make_unique(ParseFunction(function), param_names, param_types); - return function_ast; + return std::make_unique(ParseFunction(function), param_names, param_types); } -std::unique_ptr PLpgSQLParser::ParseFunction(const nlohmann::json &function) { - const auto decl_list = function[K_DATUMS]; - const auto function_body = function[K_ACTION][K_PLPGSQL_STMT_BLOCK][K_BODY]; +std::unique_ptr PLpgSQLParser::ParseFunction(const nlohmann::json &json) { + const auto decl_list = json[K_DATUMS]; + NOISEPAGE_ASSERT(decl_list.is_array(), "Declaration list is not an array"); + + const auto function_body = json[K_ACTION][K_PLPGSQL_STMT_BLOCK][K_BODY]; std::vector> stmts{}; - NOISEPAGE_ASSERT(decl_list.is_array(), "Declaration list is not an array"); for (std::size_t i = 1UL; i < decl_list.size(); i++) { stmts.push_back(ParseDecl(decl_list[i])); } @@ -108,15 +105,15 @@ std::unique_ptr PLpgSQLParser::ParseFunction(const return std::make_unique(std::move(stmts)); } -std::unique_ptr PLpgSQLParser::ParseBlock(const nlohmann::json &block) { +std::unique_ptr PLpgSQLParser::ParseBlock(const nlohmann::json &json) { // TODO(boweic): Support statements size other than 1 - NOISEPAGE_ASSERT(block.is_array(), "Block isn't array"); - if (block.empty()) { + NOISEPAGE_ASSERT(json.is_array(), "Block isn't array"); + if (json.empty()) { throw PARSER_EXCEPTION("PL/pgSQL parser : Empty block is not supported"); } std::vector> statements{}; - for (const auto &statement : block) { + for (const auto &statement : json) { std::cout << statement << std::endl; const std::string &statement_type = statement.items().begin().key(); if (statement_type == K_PLPGSQL_STMT_RETURN) { @@ -124,7 +121,8 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl if (statement[K_PLPGSQL_STMT_RETURN].empty()) { throw NOT_IMPLEMENTED_EXCEPTION("RETURN without expression not implemented."); } - auto expr = ParseExprSQL(statement[K_PLPGSQL_STMT_RETURN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); + auto expr = + ParseExprFromSQL(statement[K_PLPGSQL_STMT_RETURN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); statements.push_back(std::make_unique(std::move(expr))); } else if (statement_type == K_PLPGSQL_STMT_IF) { statements.push_back(ParseIf(statement[K_PLPGSQL_STMT_IF])); @@ -135,7 +133,7 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl const auto &var_name = udf_ast_context_->GetLocalAtIndex(statement[K_PLPGSQL_STMT_ASSIGN][K_VARNO].get() - 1); auto lhs = std::make_unique(var_name); - auto rhs = ParseExprSQL(statement[K_PLPGSQL_STMT_ASSIGN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); + auto rhs = ParseExprFromSQL(statement[K_PLPGSQL_STMT_ASSIGN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); statements.push_back(std::make_unique(std::move(lhs), std::move(rhs))); } else if (statement_type == K_PLPGSQL_STMT_WHILE) { statements.push_back(ParseWhile(statement[K_PLPGSQL_STMT_WHILE])); @@ -155,22 +153,22 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl return std::make_unique(std::move(statements)); } -std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &decl) { - const auto &declaration_type = decl.items().begin().key(); +std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &json) { + const auto &declaration_type = json.items().begin().key(); if (declaration_type == K_PLPGSQL_VAR) { - auto var_name = decl[K_PLPGSQL_VAR][K_REFNAME].get(); + auto var_name = json[K_PLPGSQL_VAR][K_REFNAME].get(); // Track the local variable (for assignment) udf_ast_context_->AddLocal(var_name); // Grab the type identifier from the PL/pgSQL parser const std::string type = StringUtils::Strip( - StringUtils::Lower(decl[K_PLPGSQL_VAR][K_DATATYPE][K_PLPGSQL_TYPE][K_TYPENAME].get())); + StringUtils::Lower(json[K_PLPGSQL_VAR][K_DATATYPE][K_PLPGSQL_TYPE][K_TYPENAME].get())); // Parse the initializer, if present std::unique_ptr initial{nullptr}; - if (decl[K_PLPGSQL_VAR].find(K_DEFAULT_VAL) != decl[K_PLPGSQL_VAR].end()) { - initial = ParseExprSQL(decl[K_PLPGSQL_VAR][K_DEFAULT_VAL][K_PLPGSQL_EXPR][K_QUERY].get()); + if (json[K_PLPGSQL_VAR].find(K_DEFAULT_VAL) != json[K_PLPGSQL_VAR].end()) { + initial = ParseExprFromSQL(json[K_PLPGSQL_VAR][K_DEFAULT_VAL][K_PLPGSQL_EXPR][K_QUERY].get()); } // Detemine if the variable has already been declared; @@ -209,7 +207,7 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo } if (declaration_type == K_PLPGSQL_ROW) { - const auto var_name = decl[K_PLPGSQL_ROW][K_REFNAME].get(); + const auto var_name = json[K_PLPGSQL_ROW][K_REFNAME].get(); NOISEPAGE_ASSERT(var_name == "*internal*", "Unexpected refname"); // TODO(Kyle): Support row types later @@ -221,42 +219,40 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo throw PARSER_EXCEPTION(fmt::format("Declaration type '{}' not supported", declaration_type)); } -std::unique_ptr PLpgSQLParser::ParseIf(const nlohmann::json &branch) { - auto cond_expr = ParseExprSQL(branch[K_COND][K_PLPGSQL_EXPR][K_QUERY].get()); - auto then_stmt = ParseBlock(branch[K_THEN_BODY]); - std::unique_ptr else_stmt = nullptr; - if (branch.find(K_ELSE_BODY) != branch.end()) { - else_stmt = ParseBlock(branch[K_ELSE_BODY]); - } +std::unique_ptr PLpgSQLParser::ParseIf(const nlohmann::json &json) { + auto cond_expr = ParseExprFromSQL(json[K_COND][K_PLPGSQL_EXPR][K_QUERY].get()); + auto then_stmt = ParseBlock(json[K_THEN_BODY]); + std::unique_ptr else_stmt = + json.contains(K_ELSE_BODY) ? ParseBlock(json[K_ELSE_BODY]) : nullptr; return std::make_unique(std::move(cond_expr), std::move(then_stmt), std::move(else_stmt)); } -std::unique_ptr PLpgSQLParser::ParseWhile(const nlohmann::json &loop) { - auto cond_expr = ParseExprSQL(loop[K_COND][K_PLPGSQL_EXPR][K_QUERY].get()); - auto body_stmt = ParseBlock(loop[K_BODY]); +std::unique_ptr PLpgSQLParser::ParseWhile(const nlohmann::json &json) { + auto cond_expr = ParseExprFromSQL(json[K_COND][K_PLPGSQL_EXPR][K_QUERY].get()); + auto body_stmt = ParseBlock(json[K_BODY]); return std::make_unique(std::move(cond_expr), std::move(body_stmt)); } -std::unique_ptr PLpgSQLParser::ParseForI(const nlohmann::json &loop) { - const auto name = loop[K_VAR][K_PLPGSQL_VAR][K_REFNAME].get(); - auto lower = ParseExprSQL(loop[K_LOWER][K_PLPGSQL_EXPR][K_QUERY]); - auto upper = ParseExprSQL(loop[K_UPPER][K_PLPGSQL_EXPR][K_QUERY]); - auto step = loop.contains(K_STEP) ? ParseExprSQL(loop[K_STEP][K_PLPGSQL_EXPR][K_QUERY]) - : ParseExprSQL(execution::ast::udf::ForIStmtAST::DEFAULT_STEP_EXPR); - auto body = ParseBlock(loop[K_BODY]); +std::unique_ptr PLpgSQLParser::ParseForI(const nlohmann::json &json) { + const auto name = json[K_VAR][K_PLPGSQL_VAR][K_REFNAME].get(); + auto lower = ParseExprFromSQL(json[K_LOWER][K_PLPGSQL_EXPR][K_QUERY]); + auto upper = ParseExprFromSQL(json[K_UPPER][K_PLPGSQL_EXPR][K_QUERY]); + auto step = json.contains(K_STEP) ? ParseExprFromSQL(json[K_STEP][K_PLPGSQL_EXPR][K_QUERY]) + : ParseExprFromSQL(execution::ast::udf::ForIStmtAST::DEFAULT_STEP_EXPR); + auto body = ParseBlock(json[K_BODY]); return std::make_unique(name, std::move(lower), std::move(upper), std::move(step), std::move(body)); } -std::unique_ptr PLpgSQLParser::ParseForS(const nlohmann::json &loop) { - const auto sql_query = loop[K_QUERY][K_PLPGSQL_EXPR][K_QUERY].get(); +std::unique_ptr PLpgSQLParser::ParseForS(const nlohmann::json &json) { + const auto sql_query = json[K_QUERY][K_PLPGSQL_EXPR][K_QUERY].get(); auto parse_result = PostgresParser::BuildParseTree(sql_query); if (parse_result == nullptr) { return nullptr; } - auto body_stmt = ParseBlock(loop[K_BODY]); - auto var_array = loop[K_ROW][K_PLPGSQL_ROW][K_FIELDS]; + auto body_stmt = ParseBlock(json[K_BODY]); + auto var_array = json[K_ROW][K_PLPGSQL_ROW][K_FIELDS]; std::vector variables{}; variables.reserve(var_array.size()); std::transform(var_array.cbegin(), var_array.cend(), std::back_inserter(variables), @@ -265,11 +261,11 @@ std::unique_ptr PLpgSQLParser::ParseForS(const nlo std::move(body_stmt)); } -std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &sql) { +std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &json) { // The query text - const auto sql_query = sql[K_SQLSTMT][K_PLPGSQL_EXPR][K_QUERY].get(); + const auto sql_query = json[K_SQLSTMT][K_PLPGSQL_EXPR][K_QUERY].get(); // The variable name (non-const for later std::move) - auto var_name = sql[K_ROW][K_PLPGSQL_ROW][K_FIELDS][0][K_NAME].get(); + auto var_name = json[K_ROW][K_PLPGSQL_ROW][K_FIELDS][0][K_NAME].get(); auto parse_result = PostgresParser::BuildParseTree(sql_query); if (parse_result == nullptr) { @@ -308,13 +304,13 @@ std::unique_ptr PLpgSQLParser::ParseSQL(const nloh std::move(query_params)); } -std::unique_ptr PLpgSQLParser::ParseDynamicSQL(const nlohmann::json &sql) { - auto sql_expr = ParseExprSQL(sql[K_QUERY][K_PLPGSQL_EXPR][K_QUERY].get()); - auto var_name = sql[K_ROW][K_PLPGSQL_ROW][K_FIELDS][0][K_NAME].get(); +std::unique_ptr PLpgSQLParser::ParseDynamicSQL(const nlohmann::json &json) { + auto sql_expr = ParseExprFromSQL(json[K_QUERY][K_PLPGSQL_EXPR][K_QUERY].get()); + auto var_name = json[K_ROW][K_PLPGSQL_ROW][K_FIELDS][0][K_NAME].get(); return std::make_unique(std::move(sql_expr), std::move(var_name)); } -std::unique_ptr PLpgSQLParser::ParseExprSQL(const std::string &sql) { +std::unique_ptr PLpgSQLParser::ParseExprFromSQL(const std::string &sql) { auto stmt_list = PostgresParser::BuildParseTree(sql); if (stmt_list == nullptr) { return nullptr; @@ -326,10 +322,10 @@ std::unique_ptr PLpgSQLParser::ParseExprSQL(const "Unsupported SQL Expr in UDF"); auto &select_list = stmt.CastManagedPointerTo()->GetSelectColumns(); NOISEPAGE_ASSERT(select_list.size() == 1, "Unsupported number of select columns in UDF"); - return PLpgSQLParser::ParseExpr(select_list[0]); + return PLpgSQLParser::ParseExprFromAbstract(select_list[0]); } -std::unique_ptr PLpgSQLParser::ParseExpr( +std::unique_ptr PLpgSQLParser::ParseExprFromAbstract( common::ManagedPointer expr) { if (expr->GetExpressionType() == parser::ExpressionType::COLUMN_VALUE) { auto cve = expr.CastManagedPointerTo(); @@ -342,8 +338,8 @@ std::unique_ptr PLpgSQLParser::ParseExpr( if ((parser::ExpressionUtil::IsOperatorExpression(expr->GetExpressionType()) && expr->GetChildrenSize() == 2) || (parser::ExpressionUtil::IsComparisonExpression(expr->GetExpressionType()))) { - return std::make_unique(expr->GetExpressionType(), ParseExpr(expr->GetChild(0)), - ParseExpr(expr->GetChild(1))); + return std::make_unique( + expr->GetExpressionType(), ParseExprFromAbstract(expr->GetChild(0)), ParseExprFromAbstract(expr->GetChild(1))); } // TODO(Kyle): I am not a fan of non-exhaustive switch statements; @@ -355,16 +351,16 @@ std::unique_ptr PLpgSQLParser::ParseExpr( std::vector> args{}; auto num_args = func_expr->GetChildrenSize(); for (size_t idx = 0; idx < num_args; ++idx) { - args.push_back(ParseExpr(func_expr->GetChild(idx))); + args.push_back(ParseExprFromAbstract(func_expr->GetChild(idx))); } return std::make_unique(func_expr->GetFuncName(), std::move(args)); } case parser::ExpressionType::VALUE_CONSTANT: return std::make_unique(expr->Copy()); case parser::ExpressionType::OPERATOR_IS_NOT_NULL: - return std::make_unique(false, ParseExpr(expr->GetChild(0))); + return std::make_unique(false, ParseExprFromAbstract(expr->GetChild(0))); case parser::ExpressionType::OPERATOR_IS_NULL: - return std::make_unique(true, ParseExpr(expr->GetChild(0))); + return std::make_unique(true, ParseExprFromAbstract(expr->GetChild(0))); default: throw PARSER_EXCEPTION("PL/pgSQL parser : Expression type not supported"); } From bf56e2a2ecdfa6010d7469f46178f206e5541709 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 23 Jul 2021 11:23:37 -0400 Subject: [PATCH 081/139] small changes while trying to get iteration over query results to work, looks like there may be more than one issue present here --- script/testing/junit/sql/udf.sql | 46 +++++++++----- script/testing/junit/traces/udf.test | 28 +++++++-- src/execution/compiler/udf/udf_codegen.cpp | 9 +-- src/include/binder/bind_node_visitor.h | 1 + src/include/parser/parse_result.h | 13 ++-- src/include/parser/select_statement.h | 2 +- src/include/parser/udf/plpgsql_parse_result.h | 35 +++++++++++ src/include/parser/udf/plpgsql_parser.h | 9 +++ src/parser/udf/plpgsql_parser.cpp | 60 +++++++++---------- 9 files changed, 141 insertions(+), 62 deletions(-) create mode 100644 src/include/parser/udf/plpgsql_parse_result.h diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index f1ac82c2be..ef1ec84039 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -123,19 +123,35 @@ SELECT x, proc_while() FROM integers; -- ---------------------------------------------------------------------------- -- proc_fors() -CREATE TABLE temp(z INT); -INSERT INTO temp(z) VALUES (0), (1); - -CREATE FUNCTION proc_fors() RETURNS INT AS $$ \ -DECLARE \ - x INT := 0; \ - v RECORD; \ -BEGIN \ - FOR v IN (SELECT z FROM temp) LOOP \ - x = x + 1; \ - END LOOP; \ - RETURN x; \ -END \ -$$ LANGUAGE PLPGSQL; +-- CREATE TABLE tmp(z INT); +-- INSERT INTO tmp(z) VALUES (0), (1); + +-- -- Bind query result to a RECORD type +-- CREATE FUNCTION proc_fors_rec() RETURNS INT AS $$ \ +-- DECLARE \ +-- x INT := 0; \ +-- v RECORD; \ +-- BEGIN \ +-- FOR v IN (SELECT z FROM temp) LOOP \ +-- x = x + 1; \ +-- END LOOP; \ +-- RETURN x; \ +-- END \ +-- $$ LANGUAGE PLPGSQL; + +-- SELECT x, proc_fors_rec() FROM integers; + +-- -- Bind query result directly to INT type +-- CREATE FUNCTION proc_fors_var() RETURNS INT AS $$ \ +-- DECLARE \ +-- x INT := 0; \ +-- v INT; \ +-- BEGIN \ +-- FOR v IN (SELECT z FROM tmp) LOOP \ +-- x = x + 1; \ +-- END LOOP; \ +-- RETURN x; \ +-- END \ +-- $$ LANGUAGE PLPGSQL; -SELECT x, proc_fors() FROM integers; +-- SELECT x, proc_fors_var() FROM integers; diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 90414bd5e2..935f0a637f 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -292,23 +292,41 @@ statement ok statement ok --- CREATE TABLE temp(z INT); +-- CREATE TABLE tmp(z INT); statement ok --- INSERT INTO temp(z) VALUES (0), (1); +-- INSERT INTO tmp(z) VALUES (0), (1); statement ok statement ok --- CREATE FUNCTION proc_fors() RETURNS INT AS $$ -- DECLARE \ +-- -- Bind query result to a RECORD type statement ok --- x INT := 0; -- BEGIN -- FOR v IN SELECT z FROM temp -- LOOP -- x = x + 1; -- END LOOP; -- RETURN x; -- END -- $$ LANGUAGE PLPGSQL; +-- CREATE FUNCTION proc_fors_rec() RETURNS INT AS $$ -- DECLARE \ + +statement ok +-- x INT := 0; -- v RECORD; -- BEGIN -- FOR v IN (SELECT z FROM temp) LOOP -- x = x + 1; -- END LOOP; -- RETURN x; -- END -- $$ LANGUAGE PLPGSQL; + +statement ok + + +statement ok +-- SELECT x, proc_fors_rec() FROM integers; + +statement ok + + +statement ok +-- -- Bind query result directly to INT type + +statement ok +-- CREATE FUNCTION proc_fors_var() RETURNS INT AS $$ -- DECLARE -- x INT := 0; -- v INT; -- BEGIN -- FOR v IN (SELECT z FROM tmp) LOOP -- x = x + 1; -- END LOOP; -- RETURN x; -- END -- $$ LANGUAGE PLPGSQL; statement ok statement ok --- SELECT x, proc_fors() FROM integers; +-- SELECT x, proc_fors_var() FROM integers; diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index a269f8bf52..1fd39005aa 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -353,6 +353,7 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { accessor_->GetTxn(), common::ManagedPointer(accessor_), query, db_oid_, common::ManagedPointer(&stats), std::make_unique(), optimizer_timeout, nullptr); auto plan = optimizer_result->GetPlanNode(); + NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, "UDF support for non-scalars is not implemented"); // Make a lambda that just writes into this std::vector var_idents{}; @@ -361,14 +362,11 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { params.push_back(codegen_->MakeField( exec_ctx->As()->Name(), codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); - std::size_t i{0}; + std::size_t i = 0; for (const auto &var : ast->Variables()) { var_idents.push_back(SymbolTable().find(var)->second); auto var_ident = var_idents.back(); - NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, - "UDF support for non-scalars is not implemented"); auto type = codegen_->TplType(execution::sql::GetTypeId(plan->GetOutputSchema()->GetColumn(i).GetType())); - fb_->Append(codegen_->Assign(codegen_->MakeExpr(var_ident), codegen_->ConstNull(plan->GetOutputSchema()->GetColumn(i).GetType()))); auto input = codegen_->MakeFreshIdentifier(var); @@ -379,7 +377,7 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { execution::ast::LambdaExpr *lambda_expr{}; FunctionBuilder fn{codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)}; { - std::size_t j{1}; + std::size_t j = 1; for (auto var : var_idents) { fn.Append(codegen_->Assign(codegen_->MakeExpr(var), fn.GetParameterByPosition(j))); j++; @@ -392,7 +390,6 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { execution::util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; for (const auto &[name, identifier] : SymbolTable()) { - // TODO(Kyle): Why do we skip this particular identifier? if (name == "executionCtx") { continue; } diff --git a/src/include/binder/bind_node_visitor.h b/src/include/binder/bind_node_visitor.h index 851abea615..0cb9f23677 100644 --- a/src/include/binder/bind_node_visitor.h +++ b/src/include/binder/bind_node_visitor.h @@ -58,6 +58,7 @@ class BindNodeVisitor final : public SqlNodeVisitor { * @param udf_ast_context The AST context for the UDF. * @return The map of UDF parameters: * Column Name -> (Parameter Name, Parameter Index) + * @throws BinderException on failure to bind query */ std::unordered_map> BindAndGetUDFParams( common::ManagedPointer parse_result, diff --git a/src/include/parser/parse_result.h b/src/include/parser/parse_result.h index 8a82aa98c2..6013b0f489 100644 --- a/src/include/parser/parse_result.h +++ b/src/include/parser/parse_result.h @@ -52,10 +52,15 @@ class ParseResult { */ uint32_t NumStatements() const { return statements_.size(); } - /** - * @return the statement at a particular index - */ - common::ManagedPointer GetStatement(size_t idx) { return common::ManagedPointer(statements_[idx]); } + /** @return The statement at index `index`*/ + common::ManagedPointer GetStatement(std::size_t idx) { + return common::ManagedPointer(statements_[idx]); + } + + /** @return The statement at a index `index` */ + common::ManagedPointer GetStatement(std::size_t idx) const { + return common::ManagedPointer(statements_.at(idx).get()); + } /** * @return non-owning list of all the expressions contained in this parse result diff --git a/src/include/parser/select_statement.h b/src/include/parser/select_statement.h index 21c748a384..9bf1f165e7 100644 --- a/src/include/parser/select_statement.h +++ b/src/include/parser/select_statement.h @@ -355,7 +355,7 @@ class SelectStatement : public SQLStatement { std::unique_ptr Copy(); /** @return The columns targeted by SELECT */ - const std::vector> &GetSelectColumns() { return select_; } + const std::vector> &GetSelectColumns() const { return select_; } /** @return `true` if "SELECT DISTINCT", `false` otherwise */ bool IsSelectDistinct() const { return select_distinct_; } diff --git a/src/include/parser/udf/plpgsql_parse_result.h b/src/include/parser/udf/plpgsql_parse_result.h new file mode 100644 index 0000000000..39457ab584 --- /dev/null +++ b/src/include/parser/udf/plpgsql_parse_result.h @@ -0,0 +1,35 @@ +#pragma once + +#include + +#include "libpg_query/pg_query.h" + +namespace noisepage::parser::udf { + +/** + * The PLpgSQLParseResult class is a simple RAII + * wrapper for the parse result returned by libpq_query. + * + * NOTE: Could just do this with a std::unique_ptr with + * a default deleter, but this is more pleasant. + */ +class PLpgSQLParseResult { + public: + /** + * Construct a new PLpgSQLParseResult instance. + * @param result The raw result + */ + explicit PLpgSQLParseResult(PgQueryPlpgsqlParseResult &&result) : result_{result} {} + + /** Release resources from the parse result */ + ~PLpgSQLParseResult() { pg_query_free_plpgsql_parse_result(result_); } + + /** @return An immutable reference to the underlying result */ + const PgQueryPlpgsqlParseResult &operator*() const { return result_; } + + private: + /** The underlying parse result */ + PgQueryPlpgsqlParseResult result_; +}; + +} // namespace noisepage::parser::udf diff --git a/src/include/parser/udf/plpgsql_parser.h b/src/include/parser/udf/plpgsql_parser.h index 792342827f..bf9d660742 100644 --- a/src/include/parser/udf/plpgsql_parser.h +++ b/src/include/parser/udf/plpgsql_parser.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include "catalog/catalog_accessor.h" @@ -128,6 +129,14 @@ class PLpgSQLParser { std::unique_ptr ParseExprFromAbstract( common::ManagedPointer expr); + private: + /** + * Resolve a PL/pgSQL RECORD type from a SELECT statement. + * @param parse_result The result of parsing the SQL query + * @return The resolved record type + */ + std::vector> ResolveRecordType(const ParseResult *parse_result); + private: /** The UDF AST context */ common::ManagedPointer udf_ast_context_; diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index f568e9f874..240cc5a018 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -2,6 +2,7 @@ #include "binder/bind_node_visitor.h" #include "execution/ast/udf/udf_ast_nodes.h" +#include "parser/udf/plpgsql_parse_result.h" #include "parser/udf/plpgsql_parser.h" #include "parser/udf/string_utils.h" @@ -60,14 +61,12 @@ static constexpr const char DECL_TYPE_ID_RECORD[] = "record"; std::unique_ptr PLpgSQLParser::Parse(const std::vector ¶m_names, const std::vector ¶m_types, const std::string &func_body) { - auto result = pg_query_parse_plpgsql(func_body.c_str()); - if (result.error != nullptr) { - pg_query_free_plpgsql_parse_result(result); + auto result = PLpgSQLParseResult{pg_query_parse_plpgsql(func_body.c_str())}; + if ((*result).error != nullptr) { throw PARSER_EXCEPTION("PL/pgSQL parsing error"); } // The result is a list, we need to wrap it - const auto ast_json_str = fmt::format("{{ \"{}\" : {} }}", K_FUNCTION_LIST, result.plpgsql_funcs); - pg_query_free_plpgsql_parse_result(result); + const auto ast_json_str = fmt::format("{{ \"{}\" : {} }}", K_FUNCTION_LIST, (*result).plpgsql_funcs); std::istringstream ss{ast_json_str}; nlohmann::json ast_json{}; @@ -114,12 +113,11 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl std::vector> statements{}; for (const auto &statement : json) { - std::cout << statement << std::endl; const std::string &statement_type = statement.items().begin().key(); if (statement_type == K_PLPGSQL_STMT_RETURN) { // TODO(Kyle): Handle RETURN without expression if (statement[K_PLPGSQL_STMT_RETURN].empty()) { - throw NOT_IMPLEMENTED_EXCEPTION("RETURN without expression not implemented."); + throw NOT_IMPLEMENTED_EXCEPTION("PL/pgSQL Parser : RETURN without expression not implemented."); } auto expr = ParseExprFromSQL(statement[K_PLPGSQL_STMT_RETURN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); @@ -146,7 +144,7 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl } else if (statement_type == K_PLPGSQL_STMT_DYNEXECUTE) { statements.push_back(ParseDynamicSQL(statement[K_PLPGSQL_STMT_DYNEXECUTE])); } else { - throw PARSER_EXCEPTION(fmt::format("Statement type '{}' not supported", statement_type)); + throw PARSER_EXCEPTION(fmt::format("PL/pgSQL Parser : statement type '{}' not supported", statement_type)); } } @@ -203,20 +201,19 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo return std::make_unique(var_name, type::TypeId::INVALID, std::move(initial)); } - throw PARSER_EXCEPTION(fmt::format("Unsupported type '{}' for variable '{}'", type, var_name)); + throw PARSER_EXCEPTION(fmt::format("PL/pgSQL Parser : unsupported type '{}' for variable '{}'", type, var_name)); } + // TODO(Kyle): Support row types later if (declaration_type == K_PLPGSQL_ROW) { const auto var_name = json[K_PLPGSQL_ROW][K_REFNAME].get(); NOISEPAGE_ASSERT(var_name == "*internal*", "Unexpected refname"); - - // TODO(Kyle): Support row types later udf_ast_context_->SetVariableType(var_name, type::TypeId::INVALID); return std::make_unique(var_name, type::TypeId::INVALID, nullptr); } // TODO(Kyle): Need to handle other types like row, table etc; - throw PARSER_EXCEPTION(fmt::format("Declaration type '{}' not supported", declaration_type)); + throw PARSER_EXCEPTION(fmt::format("PL/pgSQL Parser : declaration type '{}' not supported", declaration_type)); } std::unique_ptr PLpgSQLParser::ParseIf(const nlohmann::json &json) { @@ -272,32 +269,21 @@ std::unique_ptr PLpgSQLParser::ParseSQL(const nloh return nullptr; } + // Bind the query within the UDF body; if binding + // fails, we allow the BinderException to propogate binder::BindNodeVisitor visitor{accessor_, db_oid_}; - std::unordered_map> query_params{}; - try { - // TODO(Matt): I don't think the binder should need the database name. - // It's already bound in the ConnectionContext binder::BindNodeVisitor visitor(accessor_, db_oid_); - query_params = visitor.BindAndGetUDFParams(common::ManagedPointer{parse_result}, udf_ast_context_); - } catch (BinderException &b) { - return nullptr; - } + auto query_params = visitor.BindAndGetUDFParams(common::ManagedPointer{parse_result}, udf_ast_context_); // Check to see if a record type can be bound to this const auto type = udf_ast_context_->GetVariableType(var_name); if (!type.has_value()) { - throw PARSER_EXCEPTION("PL/pgSQL parser: variable was not declared"); + throw PARSER_EXCEPTION("PL/pgSQL parser : variable was not declared"); } if (type.value() == type::TypeId::INVALID) { - std::vector> elems{}; - const auto &select_columns = - parse_result->GetStatement(0).CastManagedPointerTo()->GetSelectColumns(); - elems.reserve(select_columns.size()); - std::transform(select_columns.cbegin(), select_columns.cend(), std::back_inserter(elems), - [](const common::ManagedPointer &column) { - return std::make_pair(column->GetAlias().GetName(), column->GetReturnValueType()); - }); - udf_ast_context_->SetRecordType(var_name, std::move(elems)); + // If the type is a RECORD type, derive the structure of + // the type from the columns of the SELECT statement + udf_ast_context_->SetRecordType(var_name, ResolveRecordType(parse_result.get())); } return std::make_unique(std::move(parse_result), std::move(var_name), @@ -362,8 +348,20 @@ std::unique_ptr PLpgSQLParser::ParseExprFromAbstra case parser::ExpressionType::OPERATOR_IS_NULL: return std::make_unique(true, ParseExprFromAbstract(expr->GetChild(0))); default: - throw PARSER_EXCEPTION("PL/pgSQL parser : Expression type not supported"); + throw PARSER_EXCEPTION("PL/pgSQL parser : expression type not supported"); } } +std::vector> PLpgSQLParser::ResolveRecordType(const ParseResult *parse_result) { + std::vector> fields{}; + const auto &select_columns = + parse_result->GetStatement(0).CastManagedPointerTo()->GetSelectColumns(); + fields.reserve(select_columns.size()); + std::transform(select_columns.cbegin(), select_columns.cend(), std::back_inserter(fields), + [](const common::ManagedPointer &column) { + return std::make_pair(column->GetAlias().GetName(), column->GetReturnValueType()); + }); + return fields; +} + } // namespace noisepage::parser::udf From 7b2d2e7e1126352a8ea90da61f9220e47b60cd5b Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sat, 24 Jul 2021 00:20:05 -0400 Subject: [PATCH 082/139] more refactoring in udf code generation during attempts to resolve the bug in embedded query compilation --- script/testing/junit/sql/udf.sql | 8 + src/execution/compiler/udf/udf_codegen.cpp | 330 +++++++++--------- .../execution/compiler/udf/udf_codegen.h | 41 ++- src/parser/udf/plpgsql_parser.cpp | 6 +- 4 files changed, 219 insertions(+), 166 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index ef1ec84039..a6bced5e35 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -106,6 +106,8 @@ SELECT x, proc_while() FROM integers; -- ---------------------------------------------------------------------------- -- proc_fori() +-- +-- TODO(Kyle): for-loop control flow (integer variant) is not supported -- CREATE FUNCTION proc_fori() RETURNS INT AS $$ \ -- DECLARE \ @@ -122,6 +124,8 @@ SELECT x, proc_while() FROM integers; -- ---------------------------------------------------------------------------- -- proc_fors() +-- +-- TODO(Kyle): for-loop control flow (query variant) is not supported -- CREATE TABLE tmp(z INT); -- INSERT INTO tmp(z) VALUES (0), (1); @@ -155,3 +159,7 @@ SELECT x, proc_while() FROM integers; -- $$ LANGUAGE PLPGSQL; -- SELECT x, proc_fors_var() FROM integers; + +CREATE FUNCTION agg_count() RETURNS INT AS $$ DECLARE v INT; BEGIN SELECT COUNT(z) INTO v FROM tmp; RETURN v; END $$ LANGUAGE PLPGSQL; +CREATE FUNCTION fun() RETURNS INT AS $$ DECLARE a INT; b INT; BEGIN SELECT COUNT(z), COUNT(z) INTO a, b FROM tmp; RETURN a + b; END $$ LANGUAGE PLPGSQL; + diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 1fd39005aa..0d7932d112 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -337,35 +337,35 @@ void UdfCodegen::Visit(ast::udf::WhileStmtAST *ast) { void UdfCodegen::Visit(ast::udf::ForIStmtAST *ast) { throw NOT_IMPLEMENTED_EXCEPTION("ForIStmtAST Not Implemented"); } void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { - // Once we encounter a for-statement we know we need an execution - // context because the loop always draws values from a query + // Executing a SQL query requires an execution context needs_exec_ctx_ = true; + execution::ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); - const auto query = common::ManagedPointer(ast->Query()); - auto exec_ctx = fb_->GetParameterByPosition(0); + // Bind the embedded query + binder::BindNodeVisitor visitor{common::ManagedPointer{accessor_}, db_oid_}; + auto query_params = + visitor.BindAndGetUDFParams(common::ManagedPointer{ast->Query()}, common::ManagedPointer{udf_ast_context_}); - binder::BindNodeVisitor visitor{common::ManagedPointer(accessor_), db_oid_}; - auto query_params = visitor.BindAndGetUDFParams(query, common::ManagedPointer(udf_ast_context_)); + // Optimize the embedded query + auto optimize_result = OptimizeEmbeddedQuery(ast->Query()); + auto plan = optimize_result->GetPlanNode(); + if (plan->GetOutputSchema()->GetColumns().size() > 1) { + throw EXECUTION_EXCEPTION("PL/pgSQL Codegen : support for non-scalars is not implemented", + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } - auto stats = optimizer::StatsStorage(); - const uint64_t optimizer_timeout = 1000000; - auto optimizer_result = trafficcop::TrafficCopUtil::Optimize( - accessor_->GetTxn(), common::ManagedPointer(accessor_), query, db_oid_, common::ManagedPointer(&stats), - std::make_unique(), optimizer_timeout, nullptr); - auto plan = optimizer_result->GetPlanNode(); - NOISEPAGE_ASSERT(plan->GetOutputSchema()->GetColumns().size() == 1, "UDF support for non-scalars is not implemented"); + // Construct a lambda that writes the output of the query + // into the identifiers within the UDF bound to the output - // Make a lambda that just writes into this - std::vector var_idents{}; - auto lam_var = codegen_->MakeFreshIdentifier("looplamb"); - execution::util::RegionVector params(codegen_->GetAstContext()->GetRegion()); + std::vector variable_identifiers{}; + execution::util::RegionVector params{codegen_->GetAstContext()->GetRegion()}; params.push_back(codegen_->MakeField( exec_ctx->As()->Name(), codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); std::size_t i = 0; for (const auto &var : ast->Variables()) { - var_idents.push_back(SymbolTable().find(var)->second); - auto var_ident = var_idents.back(); + variable_identifiers.push_back(SymbolTable().find(var)->second); + auto var_ident = variable_identifiers.back(); auto type = codegen_->TplType(execution::sql::GetTypeId(plan->GetOutputSchema()->GetColumn(i).GetType())); fb_->Append(codegen_->Assign(codegen_->MakeExpr(var_ident), codegen_->ConstNull(plan->GetOutputSchema()->GetColumn(i).GetType()))); @@ -374,11 +374,10 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { i++; } - execution::ast::LambdaExpr *lambda_expr{}; FunctionBuilder fn{codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)}; { std::size_t j = 1; - for (auto var : var_idents) { + for (auto var : variable_identifiers) { fn.Append(codegen_->Assign(codegen_->MakeExpr(var), fn.GetParameterByPosition(j))); j++; } @@ -396,8 +395,9 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { captures.push_back(codegen_->MakeExpr(identifier)); } - lambda_expr = fn.FinishLambda(std::move(captures)); - lambda_expr->SetName(lam_var); + ast::LambdaExpr *lambda_expr = fn.FinishLambda(std::move(captures)); + const ast::Identifier lambda_identifier = codegen_->MakeFreshIdentifier("looplamb"); + lambda_expr->SetName(lambda_identifier); // We want to pass something down that will materialize the lambda // function into lambda_expr and will also feed in a lambda_expr to the compiler @@ -412,14 +412,15 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { auto decls = exec_query->GetDecls(); aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); - fb_->Append(codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLiteralExpr()->TypeRepr()), - lambda_expr)); + fb_->Append(codegen_->DeclareVar( + lambda_identifier, codegen_->LambdaType(lambda_expr->GetFunctionLiteralExpr()->TypeRepr()), lambda_expr)); auto query_state = codegen_->MakeFreshIdentifier("query_state"); fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); // Set its execution context to whatever exec context was passed in here fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); + std::vector>::iterator> sorted_vec{}; for (auto it = query_params.begin(); it != query_params.end(); it++) { sorted_vec.push_back(it); @@ -427,53 +428,22 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second < y->second; }); for (auto entry : sorted_vec) { - // TODO(Kyle): Order these const type::TypeId type = GetVariableType(entry->first); - execution::ast::Builtin builtin{}; - switch (type) { - case type::TypeId::BOOLEAN: - builtin = execution::ast::Builtin::AddParamBool; - break; - case type::TypeId::TINYINT: - builtin = execution::ast::Builtin::AddParamTinyInt; - break; - case type::TypeId::SMALLINT: - builtin = execution::ast::Builtin::AddParamSmallInt; - break; - case type::TypeId::INTEGER: - builtin = execution::ast::Builtin::AddParamInt; - break; - case type::TypeId::BIGINT: - builtin = execution::ast::Builtin::AddParamBigInt; - break; - case type::TypeId::DECIMAL: - builtin = execution::ast::Builtin::AddParamDouble; - break; - case type::TypeId::DATE: - builtin = execution::ast::Builtin::AddParamDate; - break; - case type::TypeId::TIMESTAMP: - builtin = execution::ast::Builtin::AddParamTimestamp; - break; - case type::TypeId::VARCHAR: - builtin = execution::ast::Builtin::AddParamString; - break; - default: - UNREACHABLE("Unsupported parameter type"); - } + const ast::Builtin builtin = AddParamBuiltinForParameterType(type); fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, codegen_->MakeExpr(SymbolTable().at(entry->first))})); } fb_->Append(codegen_->Assign( codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); - auto fns = exec_query->GetFunctionNames(); - for (auto &sub_fn : fns) { - if (sub_fn.find("Run") != std::string::npos) { - fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), - {codegen_->AddressOf(query_state), codegen_->MakeExpr(lam_var)})); + auto function_names = exec_query->GetFunctionNames(); + for (auto &function_name : function_names) { + if (function_name.find("Run") != std::string::npos) { + fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(function_name), + {codegen_->AddressOf(query_state), codegen_->MakeExpr(lambda_identifier)})); } else { - fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), {codegen_->AddressOf(query_state)})); + fb_->Append( + codegen_->Call(codegen_->GetAstContext()->GetIdentifier(function_name), {codegen_->AddressOf(query_state)})); } } @@ -487,44 +457,32 @@ void UdfCodegen::Visit(ast::udf::RetStmtAST *ast) { } void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { - // As soon as we encounter an embedded SQL statement, - // we know we need an execution context + // Executing a SQL query requires an execution context needs_exec_ctx_ = true; - auto exec_ctx = fb_->GetParameterByPosition(0); - const auto query = common::ManagedPointer(ast->Query()); - - binder::BindNodeVisitor visitor(common::ManagedPointer(accessor_), db_oid_); - - const auto &query_params = ast->Parameters(); + ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); - // NOTE(Kyle): Assumptions: - // - This is a valid optimizer timeout - // - No parameters are required for the call to Optimize() - - auto stats = optimizer::StatsStorage(); - const std::uint64_t optimizer_timeout = 1000000; - auto optimize_result = trafficcop::TrafficCopUtil::Optimize( - accessor_->GetTxn(), common::ManagedPointer(accessor_), query, db_oid_, common::ManagedPointer(&stats), - std::make_unique(), optimizer_timeout, nullptr); + // Optimize the query and generate get a reference to the plan + auto optimize_result = OptimizeEmbeddedQuery(ast->Query()); + auto plan = optimize_result->GetPlanNode(); // Make a lambda that just writes into this - auto lam_var = codegen_->MakeFreshIdentifier("lamb"); - auto plan = optimize_result->GetPlanNode(); - auto &cols = plan->GetOutputSchema()->GetColumns(); + // Populate the parameters for the lambda + execution::util::RegionVector lambda_parameters{codegen_->GetAstContext()->GetRegion()}; - execution::util::RegionVector params{codegen_->GetAstContext()->GetRegion()}; - params.push_back(codegen_->MakeField( + // The first parameter is always the execution context + lambda_parameters.push_back(codegen_->MakeField( exec_ctx->As()->Name(), codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); - std::size_t i{0}; + // Derive the remainder of the closure's signature from + // the output schema of the associated query + std::size_t i = 0; std::vector assignees{}; execution::util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; - for (auto &col : cols) { + for (const auto &col : plan->GetOutputSchema()->GetColumns()) { execution::ast::Expr *capture_var = codegen_->MakeExpr(SymbolTable().find(ast->Name())->second); - const type::TypeId type = GetVariableType(ast->Name()); - if (type == type::TypeId::INVALID) { + if (GetVariableType(ast->Name()) == type::TypeId::INVALID) { // Record type const auto struct_vars = GetRecordType(ast->Name()); if (captures.empty()) { @@ -538,22 +496,20 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { } auto *tpl_type = codegen_->TplType(execution::sql::GetTypeId(col.GetType())); auto input_param = codegen_->MakeFreshIdentifier("input"); - params.push_back(codegen_->MakeField(input_param, tpl_type)); + lambda_parameters.push_back(codegen_->MakeField(input_param, tpl_type)); i++; } - execution::ast::LambdaExpr *lambda_expr{}; - FunctionBuilder fn{codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)}; - { - for (auto j = 0UL; j < assignees.size(); ++j) { - auto capture_var = assignees[j]; - auto input_param = fn.GetParameterByPosition(j + 1); - fn.Append(codegen_->Assign(capture_var, input_param)); - } + FunctionBuilder fn{codegen_, std::move(lambda_parameters), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)}; + for (std::size_t j = 0UL; j < assignees.size(); ++j) { + auto capture_var = assignees[j]; + auto input_param = fn.GetParameterByPosition(j + 1); + fn.Append(codegen_->Assign(capture_var, input_param)); } - lambda_expr = fn.FinishLambda(std::move(captures)); - lambda_expr->SetName(lam_var); + const ast::Identifier lambda_identifier = codegen_->MakeFreshIdentifier("lamb"); + ast::LambdaExpr *lambda_expr = fn.FinishLambda(std::move(captures)); + lambda_expr->SetName(lambda_identifier); // We want to pass something down that will materialize the lambda function // into lambda_expr and will also feed in a lambda_expr to the compiler @@ -567,95 +523,66 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { auto decls = exec_query->GetDecls(); aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); - fb_->Append(codegen_->DeclareVar(lam_var, codegen_->LambdaType(lambda_expr->GetFunctionLiteralExpr()->TypeRepr()), - lambda_expr)); + fb_->Append(codegen_->DeclareVar( + lambda_identifier, codegen_->LambdaType(lambda_expr->GetFunctionLiteralExpr()->TypeRepr()), lambda_expr)); // Make query state auto query_state = codegen_->MakeFreshIdentifier("query_state"); fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); - // Set its execution context to whatever exec context was passed in here + // Set its execution context to whatever execution context was passed in here fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); - std::vector>::const_iterator> sorted_vec{}; - for (auto it = query_params.begin(); it != query_params.end(); it++) { - sorted_vec.push_back(it); - } - std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second.second < y->second.second; }); - for (auto entry : sorted_vec) { - // TODO(Kyle): Order these + const std::vector columns = ColumnsSortedByIndex(ast->Parameters()); + const std::vector parameters = ParametersSortedByIndex(ast->Parameters()); + for (std::size_t i = 0; i < columns.size(); ++i) { + const auto &column = columns[i]; + const auto ¶meter = parameters[i]; // TODO(Kyle): This IILE is cool and all... but way more // complex than I would like, all of the logic in this // function deserves a second look to refactor - auto [type, expr] = [=, &entry]() { - if (entry->second.first.length() > 0) { - const auto fields = GetRecordType(entry->second.first); - auto it = std::find_if(fields.cbegin(), fields.cend(), [=](auto p) { return p.first == entry->first; }); + auto [type, expr] = [=, &column, ¶meter]() { + if (parameter.length() > 0) { + const auto fields = GetRecordType(parameter); + auto it = std::find_if(fields.cbegin(), fields.cend(), [=](auto p) { return p.first == column; }); NOISEPAGE_ASSERT(it != fields.cend(), "Broken invariant"); return std::pair{ - it->second, codegen_->AccessStructMember(codegen_->MakeExpr(SymbolTable().at(entry->second.first)), - codegen_->MakeIdentifier(entry->first))}; + it->second, codegen_->AccessStructMember(codegen_->MakeExpr(SymbolTable().at(parameter)), + codegen_->MakeIdentifier(column))}; } - const type::TypeId type = GetVariableType(entry->first); - return std::pair{type, codegen_->MakeExpr(SymbolTable().at(entry->first))}; + const type::TypeId type = GetVariableType(column); + return std::pair{type, codegen_->MakeExpr(SymbolTable().at(column))}; }(); - execution::ast::Builtin builtin{}; - switch (type) { - case type::TypeId::BOOLEAN: - builtin = execution::ast::Builtin::AddParamBool; - break; - case type::TypeId::TINYINT: - builtin = execution::ast::Builtin::AddParamTinyInt; - break; - case type::TypeId::SMALLINT: - builtin = execution::ast::Builtin::AddParamSmallInt; - break; - case type::TypeId::INTEGER: - builtin = execution::ast::Builtin::AddParamInt; - break; - case type::TypeId::BIGINT: - builtin = execution::ast::Builtin::AddParamBigInt; - break; - case type::TypeId::DECIMAL: - builtin = execution::ast::Builtin::AddParamDouble; - break; - case type::TypeId::DATE: - builtin = execution::ast::Builtin::AddParamDate; - break; - case type::TypeId::TIMESTAMP: - builtin = execution::ast::Builtin::AddParamTimestamp; - break; - case type::TypeId::VARCHAR: - builtin = execution::ast::Builtin::AddParamString; - break; - default: - UNREACHABLE("Unsupported parameter type"); - } - fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, expr})); + fb_->Append(codegen_->CallBuiltin(AddParamBuiltinForParameterType(type), {exec_ctx, expr})); } + // Load the execution context member of the query state fb_->Append(codegen_->Assign( codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); - for (auto &col : cols) { + // Generate code to assign to the closure captures + // from the output of the embedded query + const std::size_t n_columns = plan->GetOutputSchema()->GetColumns().size(); + for (const auto &col : plan->GetOutputSchema()->GetColumns()) { execution::ast::Expr *capture_var = codegen_->MakeExpr(SymbolTable().find(ast->Name())->second); - auto *lhs = capture_var; - if (cols.size() > 1) { - // Record struct type - lhs = codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(col.GetName())); - } + execution::ast::Expr *lhs = (n_columns > 1) + ? codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(col.GetName())) + : capture_var; fb_->Append(codegen_->Assign(lhs, codegen_->ConstNull(col.GetType()))); } - auto fns = exec_query->GetFunctionNames(); - for (auto &sub_fn : fns) { - if (sub_fn.find("Run") != std::string::npos) { - fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), - {codegen_->AddressOf(query_state), codegen_->MakeExpr(lam_var)})); + // Manually append calls to each function from the compiled + // executable query (implementing the closure) to the builder + auto function_names = exec_query->GetFunctionNames(); + for (const auto &function_name : function_names) { + if (IsRunFunction(function_name)) { + fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(function_name), + {codegen_->AddressOf(query_state), codegen_->MakeExpr(lambda_identifier)})); } else { - fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(sub_fn), {codegen_->AddressOf(query_state)})); + fb_->Append( + codegen_->Call(codegen_->GetAstContext()->GetIdentifier(function_name), {codegen_->AddressOf(query_state)})); } } @@ -686,4 +613,79 @@ std::vector> UdfCodegen::GetRecordType(cons return type.value(); } +std::unique_ptr UdfCodegen::OptimizeEmbeddedQuery(parser::ParseResult *parsed_query) { + optimizer::StatsStorage stats{}; + const std::uint64_t optimizer_timeout = 1000000; + return trafficcop::TrafficCopUtil::Optimize( + accessor_->GetTxn(), common::ManagedPointer(accessor_), common::ManagedPointer(parsed_query), db_oid_, + common::ManagedPointer(&stats), std::make_unique(), optimizer_timeout, nullptr); +} + +// Static +bool UdfCodegen::IsRunFunction(const std::string &function_name) { + return function_name.find("Run") != std::string::npos; +} + +// Static +ast::Builtin UdfCodegen::AddParamBuiltinForParameterType(type::TypeId parameter_type) { + // TODO(Kyle): Could accomplish this same thing with a compile-time + // dispatch table, but honestly that would be overkill at this point + switch (parameter_type) { + case type::TypeId::BOOLEAN: + return execution::ast::Builtin::AddParamBool; + case type::TypeId::TINYINT: + return execution::ast::Builtin::AddParamTinyInt; + case type::TypeId::SMALLINT: + return execution::ast::Builtin::AddParamSmallInt; + case type::TypeId::INTEGER: + return execution::ast::Builtin::AddParamInt; + case type::TypeId::BIGINT: + return execution::ast::Builtin::AddParamBigInt; + case type::TypeId::DECIMAL: + return execution::ast::Builtin::AddParamDouble; + case type::TypeId::DATE: + return execution::ast::Builtin::AddParamDate; + case type::TypeId::TIMESTAMP: + return execution::ast::Builtin::AddParamTimestamp; + case type::TypeId::VARCHAR: + return execution::ast::Builtin::AddParamString; + default: + UNREACHABLE("Unsupported parameter type"); + } +} + +// Static +std::vector UdfCodegen::ParametersSortedByIndex( + const std::unordered_map> ¶meter_map) { + // TODO(Kyle): This temporary data structure is gross + std::unordered_map parameters{}; + for (const auto &entry : parameter_map) { + // Column Name -> (Parameter Name, Parameter Index) + parameters[entry.second.first] = entry.second.second; + } + std::vector result{}; + result.reserve(parameters.size()); + std::transform(parameters.cbegin(), parameters.cend(), std::back_inserter(result), + [](const std::pair &entry) -> std::string { return entry.first; }); + std::sort(result.begin(), result.end(), [¶meters](const std::string &a, const std::string &b) -> bool { + return parameters.at(a) < parameters.at(b); + }); + return result; +} + +// Static +std::vector UdfCodegen::ColumnsSortedByIndex( + const std::unordered_map> ¶meter_map) { + std::vector result{}; + result.reserve(parameter_map.size()); + std::transform(parameter_map.cbegin(), parameter_map.cend(), std::back_inserter(result), + [](const std::pair> &entry) -> std::string { + return entry.first; + }); + std::sort(result.begin(), result.end(), [¶meter_map](const std::string &a, const std::string &b) -> bool { + return parameter_map.at(a).second < parameter_map.at(b).second; + }); + return result; +} + } // namespace noisepage::execution::compiler::udf diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index c7da0457a1..2d679987e4 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -14,7 +15,11 @@ namespace noisepage::catalog { class CatalogAccessor; -} +} // namespace noisepage::catalog + +namespace noisepage::optimizer { +class OptimizeResult; +} // namespace noisepage::optimizer namespace noisepage::execution { @@ -248,6 +253,40 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { */ std::vector> GetRecordType(const std::string &name) const; + /** + * Run the optimizer on an embedded SQL query. + * @param parsed_query The result of parsing the query + * @return The optimized result + */ + std::unique_ptr OptimizeEmbeddedQuery(parser::ParseResult *parsed_query); + + /** + * Determine the function identified by `name` is a top-level run function. + * @param function_name The name of the function + * @return `true` if the function is a top-level run + * function, `false` otherwise + */ + static bool IsRunFunction(const std::string &function_name); + + /** + * Get the builtin parameter-add function for the specified parameter type. + * @param parameter_type The parameter type + * @return The builtin function to add this parameter + */ + static ast::Builtin AddParamBuiltinForParameterType(type::TypeId parameter_type); + + /** + * Sort the query + */ + static std::vector ParametersSortedByIndex( + const std::unordered_map> ¶meter_map); + + /** + * + */ + static std::vector ColumnsSortedByIndex( + const std::unordered_map> ¶meter_map); + private: /** The string identifier for internal declarations */ constexpr static const char INTERNAL_DECL_ID[] = "*internal*"; diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index 240cc5a018..e800f9bdd1 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -259,6 +259,8 @@ std::unique_ptr PLpgSQLParser::ParseForS(const nlo } std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &json) { + std::cout << json << std::endl; + // The query text const auto sql_query = json[K_SQLSTMT][K_PLPGSQL_EXPR][K_QUERY].get(); // The variable name (non-const for later std::move) @@ -331,12 +333,14 @@ std::unique_ptr PLpgSQLParser::ParseExprFromAbstra // TODO(Kyle): I am not a fan of non-exhaustive switch statements; // is there a way that we can refactor this logic to make it better? + std::cout << parser::ExpressionTypeToShortString(expr->GetExpressionType()) << std::endl; + switch (expr->GetExpressionType()) { case parser::ExpressionType::FUNCTION: { auto func_expr = expr.CastManagedPointerTo(); std::vector> args{}; auto num_args = func_expr->GetChildrenSize(); - for (size_t idx = 0; idx < num_args; ++idx) { + for (std::size_t idx = 0; idx < num_args; ++idx) { args.push_back(ParseExprFromAbstract(func_expr->GetChild(idx))); } return std::make_unique(func_expr->GetFuncName(), std::move(args)); From 253a578b72a0891528267734507c0d8a26bd14fd Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sun, 25 Jul 2021 22:51:16 -0400 Subject: [PATCH 083/139] worked out the issue with embedded SQL, at least for certain SELECT queries, now need to determine which of the fixes is appropriate --- docs/design_udfs.md | 27 ++++++++ script/testing/junit/sql/udf.sql | 24 +++++-- script/testing/junit/traces/udf.test | 66 ++++++++++++------- .../compiler/compilation_context.cpp | 8 ++- .../expression/function_translator.cpp | 4 +- src/execution/compiler/function_builder.cpp | 9 ++- .../compiler/operator/operator_translator.cpp | 1 - src/execution/compiler/pipeline.cpp | 2 +- src/execution/compiler/udf/udf_codegen.cpp | 66 ++++++++----------- .../execution/compiler/compilation_context.h | 10 --- src/parser/udf/plpgsql_parser.cpp | 4 -- 11 files changed, 132 insertions(+), 89 deletions(-) diff --git a/docs/design_udfs.md b/docs/design_udfs.md index 67d0d895e6..44d1589219 100644 --- a/docs/design_udfs.md +++ b/docs/design_udfs.md @@ -8,6 +8,33 @@ This document describes important aspects of the design and implementation of us This section describes known limitations of our implementation of UDFs. +**Parallel Table Scans** + +Consider the following function: + +```sql +CREATE FUNCTION agg_count() RETURNS INT AS $$ +DECLARE +v INT; +BEGIN + SELECT COUNT(z) INTO v FROM tmp; + RETURN v; +END +$$ LANGUAGE PLPGSQL; +``` + +Currently, we fail to generate code for this function. Code generation fails while we attempt to generate code for the embedded SQL query `SELECT COUTN(z) INTO v FROM tmp;`. The plan tree for this query is straighforward: a static aggregation with a sequential table scan as its only child. However, we fail semantic analysis for the generated code because of the presence of the ouutput callback - a TPL closure that takes the output of the query and "writes" it into the variable `v` within the contextion of the function (a simplification, but close enough). When the callback is present, we add the closure itself as an additional parameter to the top-level pipeline functions: +- Wrapper +- Init +- Run +- Teardown + +This is an issue because we expect the callback function for a parallel table vector iterator to have a specific signature, and the presence of the closure violates this signature. + +There are a couple of fixes available for this problem, but the question of which one is "correct" is not straightforward. +- Do we just change the signature in semantic analysis to accept this? What are the correctness implications of this decision? Is it possible that invoking a closure in parallel will result in incorrect or undefined behavior? +- It doesn't seem like we _need_ to push the closure down as an argument to all of the functions that define the pipeline, only some of them. Specifically, the closure is only used (thus far) as a mechanism for pushing results out of the query back into the function in which it is embedded. Therefore, why do we add the closure as an argument to every pipeline function? It seems like this may have been just an "expedient" solution that isn't actually what is required or desired. + **Missing `RETURN`** In Postgres, a PL/pgSQL function that declares a return type but is missing a `RETURN` statement in the body of the function parses successfully, but results in a runtime error when the function is executed. Currently, we fail to parse such functions (which may be directly related to the issue below). diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index a6bced5e35..f5b5c0da06 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -21,7 +21,7 @@ BEGIN \ END \ $$ LANGUAGE PLPGSQL; -SELECT x, return_constant() FROM integers; +SELECT return_constant(); -- ---------------------------------------------------------------------------- -- return_input() @@ -67,7 +67,7 @@ BEGIN \ END \ $$ LANGUAGE PLPGSQL; -SELECT x, y, integer_decl() FROM integers; +SELECT integer_decl(); -- ---------------------------------------------------------------------------- -- conditional() @@ -102,7 +102,7 @@ BEGIN \ END \ $$ LANGUAGE PLPGSQL; -SELECT x, proc_while() FROM integers; +SELECT proc_while(); -- ---------------------------------------------------------------------------- -- proc_fori() @@ -122,6 +122,20 @@ SELECT x, proc_while() FROM integers; -- SELECT x, proc_fori() FROM integers; +-- ---------------------------------------------------------------------------- +-- sql_select_single_constant() + +CREATE FUNCTION sql_select_single_constant() RETURNS INT AS $$ \ +DECLARE \ + v INT; \ +BEGIN \ + SELECT 1 INTO v; \ + RETURN v; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT sql_select_single_constant(); + -- ---------------------------------------------------------------------------- -- proc_fors() -- @@ -159,7 +173,3 @@ SELECT x, proc_while() FROM integers; -- $$ LANGUAGE PLPGSQL; -- SELECT x, proc_fors_var() FROM integers; - -CREATE FUNCTION agg_count() RETURNS INT AS $$ DECLARE v INT; BEGIN SELECT COUNT(z) INTO v FROM tmp; RETURN v; END $$ LANGUAGE PLPGSQL; -CREATE FUNCTION fun() RETURNS INT AS $$ DECLARE a INT; b INT; BEGIN SELECT COUNT(z), COUNT(z) INTO a, b FROM tmp; RETURN a + b; END $$ LANGUAGE PLPGSQL; - diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 935f0a637f..722b4e59bd 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -55,15 +55,10 @@ CREATE FUNCTION return_constant() RETURNS INT AS $$ BEGIN RETURN 1; END $$ LANGU statement ok -query II rowsort -SELECT x, return_constant() FROM integers; +query I rowsort +SELECT return_constant(); ---- 1 -1 -2 -1 -3 -1 statement ok @@ -177,17 +172,9 @@ CREATE FUNCTION integer_decl() RETURNS INT AS $$ DECLARE x INT := 0; BEGIN RETUR statement ok -query III rowsort -SELECT x, y, integer_decl() FROM integers; +query I rowsort +SELECT integer_decl(); ---- -1 -1 -0 -2 -2 -0 -3 -3 0 @@ -247,14 +234,9 @@ CREATE FUNCTION proc_while() RETURNS INT AS $$ DECLARE x INT := 0; BEGIN WHILE x statement ok -query II rowsort -SELECT x, proc_while() FROM integers; +query I rowsort +SELECT proc_while(); ---- -1 -10 -2 -10 -3 10 @@ -267,6 +249,12 @@ statement ok statement ok -- proc_fori() +statement ok +-- + +statement ok +-- TODO(Kyle): for-loop control flow (integer variant) is not supported + statement ok @@ -282,12 +270,42 @@ statement ok statement ok +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- sql_select_single_constant() + +statement ok + + +statement ok +CREATE FUNCTION sql_select_single_constant() RETURNS INT AS $$ DECLARE v INT; BEGIN SELECT 1 INTO v; RETURN v; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT sql_select_single_constant(); +---- +1 + + +statement ok + + statement ok -- ---------------------------------------------------------------------------- statement ok -- proc_fors() +statement ok +-- + +statement ok +-- TODO(Kyle): for-loop control flow (query variant) is not supported + statement ok diff --git a/src/execution/compiler/compilation_context.cpp b/src/execution/compiler/compilation_context.cpp index 29f4a799cd..c82cc2a7ad 100644 --- a/src/execution/compiler/compilation_context.cpp +++ b/src/execution/compiler/compilation_context.cpp @@ -442,8 +442,12 @@ ExpressionTranslator *CompilationContext::LookupTranslator(const parser::Abstrac } std::string CompilationContext::GetFunctionPrefix() const { - return output_callback_ == nullptr ? "Query" + std::to_string(unique_id_) - : output_callback_->GetName().GetString() + "Query" + std::to_string(unique_id_); + // If an output callback is present, we prefix + // each function with the callback name + if (output_callback_ != nullptr) { + return fmt::format("{}Query{}", output_callback_->GetName().GetString(), std::to_string(unique_id_)); + } + return fmt::format("Query{}", std::to_string(unique_id_)); } util::RegionVector CompilationContext::QueryParams() const { diff --git a/src/execution/compiler/expression/function_translator.cpp b/src/execution/compiler/expression/function_translator.cpp index cf69088bcf..40c432407b 100644 --- a/src/execution/compiler/expression/function_translator.cpp +++ b/src/execution/compiler/expression/function_translator.cpp @@ -54,7 +54,7 @@ void FunctionTranslator::DefineHelperFunctions(util::RegionVectorGetAstContext().Get())); auto udf_decls = file->Declarations(); main_fn_ = udf_decls.back()->Name(); - size_t num_added = 0; + std::size_t num_added = 0; for (ast::Decl *udf_decl : udf_decls) { if (udf_decl->IsFunctionDecl()) { decls->insert(decls->begin() + num_added, udf_decl->As()); @@ -74,7 +74,7 @@ void FunctionTranslator::DefineHelperStructs(util::RegionVectorGetFile(), GetCodeGen()->GetAstContext()->GetNodeFactory(), nullptr, GetCodeGen()->GetAstContext().Get())); auto udf_decls = file->Declarations(); - size_t num_added = 0; + std::size_t num_added = 0; for (ast::Decl *udf_decl : udf_decls) { if (udf_decl->IsStructDecl()) { decls->insert(decls->begin() + num_added, udf_decl->As()); diff --git a/src/execution/compiler/function_builder.cpp b/src/execution/compiler/function_builder.cpp index ef968322c5..5f22d71382 100644 --- a/src/execution/compiler/function_builder.cpp +++ b/src/execution/compiler/function_builder.cpp @@ -5,6 +5,9 @@ namespace noisepage::execution::compiler { +// TODO(Kyle): We should refactor this two 2 distinct types: +// the regular old FunctionBuilder and a ClosureBuilder + FunctionBuilder::FunctionBuilder(CodeGen *codegen, ast::Identifier name, util::RegionVector &¶ms, ast::Expr *ret_type) : codegen_{codegen}, @@ -25,7 +28,11 @@ FunctionBuilder::FunctionBuilder(CodeGen *codegen, util::RegionVector, nullptr} {} -FunctionBuilder::~FunctionBuilder() { Finish(); } +FunctionBuilder::~FunctionBuilder() { + if (!IsLambda()) { + Finish(); + } +} ast::Expr *FunctionBuilder::GetParameterByPosition(const std::size_t param_idx) { if (param_idx < params_.size()) { diff --git a/src/execution/compiler/operator/operator_translator.cpp b/src/execution/compiler/operator/operator_translator.cpp index 6573a0a1ff..071ce93fd8 100644 --- a/src/execution/compiler/operator/operator_translator.cpp +++ b/src/execution/compiler/operator/operator_translator.cpp @@ -26,7 +26,6 @@ OperatorTranslator::OperatorTranslator(const planner::AbstractPlanNode &plan, Co pipeline->RegisterStep(this); // Prepare all output expressions. for (const auto &output_column : plan.GetOutputSchema()->GetColumns()) { - compilation_context->SetCurrentOp(this); compilation_context->Prepare(*output_column.GetExpr()); } } diff --git a/src/execution/compiler/pipeline.cpp b/src/execution/compiler/pipeline.cpp index 768f5db7ee..16b707515c 100644 --- a/src/execution/compiler/pipeline.cpp +++ b/src/execution/compiler/pipeline.cpp @@ -27,7 +27,7 @@ Pipeline::Pipeline(CompilationContext *ctx) compilation_context_(ctx), codegen_(compilation_context_->GetCodeGen()), state_var_(codegen_->MakeIdentifier("pipelineState")), - state_(codegen_->MakeIdentifier(fmt::format("P{}{}_State", ctx->GetFunctionPrefix(), id_)), + state_(codegen_->MakeIdentifier(fmt::format("{}_Pipeline{}_State", ctx->GetFunctionPrefix(), id_)), [this](CodeGen *codegen) { return codegen_->MakeExpr(state_var_); }), driver_(nullptr), parallelism_(Parallelism::Parallel), diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 0d7932d112..5932ae2c8c 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -341,11 +341,6 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { needs_exec_ctx_ = true; execution::ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); - // Bind the embedded query - binder::BindNodeVisitor visitor{common::ManagedPointer{accessor_}, db_oid_}; - auto query_params = - visitor.BindAndGetUDFParams(common::ManagedPointer{ast->Query()}, common::ManagedPointer{udf_ast_context_}); - // Optimize the embedded query auto optimize_result = OptimizeEmbeddedQuery(ast->Query()); auto plan = optimize_result->GetPlanNode(); @@ -363,14 +358,13 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { exec_ctx->As()->Name(), codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); std::size_t i = 0; - for (const auto &var : ast->Variables()) { - variable_identifiers.push_back(SymbolTable().find(var)->second); - auto var_ident = variable_identifiers.back(); - auto type = codegen_->TplType(execution::sql::GetTypeId(plan->GetOutputSchema()->GetColumn(i).GetType())); - fb_->Append(codegen_->Assign(codegen_->MakeExpr(var_ident), + for (const auto &variable_name : ast->Variables()) { + const ast::Identifier variable_identifier = SymbolTable().find(variable_name)->second; + variable_identifiers.push_back(variable_identifier); + ast::Expr *type = codegen_->TplType(execution::sql::GetTypeId(plan->GetOutputSchema()->GetColumn(i).GetType())); + fb_->Append(codegen_->Assign(codegen_->MakeExpr(variable_identifier), codegen_->ConstNull(plan->GetOutputSchema()->GetColumn(i).GetType()))); - auto input = codegen_->MakeFreshIdentifier(var); - params.push_back(codegen_->MakeField(input, type)); + params.push_back(codegen_->MakeField(codegen_->MakeFreshIdentifier(variable_name), type)); i++; } @@ -378,8 +372,7 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { { std::size_t j = 1; for (auto var : variable_identifiers) { - fn.Append(codegen_->Assign(codegen_->MakeExpr(var), fn.GetParameterByPosition(j))); - j++; + fn.Append(codegen_->Assign(codegen_->MakeExpr(var), fn.GetParameterByPosition(j++))); } auto prev_fb = fb_; fb_ = &fn; @@ -387,6 +380,9 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { fb_ = prev_fb; } + // Define the captures for the closure + // TODO(Kyle): We are capturing every variable in the symbol table, + // this seems like overkill and may lead to incorrect semantics? execution::util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; for (const auto &[name, identifier] : SymbolTable()) { if (name == "executionCtx") { @@ -396,12 +392,10 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { } ast::LambdaExpr *lambda_expr = fn.FinishLambda(std::move(captures)); - const ast::Identifier lambda_identifier = codegen_->MakeFreshIdentifier("looplamb"); + const ast::Identifier lambda_identifier = codegen_->MakeFreshIdentifier("udfLambda"); lambda_expr->SetName(lambda_identifier); - // We want to pass something down that will materialize the lambda - // function into lambda_expr and will also feed in a lambda_expr to the compiler - // TODO(Kyle): Using a NULL plan metatdata here... + // Materialize the lambda into the lambda expression execution::exec::ExecutionSettings exec_settings{}; const std::string dummy_query{}; auto exec_query = execution::compiler::CompilationContext::Compile( @@ -409,28 +403,27 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { common::ManagedPointer{}, common::ManagedPointer{&dummy_query}, lambda_expr, codegen_->GetAstContext()); - auto decls = exec_query->GetDecls(); - aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); + // Append all of the declarations from the compiled query - fb_->Append(codegen_->DeclareVar( - lambda_identifier, codegen_->LambdaType(lambda_expr->GetFunctionLiteralExpr()->TypeRepr()), lambda_expr)); + aux_decls_.insert(aux_decls_.end(), exec_query->GetDecls().cbegin(), exec_query->GetDecls().cend()); + // Add the closure and query state to the current function auto query_state = codegen_->MakeFreshIdentifier("query_state"); + fb_->Append(codegen_->DeclareVar( + lambda_identifier, codegen_->LambdaType(lambda_expr->GetFunctionLiteralExpr()->TypeRepr()), lambda_expr)); fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); - // Set its execution context to whatever exec context was passed in here + // Set its execution context to whatever execution context was passed in here fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); - std::vector>::iterator> sorted_vec{}; - for (auto it = query_params.begin(); it != query_params.end(); it++) { - sorted_vec.push_back(it); - } - - std::sort(sorted_vec.begin(), sorted_vec.end(), [](auto x, auto y) { return x->second < y->second; }); - for (auto entry : sorted_vec) { - const type::TypeId type = GetVariableType(entry->first); + // Derive the columns and parameter names from the query + binder::BindNodeVisitor visitor{common::ManagedPointer{accessor_}, db_oid_}; + auto query_params = + visitor.BindAndGetUDFParams(common::ManagedPointer{ast->Query()}, common::ManagedPointer{udf_ast_context_}); + for (const auto &column_name : ColumnsSortedByIndex(query_params)) { + const type::TypeId type = GetVariableType(column_name); const ast::Builtin builtin = AddParamBuiltinForParameterType(type); - fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, codegen_->MakeExpr(SymbolTable().at(entry->first))})); + fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, codegen_->MakeExpr(SymbolTable().at(column_name))})); } fb_->Append(codegen_->Assign( @@ -438,7 +431,7 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { auto function_names = exec_query->GetFunctionNames(); for (auto &function_name : function_names) { - if (function_name.find("Run") != std::string::npos) { + if (IsRunFunction(function_name)) { fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(function_name), {codegen_->AddressOf(query_state), codegen_->MakeExpr(lambda_identifier)})); } else { @@ -494,9 +487,8 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { assignees.push_back(capture_var); captures.push_back(capture_var); } - auto *tpl_type = codegen_->TplType(execution::sql::GetTypeId(col.GetType())); - auto input_param = codegen_->MakeFreshIdentifier("input"); - lambda_parameters.push_back(codegen_->MakeField(input_param, tpl_type)); + lambda_parameters.push_back(codegen_->MakeField(codegen_->MakeFreshIdentifier("input"), + codegen_->TplType(execution::sql::GetTypeId(col.GetType())))); i++; } @@ -507,7 +499,7 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { fn.Append(codegen_->Assign(capture_var, input_param)); } - const ast::Identifier lambda_identifier = codegen_->MakeFreshIdentifier("lamb"); + const ast::Identifier lambda_identifier = codegen_->MakeFreshIdentifier("udfLambda"); ast::LambdaExpr *lambda_expr = fn.FinishLambda(std::move(captures)); lambda_expr->SetName(lambda_identifier); diff --git a/src/include/execution/compiler/compilation_context.h b/src/include/execution/compiler/compilation_context.h index 435861f2ba..de9b1769fa 100644 --- a/src/include/execution/compiler/compilation_context.h +++ b/src/include/execution/compiler/compilation_context.h @@ -143,16 +143,6 @@ class CompilationContext { /** @return Query Id associated with the query */ query_id_t GetQueryId() const { return query_id_; } - /** - * @brief Set the current op. - */ - void SetCurrentOp(OperatorTranslator *current_op) { current_op_ = current_op; } - - /** - * @return The current op. - */ - OperatorTranslator *GetCurrentOp() const { return current_op_; } - private: // Private to force use of static Compile() function. explicit CompilationContext(ExecutableQuery *query, query_id_t query_id_, catalog::CatalogAccessor *accessor, diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index e800f9bdd1..100c894a3b 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -259,8 +259,6 @@ std::unique_ptr PLpgSQLParser::ParseForS(const nlo } std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &json) { - std::cout << json << std::endl; - // The query text const auto sql_query = json[K_SQLSTMT][K_PLPGSQL_EXPR][K_QUERY].get(); // The variable name (non-const for later std::move) @@ -333,8 +331,6 @@ std::unique_ptr PLpgSQLParser::ParseExprFromAbstra // TODO(Kyle): I am not a fan of non-exhaustive switch statements; // is there a way that we can refactor this logic to make it better? - std::cout << parser::ExpressionTypeToShortString(expr->GetExpressionType()) << std::endl; - switch (expr->GetExpressionType()) { case parser::ExpressionType::FUNCTION: { auto func_expr = expr.CastManagedPointerTo(); From 6df7d14a5cdf508f91a4c7f86ba528391a2bd7bd Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Tue, 27 Jul 2021 10:32:45 -0400 Subject: [PATCH 084/139] minor tweaks in BindNodeVisitor logic for resolving ColumnValueExpressions from UDFs --- src/binder/bind_node_visitor.cpp | 36 ++++++++++++++------------ src/include/binder/bind_node_visitor.h | 8 ++++++ 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 19e250f1c5..04048cbf96 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -701,15 +701,13 @@ void BindNodeVisitor::Visit(common::ManagedPointerHasVariable(expr->GetColumnName())) { - const type::TypeId type = udf_ast_context_->GetVariableTypeFailFast(expr->GetColumnName()); - expr->SetReturnValueType(type); - std::size_t idx = 0; + if (BindingForUDF() && IsUDFVariable(expr->GetColumnName())) { + // If the variable is not present, add it if (udf_params_.count(expr->GetColumnName()) == 0) { udf_params_[expr->GetColumnName()] = std::make_pair("", udf_params_.size()); - idx = udf_params_.size() - 1; + expr->SetReturnValueType(udf_ast_context_->GetVariableTypeFailFast(expr->GetColumnName())); + expr->SetParamIdx(udf_params_.size() - 1); } - expr->SetParamIdx(idx); } else if (context_ == nullptr || !context_->SetColumnPosTuple(expr)) { throw BINDER_EXCEPTION(fmt::format("column \"{}\" does not exist", col_name), common::ErrorCode::ERRCODE_UNDEFINED_COLUMN); @@ -722,21 +720,23 @@ void BindNodeVisitor::Visit(common::ManagedPointerHasVariable(expr->GetTableName())) { + } else if (BindingForUDF() && IsUDFVariable(expr->GetTableName())) { const type::TypeId type = udf_ast_context_->GetVariableTypeFailFast(expr->GetTableName()); NOISEPAGE_ASSERT(type == type::TypeId::INVALID, "Must be a RECORD type"); const auto fields = udf_ast_context_->GetRecordTypeFailFast(expr->GetTableName()); - auto it = + auto field = std::find_if(fields.cbegin(), fields.cend(), [=](auto p) { return p.first == expr->GetColumnName(); }); - std::size_t idx = 0; - if (it != fields.cend()) { - if (udf_params_.count(expr->GetColumnName()) == 0) { - udf_params_[expr->GetColumnName()] = std::make_pair(expr->GetTableName(), udf_params_.size()); - idx = udf_params_.size() - 1; - } - expr->SetReturnValueType(it->second); - expr->SetParamIdx(idx); + if (field == fields.cend()) { + throw BINDER_EXCEPTION(fmt::format("RECORD type field '{}' not found", expr->GetColumnName()), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + // If the parameter is not present, add it + if (udf_params_.count(expr->GetColumnName()) == 0) { + udf_params_[expr->GetColumnName()] = std::make_pair(expr->GetTableName(), udf_params_.size()); + expr->SetReturnValueType(field->second); + expr->SetParamIdx(udf_params_.size() - 1); } } else if (context_ == nullptr || !context_->CheckNestedTableColumn(table_name, col_name, expr)) { throw BINDER_EXCEPTION(fmt::format("Invalid table reference {}", expr->GetTableName()), @@ -1143,4 +1143,8 @@ void BindNodeVisitor::ValidateAndCorrectInsertValues( bool BindNodeVisitor::BindingForUDF() const { return udf_ast_context_ != nullptr; } +bool BindNodeVisitor::IsUDFVariable(const std::string &identifier) const { + return udf_ast_context_->HasVariable(identifier); +} + } // namespace noisepage::binder diff --git a/src/include/binder/bind_node_visitor.h b/src/include/binder/bind_node_visitor.h index 0cb9f23677..1e5e90be04 100644 --- a/src/include/binder/bind_node_visitor.h +++ b/src/include/binder/bind_node_visitor.h @@ -161,6 +161,14 @@ class BindNodeVisitor final : public SqlNodeVisitor { /** @return `true` if we are binding within the context of a UDF, `false` otherwise */ bool BindingForUDF() const; + + /** + * Determine if the given identifier names a UDF variable. + * @param identifier The variable identifier + * @return `true` if the variable is declared in the UDF + * for which binding is performed, `false` otherwise + */ + bool IsUDFVariable(const std::string &identifier) const; }; } // namespace binder From 30ecd1d6c64101cf681bf0ac6d3af3ed7b0a9fa7 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Wed, 28 Jul 2021 16:56:53 -0400 Subject: [PATCH 085/139] refactors across the binder and UDF code generation interface, attempting to make the logic more comprehensible so that it is not impossible to extend to support new features in the future --- script/testing/junit/sql/udf.sql | 2 + src/binder/bind_node_visitor.cpp | 68 ++-- .../compiler/compilation_context.cpp | 4 +- src/execution/compiler/udf/udf_codegen.cpp | 343 ++++++++++++------ src/include/binder/bind_node_visitor.h | 29 +- src/include/execution/ast/udf/udf_ast_nodes.h | 23 +- .../execution/compiler/compilation_context.h | 3 +- .../execution/compiler/udf/udf_codegen.h | 80 ++++ src/include/parser/udf/plpgsql_parser.h | 15 + src/include/parser/udf/variable_ref.h | 70 ++++ src/parser/udf/plpgsql_parser.cpp | 58 ++- 11 files changed, 520 insertions(+), 175 deletions(-) create mode 100644 src/include/parser/udf/variable_ref.h diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index f5b5c0da06..32694dbcbc 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -173,3 +173,5 @@ SELECT sql_select_single_constant(); -- $$ LANGUAGE PLPGSQL; -- SELECT x, proc_fors_var() FROM integers; + +CREATE FUNCTION fun() RETURNS INT AS $$ DECLARE x INT; y INT; BEGIN SELECT 1, 2 INTO x, y; RETURN x + y; END $$ LANGUAGE PLPGSQL; diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 04048cbf96..ab572e7777 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -64,7 +64,7 @@ void BindNodeVisitor::BindNameToNode( BindNodeVisitor::~BindNodeVisitor() = default; -std::unordered_map> BindNodeVisitor::BindAndGetUDFParams( +std::vector BindNodeVisitor::BindAndGetUDFVariableRefs( common::ManagedPointer parse_result, common::ManagedPointer udf_ast_context) { NOISEPAGE_ASSERT(parse_result != nullptr, "We shouldn't be trying to bind something without a ParseResult."); @@ -74,7 +74,7 @@ std::unordered_map> BindNodeVis sherpa_->GetParseResult()->GetStatement(0)->Accept( common::ManagedPointer(this).CastManagedPointerTo()); // TODO(Kyle): This is strange, why are we returning this member by value? - return udf_params_; + return udf_variable_refs_; } void BindNodeVisitor::Visit(common::ManagedPointer node) { @@ -702,12 +702,8 @@ void BindNodeVisitor::Visit(common::ManagedPointerGetColumnName())) { - // If the variable is not present, add it - if (udf_params_.count(expr->GetColumnName()) == 0) { - udf_params_[expr->GetColumnName()] = std::make_pair("", udf_params_.size()); - expr->SetReturnValueType(udf_ast_context_->GetVariableTypeFailFast(expr->GetColumnName())); - expr->SetParamIdx(udf_params_.size() - 1); - } + // This expression refers to a PL/pgSQL variable + AddUDFVariableReference(expr, expr->GetColumnName()); } else if (context_ == nullptr || !context_->SetColumnPosTuple(expr)) { throw BINDER_EXCEPTION(fmt::format("column \"{}\" does not exist", col_name), common::ErrorCode::ERRCODE_UNDEFINED_COLUMN); @@ -721,23 +717,8 @@ void BindNodeVisitor::Visit(common::ManagedPointerGetTableName())) { - const type::TypeId type = udf_ast_context_->GetVariableTypeFailFast(expr->GetTableName()); - NOISEPAGE_ASSERT(type == type::TypeId::INVALID, "Must be a RECORD type"); - - const auto fields = udf_ast_context_->GetRecordTypeFailFast(expr->GetTableName()); - auto field = - std::find_if(fields.cbegin(), fields.cend(), [=](auto p) { return p.first == expr->GetColumnName(); }); - if (field == fields.cend()) { - throw BINDER_EXCEPTION(fmt::format("RECORD type field '{}' not found", expr->GetColumnName()), - common::ErrorCode::ERRCODE_PLPGSQL_ERROR); - } - - // If the parameter is not present, add it - if (udf_params_.count(expr->GetColumnName()) == 0) { - udf_params_[expr->GetColumnName()] = std::make_pair(expr->GetTableName(), udf_params_.size()); - expr->SetReturnValueType(field->second); - expr->SetParamIdx(udf_params_.size() - 1); - } + // This expression refers to a structural (RECORD) PL/pgSQL variable + AddUDFVariableReference(expr, expr->GetTableName(), expr->GetColumnName()); } else if (context_ == nullptr || !context_->CheckNestedTableColumn(table_name, col_name, expr)) { throw BINDER_EXCEPTION(fmt::format("Invalid table reference {}", expr->GetTableName()), common::ErrorCode::ERRCODE_UNDEFINED_TABLE); @@ -1147,4 +1128,41 @@ bool BindNodeVisitor::IsUDFVariable(const std::string &identifier) const { return udf_ast_context_->HasVariable(identifier); } +bool BindNodeVisitor::HaveUDFVariableRef(const std::string &identifier) const { + auto it = std::find_if(udf_variable_refs_.cbegin(), udf_variable_refs_.cend(), + [&identifier](const parser::udf::VariableRef &ref) { return ref.ColumnName() == identifier; }); + return it != udf_variable_refs_.cend(); +} + +void BindNodeVisitor::AddUDFVariableReference(common::ManagedPointer expr, + const std::string &table_name, const std::string &column_name) { + const type::TypeId type = udf_ast_context_->GetVariableTypeFailFast(table_name); + NOISEPAGE_ASSERT(type == type::TypeId::INVALID, "Must be a RECORD type"); + + // Locate the column name in the structure + const auto fields = udf_ast_context_->GetRecordTypeFailFast(table_name); + auto field = std::find_if(fields.cbegin(), fields.cend(), [=](auto p) { return p.first == expr->GetColumnName(); }); + if (field == fields.cend()) { + throw BINDER_EXCEPTION(fmt::format("RECORD type field '{}' not found", expr->GetColumnName()), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + if (!HaveUDFVariableRef(column_name)) { + const std::size_t index = udf_variable_refs_.size(); + udf_variable_refs_.emplace_back(table_name, column_name, index); + expr->SetReturnValueType(field->second); + expr->SetParamIdx(index); + } +} + +void BindNodeVisitor::AddUDFVariableReference(common::ManagedPointer expr, + const std::string &column_name) { + if (!HaveUDFVariableRef(column_name)) { + const std::size_t index = udf_variable_refs_.size(); + udf_variable_refs_.emplace_back(column_name, index); + expr->SetReturnValueType(udf_ast_context_->GetVariableTypeFailFast(expr->GetColumnName())); + expr->SetParamIdx(index); + } +} + } // namespace noisepage::binder diff --git a/src/execution/compiler/compilation_context.cpp b/src/execution/compiler/compilation_context.cpp index c82cc2a7ad..95f8e06745 100644 --- a/src/execution/compiler/compilation_context.cpp +++ b/src/execution/compiler/compilation_context.cpp @@ -206,8 +206,8 @@ void CompilationContext::GeneratePlan(const planner::AbstractPlanNode &plan, std::unique_ptr CompilationContext::Compile( const planner::AbstractPlanNode &plan, const exec::ExecutionSettings &exec_settings, catalog::CatalogAccessor *accessor, CompilationMode mode, std::optional override_qid, - common::ManagedPointer plan_meta_data, common::ManagedPointer query_text, - ast::LambdaExpr *output_callback, common::ManagedPointer context) { + common::ManagedPointer plan_meta_data, ast::LambdaExpr *output_callback, + common::ManagedPointer context) { // The query for which we're generating code auto query = std::make_unique(plan, exec_settings, accessor->GetTxn()->StartTime(), context.Get()); if (override_qid.has_value()) { diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 5932ae2c8c..aff0c2e58d 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -1,33 +1,26 @@ #include "execution/compiler/udf/udf_codegen.h" +#include "binder/bind_node_visitor.h" +#include "catalog/catalog_accessor.h" #include "common/error/error_code.h" #include "common/error/exception.h" - -#include "binder/bind_node_visitor.h" - #include "execution/ast/ast.h" #include "execution/ast/ast_clone.h" #include "execution/ast/context.h" -#include "planner/plannodes/output_schema.h" - +#include "execution/ast/udf/udf_ast_nodes.h" #include "execution/compiler/compilation_context.h" #include "execution/compiler/executable_query.h" #include "execution/compiler/if.h" #include "execution/compiler/loop.h" #include "execution/exec/execution_settings.h" - -#include "catalog/catalog_accessor.h" #include "optimizer/cost_model/trivial_cost_model.h" #include "optimizer/statistics/stats_storage.h" - -#include "traffic_cop/traffic_cop_util.h" - #include "parser/expression/constant_value_expression.h" #include "parser/postgresparser.h" - -#include "execution/ast/udf/udf_ast_nodes.h" - +#include "parser/udf/variable_ref.h" #include "planner/plannodes/abstract_plan_node.h" +#include "planner/plannodes/output_schema.h" +#include "traffic_cop/traffic_cop_util.h" namespace noisepage::execution::compiler::udf { @@ -400,14 +393,13 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { const std::string dummy_query{}; auto exec_query = execution::compiler::CompilationContext::Compile( *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, std::nullopt, - common::ManagedPointer{}, common::ManagedPointer{&dummy_query}, - lambda_expr, codegen_->GetAstContext()); + common::ManagedPointer{}, lambda_expr, codegen_->GetAstContext()); // Append all of the declarations from the compiled query + auto decls = exec_query->GetDecls(); + aux_decls_.insert(aux_decls_.end(), decls.cbegin(), decls.cend()); - aux_decls_.insert(aux_decls_.end(), exec_query->GetDecls().cbegin(), exec_query->GetDecls().cend()); - - // Add the closure and query state to the current function + // Declare the closure and the query state in the current function auto query_state = codegen_->MakeFreshIdentifier("query_state"); fb_->Append(codegen_->DeclareVar( lambda_identifier, codegen_->LambdaType(lambda_expr->GetFunctionLiteralExpr()->TypeRepr()), lambda_expr)); @@ -418,13 +410,9 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { // Derive the columns and parameter names from the query binder::BindNodeVisitor visitor{common::ManagedPointer{accessor_}, db_oid_}; - auto query_params = - visitor.BindAndGetUDFParams(common::ManagedPointer{ast->Query()}, common::ManagedPointer{udf_ast_context_}); - for (const auto &column_name : ColumnsSortedByIndex(query_params)) { - const type::TypeId type = GetVariableType(column_name); - const ast::Builtin builtin = AddParamBuiltinForParameterType(type); - fb_->Append(codegen_->CallBuiltin(builtin, {exec_ctx, codegen_->MakeExpr(SymbolTable().at(column_name))})); - } + const auto variable_refs = + visitor.BindAndGetUDFVariableRefs(common::ManagedPointer{ast->Query()}, common::ManagedPointer{udf_ast_context_}); + CodegenAddParameters(exec_ctx, variable_refs); fb_->Append(codegen_->Assign( codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); @@ -458,112 +446,44 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { auto optimize_result = OptimizeEmbeddedQuery(ast->Query()); auto plan = optimize_result->GetPlanNode(); - // Make a lambda that just writes into this - - // Populate the parameters for the lambda - execution::util::RegionVector lambda_parameters{codegen_->GetAstContext()->GetRegion()}; - - // The first parameter is always the execution context - lambda_parameters.push_back(codegen_->MakeField( - exec_ctx->As()->Name(), - codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); - - // Derive the remainder of the closure's signature from - // the output schema of the associated query - std::size_t i = 0; - std::vector assignees{}; - execution::util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; - for (const auto &col : plan->GetOutputSchema()->GetColumns()) { - execution::ast::Expr *capture_var = codegen_->MakeExpr(SymbolTable().find(ast->Name())->second); - if (GetVariableType(ast->Name()) == type::TypeId::INVALID) { - // Record type - const auto struct_vars = GetRecordType(ast->Name()); - if (captures.empty()) { - captures.push_back(capture_var); - } - capture_var = codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(struct_vars[i].first)); - assignees.push_back(capture_var); - } else { - assignees.push_back(capture_var); - captures.push_back(capture_var); - } - lambda_parameters.push_back(codegen_->MakeField(codegen_->MakeFreshIdentifier("input"), - codegen_->TplType(execution::sql::GetTypeId(col.GetType())))); - i++; - } - - FunctionBuilder fn{codegen_, std::move(lambda_parameters), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)}; - for (std::size_t j = 0UL; j < assignees.size(); ++j) { - auto capture_var = assignees[j]; - auto input_param = fn.GetParameterByPosition(j + 1); - fn.Append(codegen_->Assign(capture_var, input_param)); - } - + // Construct a lambda that writes the output of the query + // into the bound variables, as defined by the function body + ast::LambdaExpr *lambda_expr = MakeLambda(plan, ast->Variables()); const ast::Identifier lambda_identifier = codegen_->MakeFreshIdentifier("udfLambda"); - ast::LambdaExpr *lambda_expr = fn.FinishLambda(std::move(captures)); lambda_expr->SetName(lambda_identifier); - // We want to pass something down that will materialize the lambda function - // into lambda_expr and will also feed in a lambda_expr to the compiler + // Generate code for the embedded query, utilizing the generated closure as the output callback execution::exec::ExecutionSettings exec_settings{}; - const std::string dummy_query{}; auto exec_query = execution::compiler::CompilationContext::Compile( *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, std::nullopt, - common::ManagedPointer{}, common::ManagedPointer(&dummy_query), - lambda_expr, codegen_->GetAstContext()); + common::ManagedPointer{}, lambda_expr, codegen_->GetAstContext()); + // Append all declarations from the compiled query auto decls = exec_query->GetDecls(); - aux_decls_.insert(aux_decls_.end(), decls.begin(), decls.end()); + aux_decls_.insert(aux_decls_.end(), decls.cbegin(), decls.cend()); + // Declare the closure and the query state in the current function + auto query_state = codegen_->MakeFreshIdentifier("query_state"); fb_->Append(codegen_->DeclareVar( lambda_identifier, codegen_->LambdaType(lambda_expr->GetFunctionLiteralExpr()->TypeRepr()), lambda_expr)); - - // Make query state - auto query_state = codegen_->MakeFreshIdentifier("query_state"); fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); // Set its execution context to whatever execution context was passed in here fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); - const std::vector columns = ColumnsSortedByIndex(ast->Parameters()); - const std::vector parameters = ParametersSortedByIndex(ast->Parameters()); - for (std::size_t i = 0; i < columns.size(); ++i) { - const auto &column = columns[i]; - const auto ¶meter = parameters[i]; - - // TODO(Kyle): This IILE is cool and all... but way more - // complex than I would like, all of the logic in this - // function deserves a second look to refactor - auto [type, expr] = [=, &column, ¶meter]() { - if (parameter.length() > 0) { - const auto fields = GetRecordType(parameter); - auto it = std::find_if(fields.cbegin(), fields.cend(), [=](auto p) { return p.first == column; }); - NOISEPAGE_ASSERT(it != fields.cend(), "Broken invariant"); - return std::pair{ - it->second, codegen_->AccessStructMember(codegen_->MakeExpr(SymbolTable().at(parameter)), - codegen_->MakeIdentifier(column))}; - } - const type::TypeId type = GetVariableType(column); - return std::pair{type, codegen_->MakeExpr(SymbolTable().at(column))}; - }(); - - fb_->Append(codegen_->CallBuiltin(AddParamBuiltinForParameterType(type), {exec_ctx, expr})); - } + // Determine the column references in the query (if any) + // that depend on variables in the UDF definition + binder::BindNodeVisitor visitor{common::ManagedPointer{accessor_}, db_oid_}; + const auto variable_refs = + visitor.BindAndGetUDFVariableRefs(common::ManagedPointer{ast->Query()}, common::ManagedPointer{udf_ast_context_}); + CodegenAddParameters(exec_ctx, variable_refs); // Load the execution context member of the query state fb_->Append(codegen_->Assign( codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); - // Generate code to assign to the closure captures - // from the output of the embedded query - const std::size_t n_columns = plan->GetOutputSchema()->GetColumns().size(); - for (const auto &col : plan->GetOutputSchema()->GetColumns()) { - execution::ast::Expr *capture_var = codegen_->MakeExpr(SymbolTable().find(ast->Name())->second); - execution::ast::Expr *lhs = (n_columns > 1) - ? codegen_->AccessStructMember(capture_var, codegen_->MakeIdentifier(col.GetName())) - : capture_var; - fb_->Append(codegen_->Assign(lhs, codegen_->ConstNull(col.GetType()))); - } + // Initialize the captures + CodegenCaptureAssignments(plan, ast->Variables()); // Manually append calls to each function from the compiled // executable query (implementing the closure) to the builder @@ -587,6 +507,209 @@ void UdfCodegen::Visit(ast::udf::MemberExprAST *ast) { dst_ = codegen_->AccessStructMember(object, codegen_->MakeIdentifier(ast->FieldName())); } +/* ---------------------------------------------------------------------------- + Code Generation Helpers +---------------------------------------------------------------------------- */ + +ast::LambdaExpr *UdfCodegen::MakeLambda(common::ManagedPointer plan, + const std::vector &variables) { + return GetVariableType(variables.front()) == type::TypeId::INVALID ? MakeLambdaBindingToRecord(plan, variables) + : MakeLambdaBindingToNonRecord(plan, variables); +} + +ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToRecord(common::ManagedPointer plan, + const std::vector &variables) { + // bind results to a single RECORD variable + NOISEPAGE_ASSERT(variables.size() == 1, "Broken invariant"); + + const std::string &record_name = variables.front(); + const auto record_type = GetRecordType(record_name); + + const auto n_fields = record_type.size(); + const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); + if (n_fields != n_columns) { + throw EXECUTION_EXCEPTION( + fmt::format("Attempt to bind {} query outputs to record type with {} fields", n_columns, n_fields), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + // The lambda accepts all columns of the query output schema as parameters + util::RegionVector parameters{codegen_->GetAstContext()->GetRegion()}; + + // The first parameter is always the execution context + ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); + parameters.push_back(codegen_->MakeField( + exec_ctx->As()->Name(), + codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); + + // The lambda only captures the RECORD variable to which all results are bound + ast::Expr *capture = codegen_->MakeExpr(SymbolTable().find(record_name)->second); + util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; + + // While the closure only captures a single variable, we still need + // to generate code for an assignment to each field memeber + std::vector assignees{}; + assignees.reserve(n_columns); + + for (std::size_t i = 0; i < n_columns; ++i) { + const auto &column = plan->GetOutputSchema()->GetColumn(i); + assignees.push_back(codegen_->AccessStructMember(capture, codegen_->MakeIdentifier(record_type[i].first))); + parameters.push_back(codegen_->MakeField(codegen_->MakeFreshIdentifier("input"), + codegen_->TplType(sql::GetTypeId(column.GetType())))); + } + + FunctionBuilder builder{codegen_, std::move(parameters), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)}; + for (std::size_t i = 0UL; i < assignees.size(); ++i) { + auto *assignee = assignees.at(i); + auto input_parameter = builder.GetParameterByPosition(i + 1); + builder.Append(codegen_->Assign(assignee, input_parameter)); + } + + return builder.FinishLambda(std::move(captures)); +} + +ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToNonRecord(common::ManagedPointer plan, + const std::vector &variables) { + // bind results to one or more non-RECORD variables + const auto n_variables = variables.size(); + const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); + if (n_variables != n_columns) { + throw EXECUTION_EXCEPTION(fmt::format("Attempt to bind {} query outputs to {} variables", n_columns, n_variables), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + // The lambda accepts all columns of the query output schema as parameters + util::RegionVector parameters{codegen_->GetAstContext()->GetRegion()}; + // The lambda captures the variables to which results are bound from the enclosing scope + util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; + + // The first parameter is always the execution context + ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); + parameters.push_back(codegen_->MakeField( + exec_ctx->As()->Name(), + codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); + + // Populate the remainder of the parameters and captures + for (std::size_t i = 0; i < n_columns; ++i) { + const auto &variable = variables.at(i); + const auto &column = plan->GetOutputSchema()->GetColumn(i); + captures.push_back(codegen_->MakeExpr(SymbolTable().find(variable)->second)); + parameters.push_back(codegen_->MakeField(codegen_->MakeFreshIdentifier("input"), + codegen_->TplType(sql::GetTypeId(column.GetType())))); + } + + // Begin construction of the function that implements the closure + FunctionBuilder builder{codegen_, std::move(parameters), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)}; + + // Generate an assignment from each input parameter to the associated capture + for (std::size_t i = 0UL; i < captures.size(); ++i) { + auto *capture = captures[i]; + auto input_parameter = builder.GetParameterByPosition(i + 1); + builder.Append(codegen_->Assign(capture, input_parameter)); + } + + return builder.FinishLambda(std::move(captures)); +} + +void UdfCodegen::CodegenAddParameters(ast::Expr *exec_ctx, const std::vector &variable_refs) { + for (const auto &variable_ref : variable_refs) { + if (variable_ref.IsScalar()) { + CodegenAddScalarParameter(exec_ctx, variable_ref); + } else { + CodegenAddTableParameter(exec_ctx, variable_ref); + } + } +} + +void UdfCodegen::CodegenAddScalarParameter(ast::Expr *exec_ctx, const parser::udf::VariableRef &variable_ref) { + NOISEPAGE_ASSERT(variable_ref.IsScalar(), "Broken invariant"); + const auto &name = variable_ref.ColumnName(); + const type::TypeId type = GetVariableType(name); + ast::Expr *expr = codegen_->MakeExpr(SymbolTable().at(name)); + fb_->Append(codegen_->CallBuiltin(AddParamBuiltinForParameterType(type), {exec_ctx, expr})); +} + +void UdfCodegen::CodegenAddTableParameter(ast::Expr *exec_ctx, const parser::udf::VariableRef &variable_ref) { + NOISEPAGE_ASSERT(!variable_ref.IsScalar(), "Broken invariant"); + + const auto &record_name = variable_ref.TableName(); + const auto &field_name = variable_ref.ColumnName(); + + const auto fields = GetRecordType(record_name); + auto it = std::find_if( + fields.cbegin(), fields.cend(), + [&field_name](const std::pair &field) -> bool { return field.first == field_name; }); + if (it == fields.cend()) { + throw EXECUTION_EXCEPTION(fmt::format("Field '{}' not found in record '{}'", field_name, record_name), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + const type::TypeId type = it->second; + ast::Expr *expr = codegen_->AccessStructMember(codegen_->MakeExpr(SymbolTable().at(record_name)), + codegen_->MakeIdentifier(field_name)); + fb_->Append(codegen_->CallBuiltin(AddParamBuiltinForParameterType(type), {exec_ctx, expr})); +} + +void UdfCodegen::CodegenCaptureAssignments(common::ManagedPointer plan, + const std::vector &bound_variables) { + if (bound_variables.empty()) { + // Nothing to do + return; + } + + if (GetVariableType(bound_variables.front()) == type::TypeId::INVALID) { + CodegenCaptureAssignmentToRecord(plan, bound_variables.front()); + } else { + CodegenCaptureAssignmentToScalars(plan, bound_variables); + } +} + +void UdfCodegen::CodegenCaptureAssignmentToScalars(common::ManagedPointer plan, + const std::vector &bound_variables) { + const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); + const auto n_variables = bound_variables.size(); + if (n_columns != n_variables) { + throw EXECUTION_EXCEPTION( + fmt::format("Attempt to bind {} query results to {} scalar variables", n_columns, n_variables), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + for (std::size_t i = 0; i < n_columns; ++i) { + const auto &column = plan->GetOutputSchema()->GetColumn(i); + const auto &variable = bound_variables.at(i); + execution::ast::Expr *capture = codegen_->MakeExpr(SymbolTable().find(variable)->second); + fb_->Append(codegen_->Assign(capture, codegen_->ConstNull(column.GetType()))); + } +} + +void UdfCodegen::CodegenCaptureAssignmentToRecord(common::ManagedPointer plan, + const std::string &record_name) { + NOISEPAGE_ASSERT(GetVariableType(record_name) == type::TypeId::INVALID, "Broken invariant"); + const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); + const auto fields = GetRecordType(record_name); + const auto n_fields = fields.size(); + if (n_columns != n_fields) { + // NOTE(Kyle): This should be impossible, the structure of the + // record type is derived from the output schema of the query + throw EXECUTION_EXCEPTION( + fmt::format("Attempt to bind {} query results to record with {} fields", n_columns, n_fields), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + ast::Expr *record = codegen_->MakeExpr(SymbolTable().find(record_name)->second); + for (std::size_t i = 0; i < n_columns; ++i) { + const auto &column = plan->GetOutputSchema()->GetColumn(i); + const auto &field = fields.at(i); + NOISEPAGE_ASSERT(column.GetName() == field.first, "Broken invariant"); + ast::Expr *capture = codegen_->AccessStructMember(record, codegen_->MakeIdentifier(field.first)); + fb_->Append(codegen_->Assign(capture, codegen_->ConstNull(column.GetType()))); + } +} + +/* ---------------------------------------------------------------------------- + General Utilities +---------------------------------------------------------------------------- */ + type::TypeId UdfCodegen::GetVariableType(const std::string &name) const { auto type = udf_ast_context_->GetVariableType(name); if (!type.has_value()) { diff --git a/src/include/binder/bind_node_visitor.h b/src/include/binder/bind_node_visitor.h index 1e5e90be04..bae79698a3 100644 --- a/src/include/binder/bind_node_visitor.h +++ b/src/include/binder/bind_node_visitor.h @@ -11,6 +11,7 @@ #include "execution/ast/udf/udf_ast_context.h" #include "parser/postgresparser.h" #include "parser/select_statement.h" +#include "parser/udf/variable_ref.h" #include "type/type_id.h" namespace noisepage { @@ -60,7 +61,7 @@ class BindNodeVisitor final : public SqlNodeVisitor { * Column Name -> (Parameter Name, Parameter Index) * @throws BinderException on failure to bind query */ - std::unordered_map> BindAndGetUDFParams( + std::vector BindAndGetUDFVariableRefs( common::ManagedPointer parse_result, common::ManagedPointer udf_ast_context); @@ -121,7 +122,7 @@ class BindNodeVisitor final : public SqlNodeVisitor { /** Context for UDF AST */ common::ManagedPointer udf_ast_context_{}; /** Parameters for UDF */ - std::unordered_map> udf_params_; + std::vector udf_variable_refs_; /** Catalog accessor */ const common::ManagedPointer catalog_accessor_; @@ -169,6 +170,30 @@ class BindNodeVisitor final : public SqlNodeVisitor { * for which binding is performed, `false` otherwise */ bool IsUDFVariable(const std::string &identifier) const; + + /** + * Determine if the given identifier names a variable + * reference that is already tracked. + * @param identifier The variable identifier + */ + bool HaveUDFVariableRef(const std::string &identifier) const; + + /** + * Add a UDF variable reference to the internal tracker. + * @param expr The expression + * @param table_name The name of the table associated with the reference + * @param column_name The name of the column associated with the reference + */ + void AddUDFVariableReference(common::ManagedPointer expr, + const std::string &table_name, const std::string &column_name); + + /** + * Add a UDF variable reference to the internal tracker. + * @param expr The expression + * @param column_name The name of the column associated with the reference + */ + void AddUDFVariableReference(common::ManagedPointer expr, + const std::string &column_name); }; } // namespace binder diff --git a/src/include/execution/ast/udf/udf_ast_nodes.h b/src/include/execution/ast/udf/udf_ast_nodes.h index 42c0b82d3d..811dfeaa68 100644 --- a/src/include/execution/ast/udf/udf_ast_nodes.h +++ b/src/include/execution/ast/udf/udf_ast_nodes.h @@ -631,12 +631,11 @@ class SQLStmtAST : public StmtAST { /** * Construct a new SQLStmtAST instance. * @param query The result of parsing the SQL query - * @param name The name of the variable to which results of the query are bound - * @param parameters The parameters to the query + * @param variables The collection of identifiers of variables + * to which results of the query are bound */ - SQLStmtAST(std::unique_ptr &&query, std::string name, - std::unordered_map> &¶meters) - : query_{std::move(query)}, name_{std::move(name)}, parameters_(std::move(parameters)) {} + SQLStmtAST(std::unique_ptr &&query, std::vector &&variables) + : query_{std::move(query)}, variables_{std::move(variables)} {} /** * AST visitor pattern. @@ -650,21 +649,15 @@ class SQLStmtAST : public StmtAST { /** @return The result of parsing the SQL query */ const parser::ParseResult *Query() const { return query_.get(); } - /** @return The variable name to which results are bound */ - const std::string &Name() const { return name_; } - - /** @return The parameters to the query */ - const std::unordered_map> &Parameters() const { return parameters_; } + /** @return The variable names to which results are bound */ + const std::vector &Variables() const { return variables_; } private: /** The result of parsing the SQL query */ std::unique_ptr query_; - /** The variable name to which results of the query are bound */ - std::string name_; - - /** The parameters to the query */ - std::unordered_map> parameters_; + /** The names of the variables to which results are bound */ + std::vector variables_; }; /** diff --git a/src/include/execution/compiler/compilation_context.h b/src/include/execution/compiler/compilation_context.h index de9b1769fa..26618353e5 100644 --- a/src/include/execution/compiler/compilation_context.h +++ b/src/include/execution/compiler/compilation_context.h @@ -64,8 +64,7 @@ class CompilationContext { catalog::CatalogAccessor *accessor, CompilationMode mode = CompilationMode::Interleaved, std::optional override_qid = std::nullopt, common::ManagedPointer plan_meta_data = nullptr, - common::ManagedPointer query_text = nullptr, ast::LambdaExpr *output_callback = nullptr, - common::ManagedPointer context = nullptr); + ast::LambdaExpr *output_callback = nullptr, common::ManagedPointer context = nullptr); /** * Register a pipeline in this context. diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index 2d679987e4..54e9e901af 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -12,6 +12,7 @@ #include "execution/compiler/codegen.h" #include "execution/compiler/function_builder.h" #include "execution/functions/function_context.h" +#include "planner/plannodes/abstract_join_plan_node.h" namespace noisepage::catalog { class CatalogAccessor; @@ -21,6 +22,10 @@ namespace noisepage::optimizer { class OptimizeResult; } // namespace noisepage::optimizer +namespace noisepage::parser::udf { +class VariableRef; +} // namespace noisepage::parser::udf + namespace noisepage::execution { // Forward declarations @@ -224,6 +229,81 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { static const char *GetReturnParamString(); private: + /** + * Construct a lambda expression that writes the output of the query + * represented by `plan` into the variables identified by `variables`. + * @param plan The query plan + * @param variables The names of the variables to which results are bound + * @return The lambda expression + */ + ast::LambdaExpr *MakeLambda(common::ManagedPointer plan, + const std::vector &variables); + + /** + * Construct a lambda expression that writes the output of the query + * represented by `plan` into a single RECORD-type variable. + * @param plan The query plan + * @param variables The names of the variables to which results are bound + * @return The lambda expression + */ + ast::LambdaExpr *MakeLambdaBindingToRecord(common::ManagedPointer plan, + const std::vector &variables); + + /** + * Construct a lambda expression that writes the output of the query + * represented by `plan` into one or more non-RECORD variables. + * @param plan The query plan + * @param variables The names of the variables to which results are bound + * @return The lambda expression + */ + ast::LambdaExpr *MakeLambdaBindingToNonRecord(common::ManagedPointer plan, + const std::vector &variables); + + /** + * Generate code to add query parameters to the execution context. + * @param exec_ctx The execution context expression + * @param variable_refs The collection of variable references + */ + void CodegenAddParameters(ast::Expr *exec_ctx, const std::vector &variable_refs); + + /** + * Generate code to add a scalar parameter to the execution context. + * @param exec_ctx The execution context + * @param variable_ref The variable reference + */ + void CodegenAddScalarParameter(ast::Expr *exec_ctx, const parser::udf::VariableRef &variable_ref); + + /** + * Generate code to add a non-scalar parameter to the execution context. + * @param exec_ctx The execution context + * @param variable_ref The variable reference + */ + void CodegenAddTableParameter(ast::Expr *exec_ctx, const parser::udf::VariableRef &variable_ref); + + /** + * Generate code to perform assignment to captured variables. + * @param plan The query plan + * @param bound_variables The variables to which results of the query are bound + */ + void CodegenCaptureAssignments(common::ManagedPointer plan, + const std::vector &bound_variables); + + /** + * Generate code to perform assignment to captured variables. + * @param plan The query plan + * @param bound_variables The name(s) of the scalar variables to which results of the query are bound + */ + void CodegenCaptureAssignmentToScalars(common::ManagedPointer plan, + const std::vector &bound_variables); + + /** + * Generate code to perform assignment to captured variables. + * @param plan The query plan + * @param record_name The name of the record variable to which results of the query are bound + */ + void CodegenCaptureAssignmentToRecord(common::ManagedPointer plan, + const std::string &record_name); + /** * Translate a SQL type to its corresponding catalog type. * @param type The SQL type of interest diff --git a/src/include/parser/udf/plpgsql_parser.h b/src/include/parser/udf/plpgsql_parser.h index bf9d660742..ae3cd5b8d2 100644 --- a/src/include/parser/udf/plpgsql_parser.h +++ b/src/include/parser/udf/plpgsql_parser.h @@ -130,6 +130,21 @@ class PLpgSQLParser { common::ManagedPointer expr); private: + /** + * Determine if all variables in `names` are declared in the function. + * @param names The collection of variable identifiers + * @return `true` if all variables are declared, `false` otherwise + */ + bool AllVariablesDeclared(const std::vector &names) const; + + /** + * Determine if any of the variables in `names` refer to a RECORD type. + * @param names The collection of variable identifiers + * @return `true` if any of the variables in `names` refer + * to a RECORD type previously declared, `false` otherwise + */ + bool ContainsRecordType(const std::vector &names) const; + /** * Resolve a PL/pgSQL RECORD type from a SELECT statement. * @param parse_result The result of parsing the SQL query diff --git a/src/include/parser/udf/variable_ref.h b/src/include/parser/udf/variable_ref.h new file mode 100644 index 0000000000..bcde6b1451 --- /dev/null +++ b/src/include/parser/udf/variable_ref.h @@ -0,0 +1,70 @@ +#pragma once + +#include +#include + +#include "common/macros.h" + +namespace noisepage::parser::udf { + +/** + * The VariableRefType enumeration defines the + * valid types of variable references. + */ +enum class VariableRefType { SCALAR, TABLE }; + +/** + * The VariableRef type represents a UDF variable reference + * within a SQL query. It is used during binding to identify + * and track the query parameters that must be read from the + * UDF environment prior to query execution. + */ +class VariableRef { + public: + /** + * Construct a new VariableRef instance for a TABLE reference. + * @param table_name The name of the table + * @param column_name The name of the column + * @param index The index + */ + VariableRef(std::string table_name, std::string column_name, std::size_t index) + : type_{VariableRefType::TABLE}, + table_name_{std::move(table_name)}, + column_name_{std::move(column_name)}, + index_{index} {} + + /** + * Construct a new VariableRef instance for a SCALAR reference. + * @param column_name The name of the column + * @param index The index + */ + VariableRef(std::string column_name, std::size_t index) + : type_{VariableRefType::SCALAR}, table_name_{}, column_name_{std::move(column_name)}, index_{index} {} + + /** @return `true` if this is a SCALAR variable reference, `false` otherwise */ + bool IsScalar() const { return type_ == VariableRefType::SCALAR; } + + /** @return The table name of the variable reference */ + const std::string &TableName() const { + NOISEPAGE_ASSERT(!IsScalar(), "SCALAR variable references do not have associated table names"); + return table_name_; + } + + /** @return The column name of the variable reference */ + const std::string &ColumnName() const { return column_name_; } + + /** @return The index of the variable reference */ + std::size_t Index() const { return index_; } + + private: + /** The type of this variable reference */ + const VariableRefType type_; + /** The table name associated with this variable reference (may be empty) */ + const std::string table_name_; + /** The column name associated with this variable reference */ + const std::string column_name_; + /** The index of the reference in the query */ + const std::size_t index_; +}; + +} // namespace noisepage::parser::udf diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index 100c894a3b..484b0bbe47 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -249,11 +249,16 @@ std::unique_ptr PLpgSQLParser::ParseForS(const nlo return nullptr; } auto body_stmt = ParseBlock(json[K_BODY]); - auto var_array = json[K_ROW][K_PLPGSQL_ROW][K_FIELDS]; + auto variable_array = json[K_ROW][K_PLPGSQL_ROW][K_FIELDS]; std::vector variables{}; - variables.reserve(var_array.size()); - std::transform(var_array.cbegin(), var_array.cend(), std::back_inserter(variables), + variables.reserve(variable_array.size()); + std::transform(variable_array.cbegin(), variable_array.cend(), std::back_inserter(variables), [](const nlohmann::json &var) { return var[K_NAME].get(); }); + + if (!AllVariablesDeclared(variables)) { + throw PARSER_EXCEPTION("PLpgSQL parser : variable was not declared"); + } + return std::make_unique(std::move(variables), std::move(parse_result), std::move(body_stmt)); } @@ -261,33 +266,37 @@ std::unique_ptr PLpgSQLParser::ParseForS(const nlo std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &json) { // The query text const auto sql_query = json[K_SQLSTMT][K_PLPGSQL_EXPR][K_QUERY].get(); - // The variable name (non-const for later std::move) - auto var_name = json[K_ROW][K_PLPGSQL_ROW][K_FIELDS][0][K_NAME].get(); - auto parse_result = PostgresParser::BuildParseTree(sql_query); if (parse_result == nullptr) { return nullptr; } - // Bind the query within the UDF body; if binding - // fails, we allow the BinderException to propogate - binder::BindNodeVisitor visitor{accessor_, db_oid_}; - auto query_params = visitor.BindAndGetUDFParams(common::ManagedPointer{parse_result}, udf_ast_context_); + auto variable_array = json[K_ROW][K_PLPGSQL_ROW][K_FIELDS]; + std::vector variables{}; + variables.reserve(variable_array.size()); + std::transform(variable_array.cbegin(), variable_array.cend(), std::back_inserter(variables), + [](const nlohmann::json &var) -> std::string { return var[K_NAME].get(); }); - // Check to see if a record type can be bound to this - const auto type = udf_ast_context_->GetVariableType(var_name); - if (!type.has_value()) { + // Ensure all variables to which results are bound are declared + if (!AllVariablesDeclared(variables)) { throw PARSER_EXCEPTION("PL/pgSQL parser : variable was not declared"); } - if (type.value() == type::TypeId::INVALID) { - // If the type is a RECORD type, derive the structure of - // the type from the columns of the SELECT statement - udf_ast_context_->SetRecordType(var_name, ResolveRecordType(parse_result.get())); + // Two possibilities for binding of results: + // - Exactly one RECORD variable + // - One or more non-RECORD variables + + if (ContainsRecordType(variables)) { + if (variables.size() > 1) { + throw PARSER_EXCEPTION("Binding of query results is ambiguous"); + } + // There is only a single result variable and it is a RECORD; + // derive the structure of the RECORD from the SELECT columns + const auto &name = variables.front(); + udf_ast_context_->SetRecordType(name, ResolveRecordType(parse_result.get())); } - return std::make_unique(std::move(parse_result), std::move(var_name), - std::move(query_params)); + return std::make_unique(std::move(parse_result), std::move(variables)); } std::unique_ptr PLpgSQLParser::ParseDynamicSQL(const nlohmann::json &json) { @@ -352,6 +361,17 @@ std::unique_ptr PLpgSQLParser::ParseExprFromAbstra } } +bool PLpgSQLParser::AllVariablesDeclared(const std::vector &names) const { + return std::all_of(names.cbegin(), names.cend(), + [this](const std::string &name) -> bool { return udf_ast_context_->HasVariable(name); }); +} + +bool PLpgSQLParser::ContainsRecordType(const std::vector &names) const { + return std::any_of(names.cbegin(), names.cend(), [this](const std::string &name) -> bool { + return udf_ast_context_->GetVariableType(name) == type::TypeId::INVALID; + }); +} + std::vector> PLpgSQLParser::ResolveRecordType(const ParseResult *parse_result) { std::vector> fields{}; const auto &select_columns = From 24ce125b07d6bd2a8d8fcd87888ea78f95ca947c Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Wed, 28 Jul 2021 17:13:39 -0400 Subject: [PATCH 086/139] integration tests for binding multiple query results to udf variables --- script/testing/junit/sql/udf.sql | 17 +++++++++++++++-- script/testing/junit/traces/udf.test | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 32694dbcbc..233f09ddc2 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -136,6 +136,21 @@ $$ LANGUAGE PLPGSQL; SELECT sql_select_single_constant(); +-- ---------------------------------------------------------------------------- +-- sql_select_mutliple_constants() + +CREATE FUNCTION sql_select_mutliple_constants() RETURNS INT AS $$ \ +DECLARE \ + x INT; \ + y INT; \ +BEGIN \ + SELECT 1, 2 INTO x, y; \ + RETURN x + y; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT sql_select_mutliple_constants(); + -- ---------------------------------------------------------------------------- -- proc_fors() -- @@ -173,5 +188,3 @@ SELECT sql_select_single_constant(); -- $$ LANGUAGE PLPGSQL; -- SELECT x, proc_fors_var() FROM integers; - -CREATE FUNCTION fun() RETURNS INT AS $$ DECLARE x INT; y INT; BEGIN SELECT 1, 2 INTO x, y; RETURN x + y; END $$ LANGUAGE PLPGSQL; diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 722b4e59bd..45b6934175 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -294,6 +294,30 @@ SELECT sql_select_single_constant(); statement ok +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- sql_select_mutliple_constants() + +statement ok + + +statement ok +CREATE FUNCTION sql_select_mutliple_constants() RETURNS INT AS $$ DECLARE x INT; y INT; BEGIN SELECT 1, 2 INTO x, y; RETURN x + y; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT sql_select_mutliple_constants(); +---- +3 + + +statement ok + + statement ok -- ---------------------------------------------------------------------------- From 964430df9d852af7e7d654a5129840292f936d70 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Wed, 28 Jul 2021 22:11:59 -0400 Subject: [PATCH 087/139] refactor udf code generation for embedded queries to pull out logic common to both regular queries and for loops --- script/testing/util/db_server.py | 2 +- src/execution/compiler/function_builder.cpp | 58 +++++---- src/execution/compiler/udf/udf_codegen.cpp | 120 +++++++----------- .../execution/compiler/function_builder.h | 92 +++++++------- .../execution/compiler/udf/udf_codegen.h | 18 +-- 5 files changed, 135 insertions(+), 155 deletions(-) diff --git a/script/testing/util/db_server.py b/script/testing/util/db_server.py index 0e3809c9f9..c77b8547a1 100644 --- a/script/testing/util/db_server.py +++ b/script/testing/util/db_server.py @@ -164,7 +164,7 @@ def stop_db(self, is_dry_run=False): finally: unix_socket = os.path.join("/tmp/", f".s.PGSQL.{self.db_port}") if os.path.exists(unix_socket): - os.remove(unix_socket) + # os.remove(unix_socket) LOG.info(f"Removing: {unix_socket}") self.print_db_logs() exit_code = self.db_process.returncode diff --git a/src/execution/compiler/function_builder.cpp b/src/execution/compiler/function_builder.cpp index 5f22d71382..c96613a26b 100644 --- a/src/execution/compiler/function_builder.cpp +++ b/src/execution/compiler/function_builder.cpp @@ -1,35 +1,36 @@ #include "execution/compiler/function_builder.h" #include "execution/ast/ast_node_factory.h" +#include "execution/ast/context.h" #include "execution/compiler/codegen.h" namespace noisepage::execution::compiler { -// TODO(Kyle): We should refactor this two 2 distinct types: -// the regular old FunctionBuilder and a ClosureBuilder - FunctionBuilder::FunctionBuilder(CodeGen *codegen, ast::Identifier name, util::RegionVector &¶ms, ast::Expr *ret_type) - : codegen_{codegen}, + : type_{FunctionType::FUNCTION}, + codegen_{codegen}, name_{name}, params_{std::move(params)}, - ret_type_{ret_type}, + captures_{codegen_->GetAstContext()->GetRegion()}, + return_type_{ret_type}, start_{codegen->GetPosition()}, statements_{codegen->MakeEmptyBlock()}, - is_lambda_{false}, decl_{std::in_place_type, nullptr} {} -FunctionBuilder::FunctionBuilder(CodeGen *codegen, util::RegionVector &¶ms, ast::Expr *ret_type) - : codegen_{codegen}, +FunctionBuilder::FunctionBuilder(CodeGen *codegen, util::RegionVector &¶ms, + util::RegionVector &&captures, ast::Expr *return_type) + : type_{FunctionType::CLOSURE}, + codegen_{codegen}, params_{std::move(params)}, - ret_type_{ret_type}, + captures_{std::move(captures)}, + return_type_{return_type}, start_{codegen->GetPosition()}, statements_{codegen->MakeEmptyBlock()}, - is_lambda_{true}, decl_{std::in_place_type, nullptr} {} FunctionBuilder::~FunctionBuilder() { - if (!IsLambda()) { + if (type_ == FunctionType::FUNCTION) { Finish(); } } @@ -53,7 +54,8 @@ void FunctionBuilder::Append(ast::Expr *expr) { Append(codegen_->GetFactory()->N void FunctionBuilder::Append(ast::VariableDecl *decl) { Append(codegen_->GetFactory()->NewDeclStmt(decl)); } ast::FunctionDecl *FunctionBuilder::Finish(ast::Expr *ret) { - NOISEPAGE_ASSERT(!is_lambda_, "Attempt to call Finish() on a FunctionDecl that is a lambda"); + NOISEPAGE_ASSERT(type_ == FunctionType::FUNCTION, + "Attempt to call FunctionBuilder::Finish on non-function-type builder"); NOISEPAGE_ASSERT(std::holds_alternative(decl_), "Broken invariant"); auto *declaration = std::get(decl_); if (declaration != nullptr) { @@ -65,26 +67,26 @@ ast::FunctionDecl *FunctionBuilder::Finish(ast::Expr *ret) { "with an explicit return expression, or use the factory to manually append a return " "statement and call FunctionBuilder::Finish() with a null return."); - // Add the return. + // Add the return if (!statements_->IsEmpty() && !statements_->GetLast()->IsReturnStmt()) { Append(codegen_->GetFactory()->NewReturnStmt(codegen_->GetPosition(), ret)); } - // Finalize everything. + // Finalize everything statements_->SetRightBracePosition(codegen_->GetPosition()); - // Build the function's type. - auto func_type = codegen_->GetFactory()->NewFunctionType(start_, std::move(params_), ret_type_); + // Build the function's type + auto func_type = codegen_->GetFactory()->NewFunctionType(start_, std::move(params_), return_type_); - // Create the declaration. + // Create the declaration auto func_lit = codegen_->GetFactory()->NewFunctionLitExpr(func_type, statements_); decl_ = codegen_->GetFactory()->NewFunctionDecl(start_, name_, func_lit); return std::get(decl_); } -noisepage::execution::ast::LambdaExpr *FunctionBuilder::FinishLambda(util::RegionVector &&captures, - ast::Expr *ret) { - NOISEPAGE_ASSERT(is_lambda_, "Attempt to call FinishLambda() on a FunctionDecl that is not a lambda"); +noisepage::execution::ast::LambdaExpr *FunctionBuilder::FinishClosure(ast::Expr *ret) { + NOISEPAGE_ASSERT(type_ == FunctionType::CLOSURE, + "Attempt to call FuncionBuilder::FinishClosure on non-closure-type builder"); NOISEPAGE_ASSERT(std::holds_alternative(decl_), "Broken invariant"); auto *declaration = std::get(decl_); if (declaration != nullptr) { @@ -92,21 +94,21 @@ noisepage::execution::ast::LambdaExpr *FunctionBuilder::FinishLambda(util::Regio } NOISEPAGE_ASSERT(ret == nullptr || statements_->IsEmpty() || !statements_->GetLast()->IsReturnStmt(), - "Double-return at end of function. You should either call FunctionBuilder::Finish() " + "Double-return at end of function. You should either call FunctionBuilder::FinishClosure() " "with an explicit return expression, or use the factory to manually append a return " - "statement and call FunctionBuilder::Finish() with a null return."); - // Add the return. + "statement and call FunctionBuilder::FinishClosure() with a null return."); + // Add the return if (!statements_->IsEmpty() && !statements_->GetLast()->IsReturnStmt()) { Append(codegen_->GetFactory()->NewReturnStmt(codegen_->GetPosition(), ret)); } - // Finalize everything. + // Finalize everything statements_->SetRightBracePosition(codegen_->GetPosition()); - // Build the function's type. - auto func_type = codegen_->GetFactory()->NewFunctionType(start_, std::move(params_), ret_type_); + // Build the function's type + auto func_type = codegen_->GetFactory()->NewFunctionType(start_, std::move(params_), return_type_); - // Create the declaration. + // Create the declaration auto func_lit = codegen_->GetFactory()->NewFunctionLitExpr(func_type, statements_); - decl_ = codegen_->GetFactory()->NewLambdaExpr(start_, func_lit, std::move(captures)); + decl_ = codegen_->GetFactory()->NewLambdaExpr(start_, func_lit, std::move(captures_)); return std::get(decl_); } diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index aff0c2e58d..beaa307c3d 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -342,49 +342,21 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { common::ErrorCode::ERRCODE_PLPGSQL_ERROR); } - // Construct a lambda that writes the output of the query - // into the identifiers within the UDF bound to the output + // Start construction of the lambda expression + auto builder = StartLambda(plan, ast->Variables()); - std::vector variable_identifiers{}; - execution::util::RegionVector params{codegen_->GetAstContext()->GetRegion()}; - params.push_back(codegen_->MakeField( - exec_ctx->As()->Name(), - codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); - std::size_t i = 0; - for (const auto &variable_name : ast->Variables()) { - const ast::Identifier variable_identifier = SymbolTable().find(variable_name)->second; - variable_identifiers.push_back(variable_identifier); - ast::Expr *type = codegen_->TplType(execution::sql::GetTypeId(plan->GetOutputSchema()->GetColumn(i).GetType())); - fb_->Append(codegen_->Assign(codegen_->MakeExpr(variable_identifier), - codegen_->ConstNull(plan->GetOutputSchema()->GetColumn(i).GetType()))); - params.push_back(codegen_->MakeField(codegen_->MakeFreshIdentifier(variable_name), type)); - i++; - } - - FunctionBuilder fn{codegen_, std::move(params), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)}; + // Generate code for closure capture assignment + CodegenCaptureAssignments(plan, ast->Variables()); + + // Generate code for the loop body { - std::size_t j = 1; - for (auto var : variable_identifiers) { - fn.Append(codegen_->Assign(codegen_->MakeExpr(var), fn.GetParameterByPosition(j++))); - } - auto prev_fb = fb_; - fb_ = &fn; + auto cached_builder = fb_; + fb_ = builder.get(); ast->Body()->Accept(this); - fb_ = prev_fb; + fb_ = cached_builder; } - // Define the captures for the closure - // TODO(Kyle): We are capturing every variable in the symbol table, - // this seems like overkill and may lead to incorrect semantics? - execution::util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; - for (const auto &[name, identifier] : SymbolTable()) { - if (name == "executionCtx") { - continue; - } - captures.push_back(codegen_->MakeExpr(identifier)); - } - - ast::LambdaExpr *lambda_expr = fn.FinishLambda(std::move(captures)); + ast::LambdaExpr *lambda_expr = builder->FinishClosure(); const ast::Identifier lambda_identifier = codegen_->MakeFreshIdentifier("udfLambda"); lambda_expr->SetName(lambda_identifier); @@ -448,14 +420,15 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { // Construct a lambda that writes the output of the query // into the bound variables, as defined by the function body - ast::LambdaExpr *lambda_expr = MakeLambda(plan, ast->Variables()); + auto builder = StartLambda(plan, ast->Variables()); + ast::LambdaExpr *lambda_expr = builder->FinishClosure(); const ast::Identifier lambda_identifier = codegen_->MakeFreshIdentifier("udfLambda"); lambda_expr->SetName(lambda_identifier); // Generate code for the embedded query, utilizing the generated closure as the output callback - execution::exec::ExecutionSettings exec_settings{}; - auto exec_query = execution::compiler::CompilationContext::Compile( - *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, std::nullopt, + exec::ExecutionSettings exec_settings{}; + auto exec_query = compiler::CompilationContext::Compile( + *plan, exec_settings, accessor_, compiler::CompilationMode::OneShot, std::nullopt, common::ManagedPointer{}, lambda_expr, codegen_->GetAstContext()); // Append all declarations from the compiled query @@ -469,7 +442,7 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); // Set its execution context to whatever execution context was passed in here - fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); + fb_->Append(codegen_->CallBuiltin(ast::Builtin::StartNewParams, {exec_ctx})); // Determine the column references in the query (if any) // that depend on variables in the UDF definition @@ -498,7 +471,7 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { } } - fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); + fb_->Append(codegen_->CallBuiltin(ast::Builtin::FinishNewParams, {exec_ctx})); } void UdfCodegen::Visit(ast::udf::MemberExprAST *ast) { @@ -511,14 +484,14 @@ void UdfCodegen::Visit(ast::udf::MemberExprAST *ast) { Code Generation Helpers ---------------------------------------------------------------------------- */ -ast::LambdaExpr *UdfCodegen::MakeLambda(common::ManagedPointer plan, - const std::vector &variables) { - return GetVariableType(variables.front()) == type::TypeId::INVALID ? MakeLambdaBindingToRecord(plan, variables) - : MakeLambdaBindingToNonRecord(plan, variables); +std::unique_ptr UdfCodegen::StartLambda(common::ManagedPointer plan, + const std::vector &variables) { + return GetVariableType(variables.front()) == type::TypeId::INVALID ? StartLambdaBindingToRecord(plan, variables) + : StartLambdaBindingToNonRecord(plan, variables); } -ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToRecord(common::ManagedPointer plan, - const std::vector &variables) { +std::unique_ptr UdfCodegen::StartLambdaBindingToRecord( + common::ManagedPointer plan, const std::vector &variables) { // bind results to a single RECORD variable NOISEPAGE_ASSERT(variables.size() == 1, "Broken invariant"); @@ -534,17 +507,17 @@ ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToRecord(common::ManagedPointer parameters{codegen_->GetAstContext()->GetRegion()}; + util::RegionVector parameters{codegen_->GetAstContext()->GetRegion()}; // The first parameter is always the execution context ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); - parameters.push_back(codegen_->MakeField( - exec_ctx->As()->Name(), - codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); + parameters.push_back( + codegen_->MakeField(exec_ctx->As()->Name(), + codegen_->PointerType(codegen_->BuiltinType(ast::BuiltinType::Kind::ExecutionContext)))); // The lambda only captures the RECORD variable to which all results are bound ast::Expr *capture = codegen_->MakeExpr(SymbolTable().find(record_name)->second); - util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; + util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; // While the closure only captures a single variable, we still need // to generate code for an assignment to each field memeber @@ -558,18 +531,18 @@ ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToRecord(common::ManagedPointerTplType(sql::GetTypeId(column.GetType())))); } - FunctionBuilder builder{codegen_, std::move(parameters), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)}; + auto builder = std::make_unique(codegen_, std::move(parameters), std::move(captures), + codegen_->BuiltinType(ast::BuiltinType::Nil)); for (std::size_t i = 0UL; i < assignees.size(); ++i) { auto *assignee = assignees.at(i); - auto input_parameter = builder.GetParameterByPosition(i + 1); - builder.Append(codegen_->Assign(assignee, input_parameter)); + auto input_parameter = builder->GetParameterByPosition(i + 1); + builder->Append(codegen_->Assign(assignee, input_parameter)); } - - return builder.FinishLambda(std::move(captures)); + return builder; } -ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToNonRecord(common::ManagedPointer plan, - const std::vector &variables) { +std::unique_ptr UdfCodegen::StartLambdaBindingToNonRecord( + common::ManagedPointer plan, const std::vector &variables) { // bind results to one or more non-RECORD variables const auto n_variables = variables.size(); const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); @@ -579,14 +552,14 @@ ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToNonRecord(common::ManagedPointer } // The lambda accepts all columns of the query output schema as parameters - util::RegionVector parameters{codegen_->GetAstContext()->GetRegion()}; + util::RegionVector parameters{codegen_->GetAstContext()->GetRegion()}; // The lambda captures the variables to which results are bound from the enclosing scope - util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; + util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; // The first parameter is always the execution context ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); parameters.push_back(codegen_->MakeField( - exec_ctx->As()->Name(), + exec_ctx->As()->Name(), codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); // Populate the remainder of the parameters and captures @@ -598,17 +571,20 @@ ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToNonRecord(common::ManagedPointer codegen_->TplType(sql::GetTypeId(column.GetType())))); } + // Clone the captures for assignment within the closure body + const std::vector assignees{captures.cbegin(), captures.cend()}; + // Begin construction of the function that implements the closure - FunctionBuilder builder{codegen_, std::move(parameters), codegen_->BuiltinType(execution::ast::BuiltinType::Nil)}; + auto builder = std::make_unique(codegen_, std::move(parameters), std::move(captures), + codegen_->BuiltinType(execution::ast::BuiltinType::Nil)); // Generate an assignment from each input parameter to the associated capture - for (std::size_t i = 0UL; i < captures.size(); ++i) { - auto *capture = captures[i]; - auto input_parameter = builder.GetParameterByPosition(i + 1); - builder.Append(codegen_->Assign(capture, input_parameter)); + for (std::size_t i = 0UL; i < assignees.size(); ++i) { + ast::Expr *capture = assignees.at(i); + auto input_parameter = builder->GetParameterByPosition(i + 1); + builder->Append(codegen_->Assign(capture, input_parameter)); } - - return builder.FinishLambda(std::move(captures)); + return builder; } void UdfCodegen::CodegenAddParameters(ast::Expr *exec_ctx, const std::vector &variable_refs) { diff --git a/src/include/execution/compiler/function_builder.h b/src/include/execution/compiler/function_builder.h index 1fe63aadc6..b3af42f0e2 100644 --- a/src/include/execution/compiler/function_builder.h +++ b/src/include/execution/compiler/function_builder.h @@ -4,6 +4,7 @@ #include #include +#include "common/macros.h" #include "execution/ast/identifier.h" #include "execution/compiler/ast_fwd.h" #include "execution/util/region_containers.h" @@ -12,6 +13,9 @@ namespace noisepage::execution::compiler { class CodeGen; +/** Enumerates the function types */ +enum class FunctionType { FUNCTION, CLOSURE }; + /** * Helper class to build TPL functions. */ @@ -21,101 +25,99 @@ class FunctionBuilder { public: /** - * Create a builder for a function with the provided name, return type, and arguments. - * @param codegen The code generation instance. - * @param name The name of the function. - * @param params The parameters to the function. - * @param ret_type The return type representation of the function. + * Construct a new FunctionBuilder instance for a "vanilla" function. + * @param codegen The code generation instance + * @param name The function name + * @param params The function parameters + * @param return_type The return type representation of the function */ FunctionBuilder(CodeGen *codegen, ast::Identifier name, util::RegionVector &¶ms, - ast::Expr *ret_type); + ast::Expr *return_type); /** - * Create a builder for a function with the provided return type and arguments. - * @param codegen The code generation instance. - * @param params The parameters to the function. - * @param ret_type The return type representation of the function. + * Construct a new FunctionBuilder instance for a closure. + * @param codegen The code generation instance + * @param params The function parameters + * @param closures The function closures + * @param return_type The return type representation of the function */ - FunctionBuilder(CodeGen *codegen, util::RegionVector &¶ms, ast::Expr *ret_type); + FunctionBuilder(CodeGen *codegen, util::RegionVector &¶ms, + util::RegionVector &&captures, ast::Expr *return_type); - /** - * Destructor. - */ + /** Destructor; invokes FunctionBuilder::Finish() */ ~FunctionBuilder(); - /** - * @return A reference to a function parameter by its ordinal position. - */ + /** @return A reference to a function parameter by its ordinal position */ ast::Expr *GetParameterByPosition(std::size_t param_idx); /** * Append a statement to the list of statements in this function. - * @param stmt The statement to append. + * @param stmt The statement to append */ void Append(ast::Stmt *stmt); /** * Append an expression as a statement to the list of statements in this function. - * @param expr The expression to append as a statement. + * @param expr The expression to append as a statement */ void Append(ast::Expr *expr); /** * Append a variable declaration as a statement to the list of statements in this function. - * @param decl The declaration to append to the statement. + * @param decl The declaration to append to the statement */ void Append(ast::VariableDecl *decl); /** - * Finish constructing the function. - * @param ret The value to return from the function. Use a null pointer to return nothing. - * @return The build function declaration. + * Finish construction of the function. + * @param ret The function return value; use `nullptr` for `nil` return + * @return The finished declaration */ ast::FunctionDecl *Finish(ast::Expr *ret = nullptr); /** - * Finish constructing the lambda. - * @param captures The lambda captures - * @param ret The return value, if present - * @return The lambda expression + * Finish construction of the closure. + * @param ret The function return value; use `nullptr` for `nil` return + * @return The finished expression */ - noisepage::execution::ast::LambdaExpr *FinishLambda(util::RegionVector &&captures, - ast::Expr *ret = nullptr); + ast::LambdaExpr *FinishClosure(ast::Expr *ret = nullptr); - /** - * @return The final constructed function, or nullptr if the builder - * hasn't been constructed through FunctionBuilder::Finish(). - */ - ast::FunctionDecl *GetConstructedFunction() const { return std::get(decl_); } + /** @return The final constructed function */ + ast::FunctionDecl *GetFinishedFunction() const { + NOISEPAGE_ASSERT(type_ == FunctionType::FUNCTION, "Attempt to get function from non-function-type builder"); + return std::get(decl_); + } - /** - * @return The final constructed lambda, or nullptr if the builder - * hasn't been constructed through FunctionBuilder::FinishLambda(). - */ - ast::LambdaExpr *GetConstructedLambda() const { return std::get(decl_); } + /** @return The final constructed closure */ + ast::LambdaExpr *GetFinishedClosure() const { + NOISEPAGE_ASSERT(type_ == FunctionType::CLOSURE, "Attempt to get closure from non-closure-type builder"); + return std::get(decl_); + } /** @return The code generator instance. */ CodeGen *GetCodeGen() const { return codegen_; } - /** @return `true` if the function represents a lambda, `false` otherwise. */ - bool IsLambda() const { return is_lambda_; } + /** @return `true` if the function is a closure, `false` otherwise */ + bool IsClosure() const { return type_ == FunctionType::CLOSURE; } private: + /** The type of the function */ + FunctionType type_; /** The code generation instance */ CodeGen *codegen_; /** The function's name */ ast::Identifier name_; /** The function's arguments */ util::RegionVector params_; + /** The captures for the closure (if applicable) */ + util::RegionVector captures_; /** The return type of the function */ - ast::Expr *ret_type_; + ast::Expr *return_type_; /** The start and stop position of statements in the function */ SourcePosition start_; /** The list of generated statements making up the function */ ast::BlockStmt *statements_; - /** `true` if this function is a lambda, `false` otherwise */ - bool is_lambda_; - /** The cached function declaration. Constructed once in Finish() */ + /** The cached, completed function; constructed once in Finish() */ std::variant decl_; }; diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index 54e9e901af..46fa9a362a 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -234,30 +234,30 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { * represented by `plan` into the variables identified by `variables`. * @param plan The query plan * @param variables The names of the variables to which results are bound - * @return The lambda expression + * @return The builder used to construct the expression (unfinished) */ - ast::LambdaExpr *MakeLambda(common::ManagedPointer plan, - const std::vector &variables); + std::unique_ptr StartLambda(common::ManagedPointer plan, + const std::vector &variables); /** * Construct a lambda expression that writes the output of the query * represented by `plan` into a single RECORD-type variable. * @param plan The query plan * @param variables The names of the variables to which results are bound - * @return The lambda expression + * @return The builder used to construct the expression (unfinished) */ - ast::LambdaExpr *MakeLambdaBindingToRecord(common::ManagedPointer plan, - const std::vector &variables); + std::unique_ptr StartLambdaBindingToRecord(common::ManagedPointer plan, + const std::vector &variables); /** * Construct a lambda expression that writes the output of the query * represented by `plan` into one or more non-RECORD variables. * @param plan The query plan * @param variables The names of the variables to which results are bound - * @return The lambda expression + * @return The builder used to construct the expression (unfinished) */ - ast::LambdaExpr *MakeLambdaBindingToNonRecord(common::ManagedPointer plan, - const std::vector &variables); + std::unique_ptr StartLambdaBindingToNonRecord(common::ManagedPointer plan, + const std::vector &variables); /** * Generate code to add query parameters to the execution context. From c3e81bc42569c002d4e9fbb6fdd66f6b7b06ee2f Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Wed, 28 Jul 2021 22:20:33 -0400 Subject: [PATCH 088/139] fix some clang tidy errors --- src/execution/compiler/function_builder.cpp | 4 ++-- src/execution/sql/ddl_executors.cpp | 2 +- .../execution/compiler/compilation_context.h | 3 --- src/include/parser/udf/plpgsql_parser.h | 14 ++------------ 4 files changed, 5 insertions(+), 18 deletions(-) diff --git a/src/execution/compiler/function_builder.cpp b/src/execution/compiler/function_builder.cpp index c96613a26b..41e27b11fa 100644 --- a/src/execution/compiler/function_builder.cpp +++ b/src/execution/compiler/function_builder.cpp @@ -7,13 +7,13 @@ namespace noisepage::execution::compiler { FunctionBuilder::FunctionBuilder(CodeGen *codegen, ast::Identifier name, util::RegionVector &¶ms, - ast::Expr *ret_type) + ast::Expr *return_type) : type_{FunctionType::FUNCTION}, codegen_{codegen}, name_{name}, params_{std::move(params)}, captures_{codegen_->GetAstContext()->GetRegion()}, - return_type_{ret_type}, + return_type_{return_type}, start_{codegen->GetPosition()}, statements_{codegen->MakeEmptyBlock()}, decl_{std::in_place_type, nullptr} {} diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index 41a654665f..4dba2e0b8a 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -73,7 +73,7 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerGetDatabaseOid()}; + parser::udf::PLpgSQLParser udf_parser{common::ManagedPointer{&udf_ast_context}}; std::unique_ptr ast{}; try { ast = udf_parser.Parse(node->GetFunctionParameterNames(), param_type_ids, body); diff --git a/src/include/execution/compiler/compilation_context.h b/src/include/execution/compiler/compilation_context.h index 26618353e5..d92a51ae90 100644 --- a/src/include/execution/compiler/compilation_context.h +++ b/src/include/execution/compiler/compilation_context.h @@ -200,9 +200,6 @@ class CompilationContext { // Whether pipeline metrics are enabled. bool pipeline_metrics_enabled_; - - // The current operator. - OperatorTranslator *current_op_; }; } // namespace noisepage::execution::compiler diff --git a/src/include/parser/udf/plpgsql_parser.h b/src/include/parser/udf/plpgsql_parser.h index ae3cd5b8d2..4f53072c92 100644 --- a/src/include/parser/udf/plpgsql_parser.h +++ b/src/include/parser/udf/plpgsql_parser.h @@ -31,12 +31,9 @@ class PLpgSQLParser { /** * Construct a new PLpgSQLParser instance. * @param udf_ast_context The AST context - * @param accessor The accessor to use during parsing - * @param db_oid The database OID */ - PLpgSQLParser(common::ManagedPointer udf_ast_context, - const common::ManagedPointer accessor, catalog::db_oid_t db_oid) - : udf_ast_context_(udf_ast_context), accessor_(accessor), db_oid_(db_oid) {} + explicit PLpgSQLParser(common::ManagedPointer udf_ast_context) + : udf_ast_context_{udf_ast_context} {} /** * Parse source PL/pgSQL to an abstract syntax tree. @@ -155,13 +152,6 @@ class PLpgSQLParser { private: /** The UDF AST context */ common::ManagedPointer udf_ast_context_; - - /** The catalog accessor */ - const common::ManagedPointer accessor_; - - /** The OID for the database with which the function is associated */ - catalog::db_oid_t db_oid_; - /** The function symbol table */ std::unordered_map symbol_table_; }; From aa2dfb422349ab2ccd7d002f4c2732c27e3b8ad3 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 29 Jul 2021 14:33:16 -0400 Subject: [PATCH 089/139] able to bind query results to scalars in both queries and query-variant for-loops --- script/testing/junit/sql/udf.sql | 17 +- script/testing/junit/traces/udf.test | 20 +- script/testing/util/db_server.py | 2 +- src/execution/compiler/udf/udf_codegen.cpp | 241 ++++++++++++++---- .../execution/compiler/udf/udf_codegen.h | 74 ++++-- 5 files changed, 277 insertions(+), 77 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 233f09ddc2..d6ede89888 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -153,12 +153,25 @@ SELECT sql_select_mutliple_constants(); -- ---------------------------------------------------------------------------- -- proc_fors() --- --- TODO(Kyle): for-loop control flow (query variant) is not supported -- CREATE TABLE tmp(z INT); -- INSERT INTO tmp(z) VALUES (0), (1); +-- Select constant into a scalar variable +CREATE FUNCTION proc_fors_constant_var() RETURNS INT AS $$ \ +DECLARE \ + v INT; \ + x INT := 0; \ +BEGIN \ + FOR v IN SELECT 1 LOOP \ + x = x + 1; \ + END LOOP; \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_fors_constant_var(); + -- -- Bind query result to a RECORD type -- CREATE FUNCTION proc_fors_rec() RETURNS INT AS $$ \ -- DECLARE \ diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 45b6934175..d0a7042ac3 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -325,19 +325,31 @@ statement ok -- proc_fors() statement ok --- + + +statement ok +-- CREATE TABLE tmp(z INT); statement ok --- TODO(Kyle): for-loop control flow (query variant) is not supported +-- INSERT INTO tmp(z) VALUES (0), (1); statement ok statement ok --- CREATE TABLE tmp(z INT); +-- Select constant into a scalar variable statement ok --- INSERT INTO tmp(z) VALUES (0), (1); +CREATE FUNCTION proc_fors_constant_var() RETURNS INT AS $$ DECLARE v INT; x INT := 0; BEGIN FOR v IN SELECT 1 LOOP x = x + 1; END LOOP; RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_fors_constant_var(); +---- +1 + statement ok diff --git a/script/testing/util/db_server.py b/script/testing/util/db_server.py index c77b8547a1..0e3809c9f9 100644 --- a/script/testing/util/db_server.py +++ b/script/testing/util/db_server.py @@ -164,7 +164,7 @@ def stop_db(self, is_dry_run=False): finally: unix_socket = os.path.join("/tmp/", f".s.PGSQL.{self.db_port}") if os.path.exists(unix_socket): - # os.remove(unix_socket) + os.remove(unix_socket) LOG.info(f"Removing: {unix_socket}") self.print_db_logs() exit_code = self.db_process.returncode diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index beaa307c3d..e43396eba2 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -54,14 +54,6 @@ const char *UdfCodegen::GetReturnParamString() { return "return_val"; } void UdfCodegen::GenerateUDF(ast::udf::AbstractAST *ast) { ast->Accept(this); } -void UdfCodegen::Visit(ast::udf::AbstractAST *ast) { - throw NOT_IMPLEMENTED_EXCEPTION("UdfCodegen::Visit(AbstractAST*)"); -} - -void UdfCodegen::Visit(ast::udf::DynamicSQLStmtAST *ast) { - throw NOT_IMPLEMENTED_EXCEPTION("UdfCodegen::Visit(DynamicSQLStmtAST*)"); -} - catalog::type_oid_t UdfCodegen::GetCatalogTypeOidFromSQLType(execution::ast::BuiltinType::Kind type) { switch (type) { case execution::ast::BuiltinType::Kind::Integer: { @@ -84,6 +76,18 @@ execution::ast::File *UdfCodegen::Finish() { return file; } +/* ---------------------------------------------------------------------------- + Code Generation: "Simple" Constructs +---------------------------------------------------------------------------- */ + +void UdfCodegen::Visit(ast::udf::AbstractAST *ast) { + throw NOT_IMPLEMENTED_EXCEPTION("UdfCodegen::Visit(AbstractAST*)"); +} + +void UdfCodegen::Visit(ast::udf::DynamicSQLStmtAST *ast) { + throw NOT_IMPLEMENTED_EXCEPTION("UdfCodegen::Visit(DynamicSQLStmtAST*)"); +} + void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { std::vector args_ast{}; std::vector args_ast_region_vec{}; @@ -327,8 +331,28 @@ void UdfCodegen::Visit(ast::udf::WhileStmtAST *ast) { loop.EndLoop(); } +void UdfCodegen::Visit(ast::udf::RetStmtAST *ast) { + ast->Return()->Accept(reinterpret_cast(this)); + auto ret_expr = dst_; + fb_->Append(codegen_->Return(ret_expr)); +} + +void UdfCodegen::Visit(ast::udf::MemberExprAST *ast) { + ast->Object()->Accept(reinterpret_cast(this)); + auto object = dst_; + dst_ = codegen_->AccessStructMember(object, codegen_->MakeIdentifier(ast->FieldName())); +} + +/* ---------------------------------------------------------------------------- + Code Generation: Integer-Variant For-Loops +---------------------------------------------------------------------------- */ + void UdfCodegen::Visit(ast::udf::ForIStmtAST *ast) { throw NOT_IMPLEMENTED_EXCEPTION("ForIStmtAST Not Implemented"); } +/* ---------------------------------------------------------------------------- + Code Generation: Query-Variant For-Loops +---------------------------------------------------------------------------- */ + void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { // Executing a SQL query requires an execution context needs_exec_ctx_ = true; @@ -345,8 +369,8 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { // Start construction of the lambda expression auto builder = StartLambda(plan, ast->Variables()); - // Generate code for closure capture assignment - CodegenCaptureAssignments(plan, ast->Variables()); + // Generate code for variable initialization + CodegenBoundVariableInit(plan, ast->Variables()); // Generate code for the loop body { @@ -390,7 +414,7 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); auto function_names = exec_query->GetFunctionNames(); - for (auto &function_name : function_names) { + for (const auto &function_name : function_names) { if (IsRunFunction(function_name)) { fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(function_name), {codegen_->AddressOf(query_state), codegen_->MakeExpr(lambda_identifier)})); @@ -403,12 +427,127 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); } -void UdfCodegen::Visit(ast::udf::RetStmtAST *ast) { - ast->Return()->Accept(reinterpret_cast(this)); - auto ret_expr = dst_; - fb_->Append(codegen_->Return(ret_expr)); +std::unique_ptr UdfCodegen::StartLambda(common::ManagedPointer plan, + const std::vector &variables) { + return GetVariableType(variables.front()) == type::TypeId::INVALID ? StartLambdaBindingToRecord(plan, variables) + : StartLambdaBindingToScalars(plan, variables); } +std::unique_ptr UdfCodegen::StartLambdaBindingToRecord( + common::ManagedPointer plan, const std::vector &variables) { + // bind results to a single RECORD variable + NOISEPAGE_ASSERT(variables.size() == 1, "Broken invariant"); + + const std::string &record_name = variables.front(); + const auto record_type = GetRecordType(record_name); + + const auto n_fields = record_type.size(); + const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); + if (n_fields != n_columns) { + throw EXECUTION_EXCEPTION( + fmt::format("Attempt to bind {} query outputs to record type with {} fields", n_columns, n_fields), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + // The lambda accepts all columns of the query output schema as parameters + util::RegionVector parameters{codegen_->GetAstContext()->GetRegion()}; + + // The first parameter is always the execution context + ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); + parameters.push_back( + codegen_->MakeField(exec_ctx->As()->Name(), + codegen_->PointerType(codegen_->BuiltinType(ast::BuiltinType::Kind::ExecutionContext)))); + + // The lambda captures all variables in the symbol table + // NOTE(Kyle): It might be possible / preferable to make this more conservative + util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; + for (const auto &[name, identifier] : SymbolTable()) { + if (name != "executionCtx") { + captures.push_back(codegen_->MakeExpr(identifier)); + } + } + + // While the closure only captures a single variable, we still need + // to generate code for an assignment to each field memeber + std::vector assignees{}; + assignees.reserve(n_columns); + + ast::Expr *record = codegen_->MakeExpr(SymbolTable().find(record_name)->second); + for (std::size_t i = 0UL; i < n_columns; ++i) { + const auto &column = plan->GetOutputSchema()->GetColumn(i); + assignees.push_back(codegen_->AccessStructMember(record, codegen_->MakeIdentifier(record_type[i].first))); + parameters.push_back(codegen_->MakeField(codegen_->MakeFreshIdentifier("input"), + codegen_->TplType(sql::GetTypeId(column.GetType())))); + } + + auto builder = std::make_unique(codegen_, std::move(parameters), std::move(captures), + codegen_->BuiltinType(ast::BuiltinType::Nil)); + for (std::size_t i = 0UL; i < assignees.size(); ++i) { + auto *assignee = assignees.at(i); + auto input_parameter = builder->GetParameterByPosition(i + 1); + builder->Append(codegen_->Assign(assignee, input_parameter)); + } + return builder; +} + +std::unique_ptr UdfCodegen::StartLambdaBindingToScalars( + common::ManagedPointer plan, const std::vector &variables) { + // bind results to one or more non-RECORD variables + const auto n_variables = variables.size(); + const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); + if (n_variables != n_columns) { + throw EXECUTION_EXCEPTION(fmt::format("Attempt to bind {} query outputs to {} variables", n_columns, n_variables), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + // The lambda accepts all columns of the query output schema as parameters + util::RegionVector parameters{codegen_->GetAstContext()->GetRegion()}; + + // The lambda captures all variables in the symbol table + // NOTE(Kyle): It might be possible / preferable to make this more conservative + util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; + for (const auto &[name, identifier] : SymbolTable()) { + if (name != "executionCtx") { + captures.push_back(codegen_->MakeExpr(identifier)); + } + } + + // The first parameter is always the execution context + ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); + parameters.push_back(codegen_->MakeField( + exec_ctx->As()->Name(), + codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); + + // Assignees are those captures that are written in the closure + std::vector assignees{}; + assignees.reserve(n_columns); + + // Populate the parameters and capture assignees + for (std::size_t i = 0UL; i < n_columns; ++i) { + const auto &variable = variables.at(i); + const auto &column = plan->GetOutputSchema()->GetColumn(i); + assignees.push_back(codegen_->MakeExpr(SymbolTable().find(variable)->second)); + parameters.push_back(codegen_->MakeField(codegen_->MakeFreshIdentifier("input"), + codegen_->TplType(sql::GetTypeId(column.GetType())))); + } + + // Begin construction of the function that implements the closure + auto builder = std::make_unique(codegen_, std::move(parameters), std::move(captures), + codegen_->BuiltinType(execution::ast::BuiltinType::Nil)); + + // Generate an assignment from each input parameter to the associated capture + for (std::size_t i = 0UL; i < assignees.size(); ++i) { + ast::Expr *capture = assignees.at(i); + auto input_parameter = builder->GetParameterByPosition(i + 1); + builder->Append(codegen_->Assign(capture, input_parameter)); + } + return builder; +} + +/* ---------------------------------------------------------------------------- + Code Generation: SQL Statements +---------------------------------------------------------------------------- */ + void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { // Executing a SQL query requires an execution context needs_exec_ctx_ = true; @@ -420,8 +559,7 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { // Construct a lambda that writes the output of the query // into the bound variables, as defined by the function body - auto builder = StartLambda(plan, ast->Variables()); - ast::LambdaExpr *lambda_expr = builder->FinishClosure(); + ast::LambdaExpr *lambda_expr = MakeLambda(plan, ast->Variables()); const ast::Identifier lambda_identifier = codegen_->MakeFreshIdentifier("udfLambda"); lambda_expr->SetName(lambda_identifier); @@ -456,7 +594,7 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); // Initialize the captures - CodegenCaptureAssignments(plan, ast->Variables()); + CodegenBoundVariableInit(plan, ast->Variables()); // Manually append calls to each function from the compiled // executable query (implementing the closure) to the builder @@ -474,24 +612,14 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { fb_->Append(codegen_->CallBuiltin(ast::Builtin::FinishNewParams, {exec_ctx})); } -void UdfCodegen::Visit(ast::udf::MemberExprAST *ast) { - ast->Object()->Accept(reinterpret_cast(this)); - auto object = dst_; - dst_ = codegen_->AccessStructMember(object, codegen_->MakeIdentifier(ast->FieldName())); +ast::LambdaExpr *UdfCodegen::MakeLambda(common::ManagedPointer plan, + const std::vector &variables) { + return GetVariableType(variables.front()) == type::TypeId::INVALID ? MakeLambdaBindingToRecord(plan, variables) + : MakeLambdaBindingToScalars(plan, variables); } -/* ---------------------------------------------------------------------------- - Code Generation Helpers ----------------------------------------------------------------------------- */ - -std::unique_ptr UdfCodegen::StartLambda(common::ManagedPointer plan, - const std::vector &variables) { - return GetVariableType(variables.front()) == type::TypeId::INVALID ? StartLambdaBindingToRecord(plan, variables) - : StartLambdaBindingToNonRecord(plan, variables); -} - -std::unique_ptr UdfCodegen::StartLambdaBindingToRecord( - common::ManagedPointer plan, const std::vector &variables) { +ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToRecord(common::ManagedPointer plan, + const std::vector &variables) { // bind results to a single RECORD variable NOISEPAGE_ASSERT(variables.size() == 1, "Broken invariant"); @@ -531,18 +659,19 @@ std::unique_ptr UdfCodegen::StartLambdaBindingToRecord( codegen_->TplType(sql::GetTypeId(column.GetType())))); } - auto builder = std::make_unique(codegen_, std::move(parameters), std::move(captures), - codegen_->BuiltinType(ast::BuiltinType::Nil)); + FunctionBuilder builder{codegen_, std::move(parameters), std::move(captures), + codegen_->BuiltinType(ast::BuiltinType::Nil)}; for (std::size_t i = 0UL; i < assignees.size(); ++i) { auto *assignee = assignees.at(i); - auto input_parameter = builder->GetParameterByPosition(i + 1); - builder->Append(codegen_->Assign(assignee, input_parameter)); + auto input_parameter = builder.GetParameterByPosition(i + 1); + builder.Append(codegen_->Assign(assignee, input_parameter)); } - return builder; + + return builder.FinishClosure(); } -std::unique_ptr UdfCodegen::StartLambdaBindingToNonRecord( - common::ManagedPointer plan, const std::vector &variables) { +ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToScalars(common::ManagedPointer plan, + const std::vector &variables) { // bind results to one or more non-RECORD variables const auto n_variables = variables.size(); const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); @@ -575,18 +704,22 @@ std::unique_ptr UdfCodegen::StartLambdaBindingToNonRecord( const std::vector assignees{captures.cbegin(), captures.cend()}; // Begin construction of the function that implements the closure - auto builder = std::make_unique(codegen_, std::move(parameters), std::move(captures), - codegen_->BuiltinType(execution::ast::BuiltinType::Nil)); + FunctionBuilder builder{codegen_, std::move(parameters), std::move(captures), + codegen_->BuiltinType(execution::ast::BuiltinType::Nil)}; // Generate an assignment from each input parameter to the associated capture for (std::size_t i = 0UL; i < assignees.size(); ++i) { ast::Expr *capture = assignees.at(i); - auto input_parameter = builder->GetParameterByPosition(i + 1); - builder->Append(codegen_->Assign(capture, input_parameter)); + auto input_parameter = builder.GetParameterByPosition(i + 1); + builder.Append(codegen_->Assign(capture, input_parameter)); } - return builder; + return builder.FinishClosure(); } +/* ---------------------------------------------------------------------------- + Common Code Generation Helpers +---------------------------------------------------------------------------- */ + void UdfCodegen::CodegenAddParameters(ast::Expr *exec_ctx, const std::vector &variable_refs) { for (const auto &variable_ref : variable_refs) { if (variable_ref.IsScalar()) { @@ -626,22 +759,22 @@ void UdfCodegen::CodegenAddTableParameter(ast::Expr *exec_ctx, const parser::udf fb_->Append(codegen_->CallBuiltin(AddParamBuiltinForParameterType(type), {exec_ctx, expr})); } -void UdfCodegen::CodegenCaptureAssignments(common::ManagedPointer plan, - const std::vector &bound_variables) { +void UdfCodegen::CodegenBoundVariableInit(common::ManagedPointer plan, + const std::vector &bound_variables) { if (bound_variables.empty()) { // Nothing to do return; } if (GetVariableType(bound_variables.front()) == type::TypeId::INVALID) { - CodegenCaptureAssignmentToRecord(plan, bound_variables.front()); + CodegenBoundVariableInitForRecord(plan, bound_variables.front()); } else { - CodegenCaptureAssignmentToScalars(plan, bound_variables); + CodegenBoundVariableInitForScalars(plan, bound_variables); } } -void UdfCodegen::CodegenCaptureAssignmentToScalars(common::ManagedPointer plan, - const std::vector &bound_variables) { +void UdfCodegen::CodegenBoundVariableInitForScalars(common::ManagedPointer plan, + const std::vector &bound_variables) { const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); const auto n_variables = bound_variables.size(); if (n_columns != n_variables) { @@ -658,8 +791,8 @@ void UdfCodegen::CodegenCaptureAssignmentToScalars(common::ManagedPointer plan, - const std::string &record_name) { +void UdfCodegen::CodegenBoundVariableInitForRecord(common::ManagedPointer plan, + const std::string &record_name) { NOISEPAGE_ASSERT(GetVariableType(record_name) == type::TypeId::INVALID, "Broken invariant"); const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); const auto fields = GetRecordType(record_name); diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index 46fa9a362a..e29342f468 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -229,35 +229,77 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { static const char *GetReturnParamString(); private: + /* -------------------------------------------------------------------------- + Code Generation: For-S Loops + -------------------------------------------------------------------------- */ + /** - * Construct a lambda expression that writes the output of the query + * Begin construction of a lambda that writes the output of the query * represented by `plan` into the variables identified by `variables`. * @param plan The query plan * @param variables The names of the variables to which results are bound - * @return The builder used to construct the expression (unfinished) + * @return The unfinished function builder for the lambda */ std::unique_ptr StartLambda(common::ManagedPointer plan, const std::vector &variables); /** - * Construct a lambda expression that writes the output of the query + * Begin construction of a lambda that writes the output of the query * represented by `plan` into a single RECORD-type variable. * @param plan The query plan * @param variables The names of the variables to which results are bound - * @return The builder used to construct the expression (unfinished) + * @return The unfinished function builder for the lambda */ std::unique_ptr StartLambdaBindingToRecord(common::ManagedPointer plan, const std::vector &variables); + /** + * Begin construction of a lambda that writes the output of the query + * represented by `plan` into one or more non-RECORD variables. + * @param plan The query plan + * @param variables The names of the variables to which results are bound + * @return The unfinished function builder for the lambda + */ + std::unique_ptr StartLambdaBindingToScalars(common::ManagedPointer plan, + const std::vector &variables); + + /* -------------------------------------------------------------------------- + Code Generation: SQL Statements + -------------------------------------------------------------------------- */ + + /** + * Construct a lambda expression that writes the output of the query + * represented by `plan` into the variables identified by `variables`. + * @param plan The query plan + * @param variables The names of the variables to which results are bound + * @return The finished lambda expression + */ + ast::LambdaExpr *MakeLambda(common::ManagedPointer plan, + const std::vector &variables); + + /** + * Construct a lambda expression that writes the output of the query + * represented by `plan` into a single RECORD-type variable. + * @param plan The query plan + * @param variables The names of the variables to which results are bound + * @return The finished lambda expression + */ + ast::LambdaExpr *MakeLambdaBindingToRecord(common::ManagedPointer plan, + const std::vector &variables); + /** * Construct a lambda expression that writes the output of the query * represented by `plan` into one or more non-RECORD variables. * @param plan The query plan * @param variables The names of the variables to which results are bound - * @return The builder used to construct the expression (unfinished) + * @return The finished lambda expression */ - std::unique_ptr StartLambdaBindingToNonRecord(common::ManagedPointer plan, - const std::vector &variables); + ast::LambdaExpr *MakeLambdaBindingToScalars(common::ManagedPointer plan, + const std::vector &variables); + + /* -------------------------------------------------------------------------- + Code Generation: Common + -------------------------------------------------------------------------- */ /** * Generate code to add query parameters to the execution context. @@ -281,28 +323,28 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { void CodegenAddTableParameter(ast::Expr *exec_ctx, const parser::udf::VariableRef &variable_ref); /** - * Generate code to perform assignment to captured variables. + * Generate code to initialize bound variables. * @param plan The query plan * @param bound_variables The variables to which results of the query are bound */ - void CodegenCaptureAssignments(common::ManagedPointer plan, - const std::vector &bound_variables); + void CodegenBoundVariableInit(common::ManagedPointer plan, + const std::vector &bound_variables); /** - * Generate code to perform assignment to captured variables. + * Generate code to initialize bound scalar variables. * @param plan The query plan * @param bound_variables The name(s) of the scalar variables to which results of the query are bound */ - void CodegenCaptureAssignmentToScalars(common::ManagedPointer plan, - const std::vector &bound_variables); + void CodegenBoundVariableInitForScalars(common::ManagedPointer plan, + const std::vector &bound_variables); /** - * Generate code to perform assignment to captured variables. + * Generate code to initialize a bound record variable. * @param plan The query plan * @param record_name The name of the record variable to which results of the query are bound */ - void CodegenCaptureAssignmentToRecord(common::ManagedPointer plan, - const std::string &record_name); + void CodegenBoundVariableInitForRecord(common::ManagedPointer plan, + const std::string &record_name); /** * Translate a SQL type to its corresponding catalog type. From d1c0b66f79b7c34aa64777bf1ae4901ffce23e26 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 29 Jul 2021 15:23:58 -0400 Subject: [PATCH 090/139] update error handling to not crash the system on code generation failure, integration test for multiple rows in query-variant for-loop --- script/testing/junit/sql/udf.sql | 20 ++++++++++++++-- script/testing/junit/traces/udf.test | 22 +++++++++++++++-- src/execution/compiler/udf/udf_codegen.cpp | 24 ++++++++++--------- src/execution/sql/ddl_executors.cpp | 19 +++++++++++---- .../execution/compiler/udf/udf_codegen.h | 11 ++++++--- 5 files changed, 73 insertions(+), 23 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index d6ede89888..9a9deb1ba9 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -172,6 +172,22 @@ $$ LANGUAGE PLPGSQL; SELECT proc_fors_constant_var(); +-- Select multiple constants in scalar variables +CREATE FUNCTION proc_fors_constant_vars() RETURNS INT AS $$ \ +DECLARE \ + x INT; \ + y INT; \ + z INT := 0; \ +BEGIN \ + FOR x, y IN SELECT 1, 2 LOOP \ + z = z + 1; \ + END LOOP; \ + RETURN z; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_fors_constant_vars(); + -- -- Bind query result to a RECORD type -- CREATE FUNCTION proc_fors_rec() RETURNS INT AS $$ \ -- DECLARE \ @@ -185,7 +201,7 @@ SELECT proc_fors_constant_var(); -- END \ -- $$ LANGUAGE PLPGSQL; --- SELECT x, proc_fors_rec() FROM integers; +-- SELECT proc_fors_rec() FROM integers; -- -- Bind query result directly to INT type -- CREATE FUNCTION proc_fors_var() RETURNS INT AS $$ \ @@ -200,4 +216,4 @@ SELECT proc_fors_constant_var(); -- END \ -- $$ LANGUAGE PLPGSQL; --- SELECT x, proc_fors_var() FROM integers; +-- SELECT proc_fors_var() FROM integers; diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index d0a7042ac3..afcca9e69d 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -354,6 +354,24 @@ SELECT proc_fors_constant_var(); statement ok +statement ok +-- Select multiple constants in scalar variables + +statement ok +CREATE FUNCTION proc_fors_constant_vars() RETURNS INT AS $$ DECLARE x INT; y INT; z INT := 0; BEGIN FOR x, y IN SELECT 1, 2 LOOP z = z + 1; END LOOP; RETURN z; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_fors_constant_vars(); +---- +1 + + +statement ok + + statement ok -- -- Bind query result to a RECORD type @@ -367,7 +385,7 @@ statement ok statement ok --- SELECT x, proc_fors_rec() FROM integers; +-- SELECT proc_fors_rec() FROM integers; statement ok @@ -382,5 +400,5 @@ statement ok statement ok --- SELECT x, proc_fors_var() FROM integers; +-- SELECT proc_fors_var() FROM integers; diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index e43396eba2..ad2a4f7d45 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -358,13 +358,13 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { needs_exec_ctx_ = true; execution::ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); + // Bind the embedded query; must do this prior to attempting + // to optimize to ensure correctness + const auto variable_refs = BindQueryAndGetVariableRefs(ast->Query()); + // Optimize the embedded query auto optimize_result = OptimizeEmbeddedQuery(ast->Query()); auto plan = optimize_result->GetPlanNode(); - if (plan->GetOutputSchema()->GetColumns().size() > 1) { - throw EXECUTION_EXCEPTION("PL/pgSQL Codegen : support for non-scalars is not implemented", - common::ErrorCode::ERRCODE_PLPGSQL_ERROR); - } // Start construction of the lambda expression auto builder = StartLambda(plan, ast->Variables()); @@ -404,10 +404,6 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { // Set its execution context to whatever execution context was passed in here fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); - // Derive the columns and parameter names from the query - binder::BindNodeVisitor visitor{common::ManagedPointer{accessor_}, db_oid_}; - const auto variable_refs = - visitor.BindAndGetUDFVariableRefs(common::ManagedPointer{ast->Query()}, common::ManagedPointer{udf_ast_context_}); CodegenAddParameters(exec_ctx, variable_refs); fb_->Append(codegen_->Assign( @@ -553,6 +549,10 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { needs_exec_ctx_ = true; ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); + // Bind the embedded query; must do this prior to attempting + // to optimize to ensure correctness + const auto variable_refs = BindQueryAndGetVariableRefs(ast->Query()); + // Optimize the query and generate get a reference to the plan auto optimize_result = OptimizeEmbeddedQuery(ast->Query()); auto plan = optimize_result->GetPlanNode(); @@ -584,9 +584,6 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { // Determine the column references in the query (if any) // that depend on variables in the UDF definition - binder::BindNodeVisitor visitor{common::ManagedPointer{accessor_}, db_oid_}; - const auto variable_refs = - visitor.BindAndGetUDFVariableRefs(common::ManagedPointer{ast->Query()}, common::ManagedPointer{udf_ast_context_}); CodegenAddParameters(exec_ctx, variable_refs); // Load the execution context member of the query state @@ -837,6 +834,11 @@ std::vector> UdfCodegen::GetRecordType(cons return type.value(); } +std::vector UdfCodegen::BindQueryAndGetVariableRefs(parser::ParseResult *query) { + binder::BindNodeVisitor visitor{common::ManagedPointer{accessor_}, db_oid_}; + return visitor.BindAndGetUDFVariableRefs(common::ManagedPointer{query}, common::ManagedPointer{udf_ast_context_}); +} + std::unique_ptr UdfCodegen::OptimizeEmbeddedQuery(parser::ParseResult *parsed_query) { optimizer::StatsStorage stats{}; const std::uint64_t optimizer_timeout = 1000000; diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index 4dba2e0b8a..03c34b95e1 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -77,8 +77,8 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointer ast{}; try { ast = udf_parser.Parse(node->GetFunctionParameterNames(), param_type_ids, body); - } catch (const ParserException &e) { - PARSER_LOG_ERROR(e.what()); + } catch (const ParserException &parser_error) { + PARSER_LOG_ERROR(parser_error.what()); return false; } @@ -107,8 +107,17 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerGetReturnType())))}; // Run UDF code generation - auto *file = compiler::udf::UdfCodegen::Run(accessor.Get(), &fb, &udf_ast_context, &codegen, node->GetDatabaseOid(), - ast.get()); + ast::File *file; + try { + file = compiler::udf::UdfCodegen::Run(accessor.Get(), &fb, &udf_ast_context, &codegen, node->GetDatabaseOid(), + ast.get()); + } catch (const BinderException &binder_error) { + EXECUTION_LOG_ERROR(binder_error.what()); + return false; + } catch (const ExecutionException &execution_error) { + EXECUTION_LOG_ERROR(execution_error.what()); + return false; + } { sema::Sema type_check{codegen.GetAstContext().Get()}; @@ -121,7 +130,7 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointer types{}; types.reserve(node->GetFunctionParameterTypes().size()); diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index e29342f468..136b773c2e 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -71,9 +71,7 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { UdfCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, ast::udf::UdfAstContext *udf_ast_context, CodeGen *codegen, catalog::db_oid_t db_oid); - /** - * Destroy the UDF code generation context. - */ + /** Destroy the UDF code generation context. */ ~UdfCodegen() override = default; /** @@ -375,6 +373,13 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { */ std::vector> GetRecordType(const std::string &name) const; + /** + * Bind the query and return the variable references. + * @param query The parsed query + * @return The collection of variable references + */ + std::vector BindQueryAndGetVariableRefs(parser::ParseResult *query); + /** * Run the optimizer on an embedded SQL query. * @param parsed_query The result of parsing the query From ec93afc1f36a91f2a3de9c0d9ec4f5c09d6df469 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sun, 1 Aug 2021 22:54:41 -0400 Subject: [PATCH 091/139] major refactor, primarily in pipeline implementation to account for different pipeline function signatures when output callback is present, tests passing, still want to revisit again because the implementation is over-complicated --- script/testing/junit/sql/udf.sql | 20 +- .../compiler/compilation_context.cpp | 2 +- src/execution/compiler/executable_query.cpp | 19 + src/execution/compiler/function_builder.cpp | 8 + .../operator/hash_join_translator.cpp | 4 +- .../compiler/operator/output_translator.cpp | 23 +- src/execution/compiler/pipeline.cpp | 623 +++++++++++------- src/execution/compiler/udf/udf_codegen.cpp | 81 ++- src/execution/sema/sema_builtin.cpp | 9 +- .../execution/compiler/compilation_context.h | 36 +- .../execution/compiler/executable_query.h | 17 +- .../execution/compiler/function_builder.h | 6 + .../compiler/operator/output_translator.h | 23 +- src/include/execution/compiler/pipeline.h | 201 +++--- .../execution/compiler/udf/udf_codegen.h | 28 +- src/include/execution/sema/error_message.h | 2 +- src/include/execution/vm/module.h | 2 +- 17 files changed, 695 insertions(+), 409 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 9a9deb1ba9..261f14cd40 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -174,16 +174,16 @@ SELECT proc_fors_constant_var(); -- Select multiple constants in scalar variables CREATE FUNCTION proc_fors_constant_vars() RETURNS INT AS $$ \ -DECLARE \ - x INT; \ - y INT; \ - z INT := 0; \ -BEGIN \ - FOR x, y IN SELECT 1, 2 LOOP \ - z = z + 1; \ - END LOOP; \ - RETURN z; \ -END \ +DECLARE \ + x INT; \ + y INT; \ + z INT := 0; \ +BEGIN \ + FOR x, y IN SELECT 1, 2 LOOP \ + z = z + 1; \ + END LOOP; \ + RETURN z; \ +END \ $$ LANGUAGE PLPGSQL; SELECT proc_fors_constant_vars(); diff --git a/src/execution/compiler/compilation_context.cpp b/src/execution/compiler/compilation_context.cpp index 95f8e06745..f6c5f857be 100644 --- a/src/execution/compiler/compilation_context.cpp +++ b/src/execution/compiler/compilation_context.cpp @@ -188,7 +188,7 @@ void CompilationContext::GeneratePlan(const planner::AbstractPlanNode &plan, } main_builder.DeclareAll(pipeline_decls); } - pipeline->GeneratePipeline(&main_builder, query_id_t{unique_id_}, output_callback_); + pipeline->GeneratePipeline(&main_builder); } // Register the tear-down function. diff --git a/src/execution/compiler/executable_query.cpp b/src/execution/compiler/executable_query.cpp index 7bd5a9a225..7e7e92f25a 100644 --- a/src/execution/compiler/executable_query.cpp +++ b/src/execution/compiler/executable_query.cpp @@ -10,6 +10,7 @@ #include "execution/compiler/compiler.h" #include "execution/exec/execution_context.h" #include "execution/sema/error_reporter.h" +#include "execution/vm/bytecode_function_info.h" #include "execution/vm/module.h" #include "loggers/execution_logger.h" #include "self_driving/modeling/operating_unit.h" @@ -64,6 +65,11 @@ void ExecutableQuery::Fragment::Run(byte query_state[], vm::ExecutionMode mode) } } +std::optional ExecutableQuery::Fragment::GetFunctionMetadata(const std::string &name) const { + const auto *metadata = module_->GetFuncInfoByName(name); + return (metadata == nullptr) ? std::nullopt : std::make_optional(metadata); +} + const vm::ModuleMetadata &ExecutableQuery::Fragment::GetModuleMetadata() const { return module_->GetMetadata(); } //===----------------------------------------------------------------------===// @@ -206,6 +212,19 @@ std::vector ExecutableQuery::GetFunctionNames() const { return function_names; } +std::vector ExecutableQuery::GetFunctionMetadata() const { + std::vector function_meta{}; + for (const auto &f : fragments_) { + const auto function_names = f->GetFunctions(); + for (const auto &function_name : function_names) { + auto meta = f->GetFunctionMetadata(function_name); + NOISEPAGE_ASSERT(meta.has_value(), "Broken invariant"); + function_meta.push_back(meta.value()); + } + } + return function_meta; +} + std::vector ExecutableQuery::GetDecls() const { std::vector decls{}; for (const auto &f : fragments_) { diff --git a/src/execution/compiler/function_builder.cpp b/src/execution/compiler/function_builder.cpp index 41e27b11fa..8e1872a591 100644 --- a/src/execution/compiler/function_builder.cpp +++ b/src/execution/compiler/function_builder.cpp @@ -42,6 +42,14 @@ ast::Expr *FunctionBuilder::GetParameterByPosition(const std::size_t param_idx) return nullptr; } +std::vector FunctionBuilder::GetParameters() const { + std::vector parameters{}; + parameters.reserve(params_.size()); + std::transform(params_.cbegin(), params_.cend(), std::back_inserter(parameters), + [this](const ast::FieldDecl *p) -> ast::Expr * { return codegen_->MakeExpr(p->Name()); }); + return parameters; +} + void FunctionBuilder::Append(ast::Stmt *stmt) { // Append the statement to the block. statements_->AppendStatement(stmt); diff --git a/src/execution/compiler/operator/hash_join_translator.cpp b/src/execution/compiler/operator/hash_join_translator.cpp index 99a8942386..d2efb9d2ae 100644 --- a/src/execution/compiler/operator/hash_join_translator.cpp +++ b/src/execution/compiler/operator/hash_join_translator.cpp @@ -390,7 +390,7 @@ void HashJoinTranslator::CheckJoinPredicate(WorkContext *ctx, FunctionBuilder *f FillProbeRow(ctx, function, codegen->MakeExpr(probe_row_var_)); // joinConsumer(queryState, pipelineState, buildRow, probeRow); std::initializer_list args{GetQueryStatePtr(), - codegen->MakeExpr(GetPipeline()->GetPipelineStateVar()), + codegen->MakeExpr(GetPipeline()->GetPipelineStateName()), codegen->MakeExpr(build_row_var_), codegen->AddressOf(probe_row)}; function->Append(codegen->Call(join_consumer_, args)); } else { @@ -462,7 +462,7 @@ void HashJoinTranslator::CollectUnmatchedLeftRows(FunctionBuilder *function) con } // joinConsumer(queryState, pipelineState, buildRow, probeRow); std::initializer_list args{GetQueryStatePtr(), - codegen->MakeExpr(GetPipeline()->GetPipelineStateVar()), + codegen->MakeExpr(GetPipeline()->GetPipelineStateName()), codegen->MakeExpr(build_row_var_), codegen->AddressOf(probe_row)}; function->Append(codegen->Call(join_consumer_, args)); } diff --git a/src/execution/compiler/operator/output_translator.cpp b/src/execution/compiler/operator/output_translator.cpp index c02fcf52bf..4f5b362655 100644 --- a/src/execution/compiler/operator/output_translator.cpp +++ b/src/execution/compiler/operator/output_translator.cpp @@ -26,10 +26,16 @@ OutputTranslator::OutputTranslator(const planner::AbstractPlanNode &plan, Compil output_buffer_ = pipeline->DeclarePipelineStateEntry( "output_buffer", GetCodeGen()->PointerType(GetCodeGen()->BuiltinType(ast::BuiltinType::OutputBuffer))); num_output_ = CounterDeclare("num_output", pipeline); + + // If the compilation context contains an output callback, + // the output translator injects the callback into its pipeline + if (compilation_context->HasOutputCallback()) { + pipeline->SetOutputCallback(compilation_context->GetOutputCallback()); + } } void OutputTranslator::InitializePipelineState(const Pipeline &pipeline, FunctionBuilder *function) const { - if (GetCompilationContext()->GetOutputCallback() != nullptr) { + if (HasOutputCallback()) { return; } @@ -41,7 +47,7 @@ void OutputTranslator::InitializePipelineState(const Pipeline &pipeline, Functio } void OutputTranslator::TearDownPipelineState(const Pipeline &pipeline, FunctionBuilder *function) const { - if (GetCompilationContext()->GetOutputCallback() != nullptr) { + if (HasOutputCallback()) { return; } @@ -52,15 +58,14 @@ void OutputTranslator::TearDownPipelineState(const Pipeline &pipeline, FunctionB void OutputTranslator::PerformPipelineWork(noisepage::execution::compiler::WorkContext *context, noisepage::execution::compiler::FunctionBuilder *function) const { - auto out_buffer = output_buffer_.Get(GetCodeGen()); ast::Expr *cast_call; - auto callback = GetCompilationContext()->GetOutputCallback(); - if (callback != nullptr) { + if (HasOutputCallback()) { auto output = GetCodeGen()->MakeFreshIdentifier("output_row"); auto *row_alloc = GetCodeGen()->DeclareVarNoInit(output, GetCodeGen()->MakeExpr(output_struct_)); function->Append(row_alloc); cast_call = GetCodeGen()->AddressOf(GetCodeGen()->MakeExpr(output)); } else { + auto out_buffer = output_buffer_.Get(GetCodeGen()); ast::Expr *alloc_call = GetCodeGen()->CallBuiltin(ast::Builtin::ResultBufferAllocOutRow, {out_buffer}); cast_call = GetCodeGen()->PtrCast(output_struct_, alloc_call); } @@ -76,12 +81,14 @@ void OutputTranslator::PerformPipelineWork(noisepage::execution::compiler::WorkC ast::Expr *lhs = GetCodeGen()->AccessStructMember(GetCodeGen()->MakeExpr(output_var_), attr_name); ast::Expr *rhs = child_translator->GetOutput(context, attr_idx); function->Append(GetCodeGen()->Assign(lhs, rhs)); - if (callback != nullptr) { + if (HasOutputCallback()) { callback_args.push_back(lhs); } } - if (callback != nullptr) { + // If an output callback is present, append the callback invocation + if (HasOutputCallback()) { + auto *callback = GetCompilationContext()->GetOutputCallback(); function->Append(GetCodeGen()->Call(callback->As()->GetName(), callback_args)); } @@ -135,4 +142,6 @@ void OutputTranslator::DefineHelperStructs(util::RegionVector decls->push_back(codegen->DeclareStruct(output_struct_, std::move(fields))); } +bool OutputTranslator::HasOutputCallback() const { return GetCompilationContext()->HasOutputCallback(); } + } // namespace noisepage::execution::compiler diff --git a/src/execution/compiler/pipeline.cpp b/src/execution/compiler/pipeline.cpp index 16b707515c..cf7dd65fda 100644 --- a/src/execution/compiler/pipeline.cpp +++ b/src/execution/compiler/pipeline.cpp @@ -26,9 +26,8 @@ Pipeline::Pipeline(CompilationContext *ctx) : id_(ctx->RegisterPipeline(this)), compilation_context_(ctx), codegen_(compilation_context_->GetCodeGen()), - state_var_(codegen_->MakeIdentifier("pipelineState")), state_(codegen_->MakeIdentifier(fmt::format("{}_Pipeline{}_State", ctx->GetFunctionPrefix(), id_)), - [this](CodeGen *codegen) { return codegen_->MakeExpr(state_var_); }), + [this](CodeGen *codegen) { return codegen_->MakeExpr(GetPipelineStateName()); }), driver_(nullptr), parallelism_(Parallelism::Parallel), check_parallelism_(true), @@ -76,26 +75,6 @@ StateDescriptor::Entry Pipeline::DeclarePipelineStateEntry(const std::string &na return state.DeclareStateEntry(codegen_, name, type_repr); } -std::string Pipeline::CreatePipelineFunctionName(const std::string &func_name) const { - auto result = fmt::format("{}_Pipeline{}", compilation_context_->GetFunctionPrefix(), id_); - if (!func_name.empty()) { - result += "_" + func_name; - } - return result; -} - -ast::Identifier Pipeline::GetSetupPipelineStateFunctionName() const { - return codegen_->MakeIdentifier(CreatePipelineFunctionName("InitPipelineState")); -} - -ast::Identifier Pipeline::GetTearDownPipelineStateFunctionName() const { - return codegen_->MakeIdentifier(CreatePipelineFunctionName("TearDownPipelineState")); -} - -ast::Identifier Pipeline::GetWorkFunctionName() const { - return codegen_->MakeIdentifier(CreatePipelineFunctionName(IsParallel() ? "ParallelWork" : "SerialWork")); -} - void Pipeline::InjectStartResourceTracker(FunctionBuilder *builder, bool is_hook) const { if (compilation_context_->IsPipelineMetricsEnabled()) { auto *exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); @@ -143,16 +122,6 @@ void Pipeline::InjectEndResourceTracker(FunctionBuilder *builder, bool is_hook) } } -util::RegionVector Pipeline::PipelineParams() const { - // The main query parameters. - util::RegionVector query_params = compilation_context_->QueryParams(); - // Tag on the pipeline state. - auto &state = GetPipelineStateDescriptor(); - ast::Expr *pipeline_state = codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName())); - query_params.push_back(codegen_->MakeField(state_var_, pipeline_state)); - return query_params; -} - void Pipeline::LinkSourcePipeline(Pipeline *dependency) { NOISEPAGE_ASSERT(dependency != nullptr, "Source cannot be null"); // Add pipeline `dependency` as a nested pipeline @@ -173,8 +142,8 @@ void Pipeline::LinkNestedPipeline(Pipeline *pipeline, const OperatorTranslator * if (std::find(dependencies_.begin(), dependencies_.end(), pipeline) == dependencies_.end()) { pipeline->nested_pipelines_.push_back(this); } - if (!pipeline->nested_) { - pipeline->nested_ = true; + if (!pipeline->IsNestedPipeline()) { + pipeline->MarkNested(); // add to pipeline params std::size_t i = 0; for (auto &col : op->GetPlan().GetOutputSchema()->GetColumns()) { @@ -254,12 +223,52 @@ void Pipeline::Prepare(const exec::ExecutionSettings &exec_settings) { prepared_ = true; } -ast::FunctionDecl *Pipeline::GenerateSetupPipelineStateFunction() const { - auto name = GetSetupPipelineStateFunctionName(); - FunctionBuilder builder(codegen_, name, PipelineParams(), codegen_->Nil()); +/* ---------------------------------------------------------------------------- + Pipeline Generation: Top-Level +----------------------------------------------------------------------------- */ + +void Pipeline::GeneratePipeline(ExecutableQueryFragmentBuilder *builder) const { + NOISEPAGE_ASSERT(!(IsNestedPipeline() && HasOutputCallback()), + "Single pipeline cannot both be nested and have an output callback"); + + // Declare the pipeline state. + builder->DeclareStruct(state_.GetType()); + // Generate pipeline state initialization and tear-down functions. + builder->DeclareFunction(GenerateInitPipelineStateFunction()); + builder->DeclareFunction(GenerateTearDownPipelineStateFunction()); + + auto teardown = GenerateTearDownPipelineFunction(); + + // Declare top-level functions + builder->DeclareFunction(GenerateInitPipelineFunction()); + builder->DeclareFunction(GeneratePipelineWorkFunction()); + builder->DeclareFunction(GenerateRunPipelineFunction()); + builder->DeclareFunction(teardown); + + if (HasOutputCallback()) { + auto run_all = GeneratePipelineRunAllOutputCallbackFunction(); + builder->DeclareFunction(run_all); + builder->RegisterStep(run_all); + } else if (!IsNestedPipeline()) { + // Register the main init, run, tear-down functions as steps, in that order. + builder->RegisterStep(GenerateInitPipelineFunction()); + builder->RegisterStep(GenerateRunPipelineFunction()); + builder->RegisterStep(teardown); + } + + builder->AddTeardownFn(teardown); +} + +/* ---------------------------------------------------------------------------- + Pipeline Generation: State Setup + Teardown +----------------------------------------------------------------------------- */ + +ast::FunctionDecl *Pipeline::GenerateInitPipelineStateFunction() const { + auto name = GetInitPipelineStateFunctionName(); + FunctionBuilder builder{codegen_, name, GetInitPipelineStateParams(), codegen_->Nil()}; { // Request new scope for the function. - CodeGen::CodeScope code_scope(codegen_); + CodeGen::CodeScope code_scope{codegen_}; for (auto *op : steps_) { op->InitializePipelineState(*this, &builder); } @@ -269,10 +278,10 @@ ast::FunctionDecl *Pipeline::GenerateSetupPipelineStateFunction() const { ast::FunctionDecl *Pipeline::GenerateTearDownPipelineStateFunction() const { auto name = GetTearDownPipelineStateFunctionName(); - FunctionBuilder builder(codegen_, name, PipelineParams(), codegen_->Nil()); + FunctionBuilder builder{codegen_, name, GetTeardownPipelineStateParams(), codegen_->Nil()}; { // Request new scope for the function. - CodeGen::CodeScope code_scope(codegen_); + CodeGen::CodeScope code_scope{codegen_}; for (auto *op : steps_) { op->TearDownPipelineState(*this, &builder); } @@ -287,86 +296,161 @@ ast::FunctionDecl *Pipeline::GenerateTearDownPipelineStateFunction() const { return builder.Finish(); } -ast::FunctionDecl *Pipeline::GeneratePipelineWrapperFunction(ast::LambdaExpr *output_callback) const { - auto name = codegen_->MakeIdentifier(CreatePipelineFunctionName("RunAll")); - auto params = compilation_context_->QueryParams(); - auto run_params = params; - if (output_callback != nullptr) { - run_params.push_back(codegen_->MakeField( - output_callback->GetName(), codegen_->LambdaType(output_callback->GetFunctionLiteralExpr()->TypeRepr()))); +/* ---------------------------------------------------------------------------- + Pipeline Generation: RunAll +----------------------------------------------------------------------------- */ + +ast::FunctionDecl *Pipeline::GeneratePipelineRunAllNestedFunction() const { + NOISEPAGE_ASSERT(IsNestedPipeline(), "Should only generate a RunAllNested function in nested pipeline"); + + const ast::Identifier name = GetRunAllNestedPipelineFunctionName(); + FunctionBuilder builder{codegen_, name, GetRunAllNestedPipelineParams(), codegen_->Nil()}; + { + CodeGen::CodeScope code_scope{codegen_}; + + ast::Identifier pipeline_state = codegen_->MakeFreshIdentifier("pipeline_state"); + builder.Append(codegen_->DeclareVarNoInit(pipeline_state, state_.GetType()->TypeRepr())); + auto pipeline_state_ptr = codegen_->AddressOf(pipeline_state); + + NOISEPAGE_ASSERT(builder.GetParameterCount() == 1, "Unexpected parameter count for RunAllNested function"); + auto *query_state_param = builder.GetParameterByPosition(0); + + builder.Append(codegen_->Call(GetInitPipelineFunctionName(), {query_state_param, pipeline_state_ptr})); + builder.Append(codegen_->Call(GetRunPipelineFunctionName(), {query_state_param, pipeline_state_ptr})); + builder.Append(codegen_->Call(GetTeardownPipelineFunctionName(), {query_state_param, pipeline_state_ptr})); } - FunctionBuilder builder(codegen_, name, std::move(run_params), codegen_->Nil()); + + return builder.Finish(); +} + +ast::FunctionDecl *Pipeline::GeneratePipelineRunAllOutputCallbackFunction() const { + NOISEPAGE_ASSERT(HasOutputCallback(), + "Should only generate RunAllOutputCallback function for pipeline with output callback"); + + const ast::Identifier name = GetRunAllOutputCallbackPipelineFunctionName(); + FunctionBuilder builder{codegen_, name, GetRunAllOutputCallbackPipelineParams(), codegen_->Nil()}; + { - CodeGen::CodeScope code_scope(codegen_); - ast::Identifier p_state = codegen_->MakeFreshIdentifier("pipeline_state"); - builder.Append(codegen_->DeclareVarNoInit(p_state, state_.GetType()->TypeRepr())); - auto query_state_param = builder.GetParameterByPosition(0); - auto p_state_ptr = codegen_->AddressOf(p_state); - auto lambda_call = builder.GetParameterByPosition(1); - builder.Append(codegen_->Call(GetInitPipelineFunctionName(), {query_state_param, p_state_ptr})); - builder.Append(codegen_->Call(GetRunPipelineFunctionName(), {query_state_param, p_state_ptr, lambda_call})); - builder.Append(codegen_->Call(GetTeardownPipelineFunctionName(), {query_state_param, p_state_ptr})); + CodeGen::CodeScope code_scope{codegen_}; + + ast::Identifier pipeline_state_id = codegen_->MakeFreshIdentifier("pipelineState"); + builder.Append(codegen_->DeclareVarNoInit(pipeline_state_id, state_.GetType()->TypeRepr())); + + NOISEPAGE_ASSERT(builder.GetParameterCount() == 2, "Unexpected parameter count for RunAllOutputCallback function"); + auto *query_state = builder.GetParameterByPosition(0); + auto *pipeline_state = codegen_->AddressOf(pipeline_state_id); + auto *callback = builder.GetParameterByPosition(1); + + builder.Append(codegen_->Call(GetInitPipelineFunctionName(), {query_state, pipeline_state})); + builder.Append(codegen_->Call(GetRunPipelineFunctionName(), {query_state, pipeline_state, callback})); + builder.Append(codegen_->Call(GetTeardownPipelineFunctionName(), {query_state, pipeline_state})); } return builder.Finish(); } -ast::FunctionDecl *Pipeline::GenerateInitPipelineFunction(ast::LambdaExpr *output_callback) const { +/* ---------------------------------------------------------------------------- + Pipeline Generation: Steps +----------------------------------------------------------------------------- */ + +ast::FunctionDecl *Pipeline::GenerateInitPipelineFunction() const { auto query_state = compilation_context_->GetQueryState(); - auto name = codegen_->MakeIdentifier(CreatePipelineFunctionName("Init")); - auto params = compilation_context_->QueryParams(); - ast::FieldDecl *p_state_ptr = nullptr; - auto &state = GetPipelineStateDescriptor(); - uint32_t p_state_ind = 0; - if (nested_ || output_callback != nullptr) { - p_state_ptr = codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), - codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName()))); - params.push_back(p_state_ptr); - p_state_ind = params.size() - 1; - } - FunctionBuilder builder(codegen_, name, std::move(params), codegen_->Nil()); - { - CodeGen::CodeScope code_scope(codegen_); - // var tls = @execCtxGetTLS(exec_ctx) - ast::Expr *exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); - ast::Identifier tls = codegen_->MakeFreshIdentifier("threadStateContainer"); - builder.Append(codegen_->DeclareVarWithInit(tls, codegen_->ExecCtxGetTLS(exec_ctx))); + const ast::Identifier name = GetInitPipelineFunctionName(); - // @tlsReset(tls, @sizeOf(ThreadState), init, tearDown, queryState) + auto parameters = GetInitPipelineParams(); + FunctionBuilder builder{codegen_, name, std::move(parameters), codegen_->Nil()}; + { + CodeGen::CodeScope code_scope{codegen_}; ast::Expr *state_ptr = query_state->GetStatePointer(codegen_); - if (!nested_ && output_callback == nullptr) { + + if (IsNestedPipeline() || HasOutputCallback()) { + // No TLS reset in nested pipelines + // NOTE: Assumes the pipeline state is always the final parameter + const auto pipeline_state_index = builder.GetParameterCount() - 1; + auto *pipeline_state = builder.GetParameterByPosition(pipeline_state_index); + builder.Append(codegen_->Call(GetInitPipelineStateFunctionName(), {state_ptr, pipeline_state})); + } else { + auto &state = GetPipelineStateDescriptor(); + ast::Expr *exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); + ast::Identifier tls = codegen_->MakeFreshIdentifier("threadStateContainer"); + // var tls = @execCtxGetTLS(exec_ctx) + builder.Append(codegen_->DeclareVarWithInit(tls, codegen_->ExecCtxGetTLS(exec_ctx))); + // @tlsReset(tls, @sizeOf(ThreadState), init, tearDown, queryState) builder.Append(codegen_->TLSReset(codegen_->MakeExpr(tls), state.GetTypeName(), - GetSetupPipelineStateFunctionName(), GetTearDownPipelineStateFunctionName(), + GetInitPipelineStateFunctionName(), GetTearDownPipelineStateFunctionName(), state_ptr)); - } else { - // no TLS reset if pipeline is nested - auto pipeline_state = builder.GetParameterByPosition(p_state_ind); - builder.Append(codegen_->Call(GetSetupPipelineStateFunctionName(), {state_ptr, pipeline_state})); } } return builder.Finish(); } -ast::FunctionDecl *Pipeline::GeneratePipelineWorkFunction(ast::LambdaExpr *output_callback) const { - auto params = PipelineParams(); - for (auto field : extra_pipeline_params_) { - params.push_back(field); - } +ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction() const { + bool started_tracker = false; + const ast::Identifier name = GetRunPipelineFunctionName(); + FunctionBuilder builder{codegen_, name, GetRunPipelineParams(), codegen_->Nil()}; + { + // Begin a new code scope for fresh variables. + CodeGen::CodeScope code_scope{codegen_}; - if (IsParallel()) { - auto additional_params = driver_->GetWorkerParams(); - params.insert(params.end(), additional_params.begin(), additional_params.end()); - } + // TODO(abalakum): This shouldn't actually be dependent on order and the loop can be simplified + // after issue #1154 is fixed + // Let the operators perform some initialization work in this pipeline. + for (auto iter = Begin(), end = End(); iter != end; ++iter) { + (*iter)->BeginPipelineWork(*this, &builder); + } + + // TODO(Kyle): I think this is wrong for nested pipelines / output callbacks + // var pipelineState = @tlsGetCurrentThreadState(...) + auto exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); + auto tls = codegen_->ExecCtxGetTLS(exec_ctx); + auto state_type = GetPipelineStateDescriptor().GetTypeName(); + auto state = codegen_->TLSAccessCurrentThreadState(tls, state_type); + builder.Append(codegen_->DeclareVarWithInit(GetPipelineStateName(), state)); + + // Launch pipeline work. + if (IsParallel()) { + driver_->LaunchWork(&builder, GetPipelineWorkFunctionName()); + } else { + // Serial pipeline + InjectStartResourceTracker(&builder, false); + started_tracker = true; + + std::vector params{codegen_->MakeExpr(GetQueryStateName()), + codegen_->MakeExpr(GetPipelineStateName())}; + if (IsNestedPipeline()) { + const auto run_params = builder.GetParameters(); + auto begin = run_params.cbegin(); + std::advance(begin, params.size()); + params.insert(params.end(), begin, run_params.cend()); + } + + if (HasOutputCallback()) { + params.push_back(codegen_->MakeExpr(GetOutputCallback()->GetName())); + } + + builder.Append(codegen_->Call(GetPipelineWorkFunctionName(), std::move(params))); + } + + // TODO(abalakum): This shouldn't actually be dependent on order and the loop can be simplified + // after issue #1154 is fixed + // Let the operators perform some completion work in this pipeline. + for (auto iter = Begin(), end = End(); iter != end; ++iter) { + (*iter)->FinishPipelineWork(*this, &builder); + } - if (output_callback != nullptr) { - params.push_back(codegen_->MakeField(output_callback->GetName(), - codegen_->LambdaType(output_callback->GetFunctionLiteralExpr()->TypeRepr()))); + if (started_tracker) { + InjectEndResourceTracker(&builder, false); + } } - FunctionBuilder builder(codegen_, GetWorkFunctionName(), std::move(params), codegen_->Nil()); + return builder.Finish(); +} + +ast::FunctionDecl *Pipeline::GeneratePipelineWorkFunction() const { + FunctionBuilder builder{codegen_, GetPipelineWorkFunctionName(), GetPipelineWorkParams(), codegen_->Nil()}; { // Begin a new code scope for fresh variables. - CodeGen::CodeScope code_scope(codegen_); + CodeGen::CodeScope code_scope{codegen_}; if (IsParallel()) { for (auto *op : steps_) { op->BeginParallelPipelineWork(*this, &builder); @@ -390,8 +474,173 @@ ast::FunctionDecl *Pipeline::GeneratePipelineWorkFunction(ast::LambdaExpr *outpu return builder.Finish(); } +ast::FunctionDecl *Pipeline::GenerateTearDownPipelineFunction() const { + const ast::Identifier name = GetTeardownPipelineFunctionName(); + FunctionBuilder builder{codegen_, name, GetTeardownPipelineParams(), codegen_->Nil()}; + { + // Begin a new code scope for fresh variables. + CodeGen::CodeScope code_scope{codegen_}; + if (IsNestedPipeline() || HasOutputCallback()) { + // NOTE: Assumes pipeline state is always final parameter to call + const auto pipeline_state_index = builder.GetParameterCount() - 1; + auto query_state = compilation_context_->GetQueryState()->GetStatePointer(codegen_); + auto pipeline_state = builder.GetParameterByPosition(pipeline_state_index); + auto call = codegen_->Call(GetTearDownPipelineStateFunctionName(), {query_state, pipeline_state}); + builder.Append(codegen_->MakeStmt(call)); + } else { + // Tear down thread local state if parallel pipeline. + ast::Expr *exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); + builder.Append(codegen_->TLSClear(codegen_->ExecCtxGetTLS(exec_ctx))); + auto call = codegen_->CallBuiltin(ast::Builtin::EnsureTrackersStopped, {exec_ctx}); + builder.Append(codegen_->MakeStmt(call)); + } + } + return builder.Finish(); +} + +/* ---------------------------------------------------------------------------- + Pipeline Function Parameter Definition +----------------------------------------------------------------------------- */ + +util::RegionVector Pipeline::GetInitPipelineStateParams() const { + // fun QueryX_PipelineY_InitPipelineState(queryState, pipelineState) + return PipelineParams(); +} + +util::RegionVector Pipeline::GetTeardownPipelineStateParams() const { + // fun QueryX_PipelineY_TeardownPipelineState(queryState, pipelineState) + return PipelineParams(); +} + +util::RegionVector Pipeline::GetRunAllNestedPipelineParams() const { + NOISEPAGE_ASSERT(IsNestedPipeline(), "RunAllNested should only be generated for nested pipelines"); + // fun QueryX_PipelineY_RunAll(queryState) + return QueryParams(); +} + +util::RegionVector Pipeline::GetRunAllOutputCallbackPipelineParams() const { + NOISEPAGE_ASSERT(HasOutputCallback(), + "RunAllOutputCallback should only be generated for pipeline with output callback"); + // fun QueryX_PipelineY_RunAll(queryState, udfLambda) + util::RegionVector params{QueryParams()}; + params.push_back(codegen_->MakeField( + GetOutputCallback()->GetName(), codegen_->LambdaType(GetOutputCallback()->GetFunctionLiteralExpr()->TypeRepr()))); + return params; +} + +util::RegionVector Pipeline::GetInitPipelineParams() const { + /** + * Common Case: + * fun QueryX_PipelineY_Init(queryState) + * + * Nested Pipeline: + * fun QueryX_PipelineY_InitPipeline(queryState, pipelineState) + * + * Output Callback: + * fun QueryX_PipelineY_InitPipeline(queryState, pipelineState) + */ + + util::RegionVector params{QueryParams()}; + if (IsNestedPipeline() || HasOutputCallback()) { + const auto &state = GetPipelineStateDescriptor(); + ast::FieldDecl *pipeline_state_ptr = + codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), + codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName()))); + params.push_back(pipeline_state_ptr); + } + return params; +} + +util::RegionVector Pipeline::GetRunPipelineParams() const { + /** + * Common Case: + * fun QueryX_PipelineY_Run(queryState) + * + * Nested Pipeline: + * fun QueryX_PipelineY_Run(queryState, pipelineState) + * + * Output Callback: + * fun QueryX_PipelineY_Run(queryState, outputCallback) + */ + + util::RegionVector params{QueryParams()}; + + if (IsNestedPipeline() || HasOutputCallback()) { + params.push_back(codegen_->MakeField(GetPipelineStateName(), codegen_->PointerType(state_.GetTypeName()))); + } + + for (auto *field : extra_pipeline_params_) { + params.push_back(field); + } + + if (HasOutputCallback()) { + params.push_back( + codegen_->MakeField(GetOutputCallback()->GetName(), + codegen_->LambdaType(GetOutputCallback()->GetFunctionLiteralExpr()->TypeRepr()))); + } + + return params; +} + +util::RegionVector Pipeline::GetPipelineWorkParams() const { + util::RegionVector params{PipelineParams()}; + for (auto *field : extra_pipeline_params_) { + params.push_back(field); + } + + if (IsParallel()) { + auto additional_params = driver_->GetWorkerParams(); + params.insert(params.end(), additional_params.cbegin(), additional_params.cend()); + } + + if (HasOutputCallback()) { + params.push_back( + codegen_->MakeField(GetOutputCallback()->GetName(), + codegen_->LambdaType(GetOutputCallback()->GetFunctionLiteralExpr()->TypeRepr()))); + } + + return params; +} + +util::RegionVector Pipeline::GetTeardownPipelineParams() const { + /** + * Common Case: + * QueryX_PipelineY_Teardown(queryState) + * + * Nested Pipeline: + * QueryX_PipelineY_Teardown(queryState, pipelineState) + * + * Output Callback: + * QueryX_PipelineY_Teardown(queryState, pipelineState) + */ + util::RegionVector params{QueryParams()}; + if (IsNestedPipeline() || HasOutputCallback()) { + ast::FieldDecl *pipeline_state = + codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), + codegen_->PointerType(codegen_->MakeExpr(GetPipelineStateDescriptor().GetTypeName()))); + params.push_back(pipeline_state); + } + return params; +} + +util::RegionVector Pipeline::QueryParams() const { return compilation_context_->QueryParams(); } + +util::RegionVector Pipeline::PipelineParams() const { + // The main query parameters + util::RegionVector pipeline_params{QueryParams()}; + // Tag on the pipeline state + auto &state = GetPipelineStateDescriptor(); + ast::Expr *pipeline_state = codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName())); + pipeline_params.push_back(codegen_->MakeField(GetPipelineStateName(), pipeline_state)); + return pipeline_params; +} + +/* ---------------------------------------------------------------------------- + Pipeline Calls +----------------------------------------------------------------------------- */ + std::vector Pipeline::CallSingleRunPipelineFunction() const { - NOISEPAGE_ASSERT(!nested_, "can't call a nested pipeline like this"); + NOISEPAGE_ASSERT(!IsNestedPipeline(), "Can't call a nested pipeline like this"); return { codegen_->Call(GetInitPipelineFunctionName(), {compilation_context_->GetQueryState()->GetStatePointer(codegen_)}), codegen_->Call(GetRunPipelineFunctionName(), {compilation_context_->GetQueryState()->GetStatePointer(codegen_)}), @@ -405,10 +654,10 @@ void Pipeline::CallNestedRunPipelineFunction(WorkContext *ctx, const OperatorTra auto p_state = codegen_->MakeFreshIdentifier("nested_state"); auto p_state_ptr = codegen_->AddressOf(p_state); - std::vector params_vec = {compilation_context_->GetQueryState()->GetStatePointer(codegen_)}; + std::vector params_vec{compilation_context_->GetQueryState()->GetStatePointer(codegen_)}; params_vec.push_back(p_state_ptr); - for (size_t i = 0; i < op->GetPlan().GetOutputSchema()->GetColumns().size(); i++) { + for (std::size_t i = 0; i < op->GetPlan().GetOutputSchema()->GetColumns().size(); i++) { params_vec.push_back(codegen_->AddressOf(op->GetOutput(ctx, i))); } @@ -425,7 +674,7 @@ std::vector Pipeline::CallRunPipelineFunction() const { std::vector pipelines; CollectDependencies(&pipelines); for (auto pipeline : pipelines) { - if (!pipeline->nested_ || (pipeline == this)) { + if (!pipeline->IsNestedPipeline() || (pipeline == this)) { for (auto call : CallSingleRunPipelineFunction()) { calls.push_back(call); } @@ -434,6 +683,30 @@ std::vector Pipeline::CallRunPipelineFunction() const { return calls; } +/* ---------------------------------------------------------------------------- + Variable + Function Identifiers +----------------------------------------------------------------------------- */ + +ast::Identifier Pipeline::GetQueryStateName() const { return compilation_context_->GetQueryStateName(); } + +ast::Identifier Pipeline::GetPipelineStateName() const { return codegen_->MakeIdentifier("pipelineState"); } + +ast::Identifier Pipeline::GetInitPipelineStateFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName("InitPipelineState")); +} + +ast::Identifier Pipeline::GetTearDownPipelineStateFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName("TearDownPipelineState")); +} + +ast::Identifier Pipeline::GetRunAllNestedPipelineFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName("RunAllNested")); +} + +ast::Identifier Pipeline::GetRunAllOutputCallbackPipelineFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName("RunAllOutputCallback")); +} + ast::Identifier Pipeline::GetInitPipelineFunctionName() const { return codegen_->MakeIdentifier(CreatePipelineFunctionName("Init")); } @@ -446,146 +719,26 @@ ast::Identifier Pipeline::GetRunPipelineFunctionName() const { return codegen_->MakeIdentifier(CreatePipelineFunctionName("Run")); } -ast::Expr *Pipeline::GetNestedInputArg(const std::size_t index) const { - NOISEPAGE_ASSERT(nested_, "Requested nested input argument on non-nested pipeline"); - NOISEPAGE_ASSERT(index < extra_pipeline_params_.size(), "Requested nested index argument out of range"); - return codegen_->UnaryOp(parsing::Token::Type::STAR, codegen_->MakeExpr(extra_pipeline_params_[index]->Name())); -} - -ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction(query_id_t query_id, ast::LambdaExpr *output_callback) const { - bool started_tracker = false; - auto name = codegen_->MakeIdentifier(CreatePipelineFunctionName("Run")); - auto params = compilation_context_->QueryParams(); - if (nested_ || output_callback != nullptr) { - params.push_back(codegen_->MakeField(state_var_, codegen_->PointerType(state_.GetTypeName()))); - } - for (auto field : extra_pipeline_params_) { - params.push_back(field); - } - if (output_callback != nullptr) { - params.push_back(codegen_->MakeField(output_callback->GetName(), - codegen_->LambdaType(output_callback->GetFunctionLiteralExpr()->TypeRepr()))); - } - FunctionBuilder builder(codegen_, name, std::move(params), codegen_->Nil()); - { - // Begin a new code scope for fresh variables. - CodeGen::CodeScope code_scope(codegen_); - - // TODO(abalakum): This shouldn't actually be dependent on order and the loop can be simplified - // after issue #1154 is fixed - // Let the operators perform some initialization work in this pipeline. - for (auto iter = Begin(), end = End(); iter != end; ++iter) { - (*iter)->BeginPipelineWork(*this, &builder); - } - - // var pipelineState = @tlsGetCurrentThreadState(...) - auto exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); - auto tls = codegen_->ExecCtxGetTLS(exec_ctx); - auto state_type = GetPipelineStateDescriptor().GetTypeName(); - auto state = codegen_->TLSAccessCurrentThreadState(tls, state_type); - builder.Append(codegen_->DeclareVarWithInit(state_var_, state)); - - // Launch pipeline work. - if (IsParallel()) { - driver_->LaunchWork(&builder, GetWorkFunctionName()); - } else { - // SerialWork(queryState, pipelineState) - // if(!nested_) { - InjectStartResourceTracker(&builder, false); - started_tracker = true; - - std::vector args{builder.GetParameterByPosition(0), codegen_->MakeExpr(state_var_)}; - if (nested_) { - size_t i = args.size(); - ast::Expr *arg = builder.GetParameterByPosition(i++); - while (arg != nullptr) { - args.push_back(arg); - arg = builder.GetParameterByPosition(i++); - } - } - if (output_callback != nullptr && !nested_) { - args.push_back(codegen_->MakeExpr(output_callback->GetName())); - } - builder.Append(codegen_->Call(GetWorkFunctionName(), args)); - } - - // TODO(abalakum): This shouldn't actually be dependent on order and the loop can be simplified - // after issue #1154 is fixed - // Let the operators perform some completion work in this pipeline. - for (auto iter = Begin(), end = End(); iter != end; ++iter) { - (*iter)->FinishPipelineWork(*this, &builder); - } - - if (started_tracker) { - InjectEndResourceTracker(&builder, false); - } - } - - return builder.Finish(); +ast::Identifier Pipeline::GetPipelineWorkFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName(IsParallel() ? "ParallelWork" : "SerialWork")); } -ast::FunctionDecl *Pipeline::GenerateTearDownPipelineFunction(ast::LambdaExpr *output_callback) const { - auto name = codegen_->MakeIdentifier(CreatePipelineFunctionName("TearDown")); - auto params = compilation_context_->QueryParams(); - ast::FieldDecl *p_state_ptr = nullptr; - auto &state = GetPipelineStateDescriptor(); - uint32_t p_state_index = 0; - if (nested_ || output_callback != nullptr) { - p_state_ptr = codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), - codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName()))); - params.push_back(p_state_ptr); - p_state_index = params.size() - 1; - } - - FunctionBuilder builder(codegen_, name, std::move(params), codegen_->Nil()); - { - // Begin a new code scope for fresh variables. - CodeGen::CodeScope code_scope(codegen_); - if (!nested_ && output_callback == nullptr) { - // Tear down thread local state if parallel pipeline. - ast::Expr *exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); - builder.Append(codegen_->TLSClear(codegen_->ExecCtxGetTLS(exec_ctx))); - - auto call = codegen_->CallBuiltin(ast::Builtin::EnsureTrackersStopped, {exec_ctx}); - builder.Append(codegen_->MakeStmt(call)); - } else { - auto query_state = compilation_context_->GetQueryState(); - auto state_ptr = query_state->GetStatePointer(codegen_); - - auto pipeline_state = builder.GetParameterByPosition(p_state_index); - auto call = codegen_->Call(GetTearDownPipelineStateFunctionName(), {state_ptr, pipeline_state}); - builder.Append(codegen_->MakeStmt(call)); - } +std::string Pipeline::CreatePipelineFunctionName(const std::string &func_name) const { + auto result = fmt::format("{}_Pipeline{}", compilation_context_->GetFunctionPrefix(), id_); + if (!func_name.empty()) { + result += "_" + func_name; } - return builder.Finish(); + return result; } -void Pipeline::GeneratePipeline(ExecutableQueryFragmentBuilder *builder, query_id_t query_id, - ast::LambdaExpr *output_callback) const { - // Declare the pipeline state. - builder->DeclareStruct(state_.GetType()); - // Generate pipeline state initialization and tear-down functions. - builder->DeclareFunction(GenerateSetupPipelineStateFunction()); - builder->DeclareFunction(GenerateTearDownPipelineStateFunction()); - - // Generate main pipeline logic. - builder->DeclareFunction(GeneratePipelineWorkFunction(output_callback)); - builder->DeclareFunction(GenerateRunPipelineFunction(query_id, output_callback)); - builder->DeclareFunction(GenerateInitPipelineFunction(output_callback)); - auto teardown = GenerateTearDownPipelineFunction(output_callback); - builder->DeclareFunction(teardown); +/* ---------------------------------------------------------------------------- + Additional Helpers +----------------------------------------------------------------------------- */ - // Register the main init, run, tear-down functions as steps, in that order. - if (output_callback != nullptr) { - auto fn = GeneratePipelineWrapperFunction(output_callback); - builder->DeclareFunction(fn); - builder->RegisterStep(fn); - } else if (!nested_) { - builder->RegisterStep(GenerateInitPipelineFunction(output_callback)); - builder->RegisterStep(GenerateRunPipelineFunction(query_id, output_callback)); - builder->RegisterStep(teardown); - } - builder->AddTeardownFn(teardown); +ast::Expr *Pipeline::GetNestedInputArg(const std::size_t index) const { + NOISEPAGE_ASSERT(IsNestedPipeline(), "Requested nested input argument on non-nested pipeline"); + NOISEPAGE_ASSERT(index < extra_pipeline_params_.size(), "Requested nested index argument out of range"); + return codegen_->UnaryOp(parsing::Token::Type::STAR, codegen_->MakeExpr(extra_pipeline_params_[index]->Name())); } } // namespace noisepage::execution::compiler diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index ad2a4f7d45..e45539d408 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -13,6 +13,7 @@ #include "execution/compiler/if.h" #include "execution/compiler/loop.h" #include "execution/exec/execution_settings.h" +#include "execution/vm/bytecode_function_info.h" #include "optimizer/cost_model/trivial_cost_model.h" #include "optimizer/statistics/stats_storage.h" #include "parser/expression/constant_value_expression.h" @@ -56,10 +57,10 @@ void UdfCodegen::GenerateUDF(ast::udf::AbstractAST *ast) { ast->Accept(this); } catalog::type_oid_t UdfCodegen::GetCatalogTypeOidFromSQLType(execution::ast::BuiltinType::Kind type) { switch (type) { - case execution::ast::BuiltinType::Kind::Integer: { + case ast::BuiltinType::Kind::Integer: { return accessor_->GetTypeOidFromTypeId(type::TypeId::INTEGER); } - case execution::ast::BuiltinType::Kind::Boolean: { + case ast::BuiltinType::Kind::Boolean: { return accessor_->GetTypeOidFromTypeId(type::TypeId::BOOLEAN); } default: @@ -68,10 +69,10 @@ catalog::type_oid_t UdfCodegen::GetCatalogTypeOidFromSQLType(execution::ast::Bui } } -execution::ast::File *UdfCodegen::Finish() { - auto fn = fb_->Finish(); - execution::util::RegionVector decls{{fn}, codegen_->GetAstContext()->GetRegion()}; - decls.insert(decls.begin(), aux_decls_.begin(), aux_decls_.end()); +ast::File *UdfCodegen::Finish() { + ast::FunctionDecl *fn = fb_->Finish(); + util::RegionVector decls{{fn}, codegen_->GetAstContext()->GetRegion()}; + decls.insert(decls.begin(), aux_decls_.cbegin(), aux_decls_.cend()); auto file = codegen_->GetAstContext()->GetNodeFactory()->NewFile({0, 0}, std::move(decls)); return file; } @@ -397,9 +398,9 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { // Declare the closure and the query state in the current function auto query_state = codegen_->MakeFreshIdentifier("query_state"); + fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); fb_->Append(codegen_->DeclareVar( lambda_identifier, codegen_->LambdaType(lambda_expr->GetFunctionLiteralExpr()->TypeRepr()), lambda_expr)); - fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); // Set its execution context to whatever execution context was passed in here fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); @@ -409,16 +410,9 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { fb_->Append(codegen_->Assign( codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); - auto function_names = exec_query->GetFunctionNames(); - for (const auto &function_name : function_names) { - if (IsRunFunction(function_name)) { - fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(function_name), - {codegen_->AddressOf(query_state), codegen_->MakeExpr(lambda_identifier)})); - } else { - fb_->Append( - codegen_->Call(codegen_->GetAstContext()->GetIdentifier(function_name), {codegen_->AddressOf(query_state)})); - } - } + // Manually append calls to each function from the compiled + // executable query (implementing the closure) to the builder + CodegenTopLevelCalls(exec_query.get(), query_state, lambda_identifier); fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); } @@ -575,9 +569,9 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { // Declare the closure and the query state in the current function auto query_state = codegen_->MakeFreshIdentifier("query_state"); + fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); fb_->Append(codegen_->DeclareVar( lambda_identifier, codegen_->LambdaType(lambda_expr->GetFunctionLiteralExpr()->TypeRepr()), lambda_expr)); - fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); // Set its execution context to whatever execution context was passed in here fb_->Append(codegen_->CallBuiltin(ast::Builtin::StartNewParams, {exec_ctx})); @@ -595,16 +589,7 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { // Manually append calls to each function from the compiled // executable query (implementing the closure) to the builder - auto function_names = exec_query->GetFunctionNames(); - for (const auto &function_name : function_names) { - if (IsRunFunction(function_name)) { - fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(function_name), - {codegen_->AddressOf(query_state), codegen_->MakeExpr(lambda_identifier)})); - } else { - fb_->Append( - codegen_->Call(codegen_->GetAstContext()->GetIdentifier(function_name), {codegen_->AddressOf(query_state)})); - } - } + CodegenTopLevelCalls(exec_query.get(), query_state, lambda_identifier); fb_->Append(codegen_->CallBuiltin(ast::Builtin::FinishNewParams, {exec_ctx})); } @@ -812,6 +797,42 @@ void UdfCodegen::CodegenBoundVariableInitForRecord(common::ManagedPointerGetFunctionMetadata()) { + const auto &function_name = metadata->GetName(); + if (IsRunAllFunction(function_name)) { + NOISEPAGE_ASSERT(metadata->GetParamsCount() == 2, "Unexpected arity for RunAll function"); + fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(function_name), + {codegen_->AddressOf(query_state_id), codegen_->MakeExpr(lambda_id)})); + } else { + NOISEPAGE_ASSERT(metadata->GetParamsCount() == 1, "Unexpected arity for top-level pipeline function"); + fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(function_name), + {codegen_->AddressOf(query_state_id)})); + } + } +} + /* ---------------------------------------------------------------------------- General Utilities ---------------------------------------------------------------------------- */ @@ -848,8 +869,8 @@ std::unique_ptr UdfCodegen::OptimizeEmbeddedQuery(par } // Static -bool UdfCodegen::IsRunFunction(const std::string &function_name) { - return function_name.find("Run") != std::string::npos; +bool UdfCodegen::IsRunAllFunction(const std::string &name) { + return name.find("RunAllOutputCallback") != std::string::npos; } // Static diff --git a/src/execution/sema/sema_builtin.cpp b/src/execution/sema/sema_builtin.cpp index 3b9bb6ca7a..c3d7fc47d3 100644 --- a/src/execution/sema/sema_builtin.cpp +++ b/src/execution/sema/sema_builtin.cpp @@ -1381,10 +1381,11 @@ void Sema::CheckBuiltinTableIterParCall(ast::CallExpr *call) { // Check the type of the scanner function parameters. See TableVectorIterator::ScanFn. const auto tvi_kind = ast::BuiltinType::TableVectorIterator; const auto ¶ms = scan_fn_type->GetParams(); - if (params.size() != 3 // Scan function has 3 arguments. - || !params[0].type_->IsPointerType() // QueryState, must contain execCtx. - || !params[1].type_->IsPointerType() // Thread state. - || !IsPointerToSpecificBuiltin(params[2].type_, tvi_kind)) { // TableVectorIterator. + + if (params.size() != 3 // Call has 3 parameters + || !params[0].GetType()->IsPointerType() // QueryState* + || !params[1].GetType()->IsPointerType() // PipelineState* + || !IsPointerToSpecificBuiltin(params[2].GetType(), tvi_kind)) { // TableVectorIterator* GetErrorReporter()->Report(call->Position(), ErrorMessages::kBadParallelScanFunction, call_args[5]->GetType()); return; } diff --git a/src/include/execution/compiler/compilation_context.h b/src/include/execution/compiler/compilation_context.h index d92a51ae90..770db3c8b1 100644 --- a/src/include/execution/compiler/compilation_context.h +++ b/src/include/execution/compiler/compilation_context.h @@ -86,16 +86,15 @@ class CompilationContext { */ void Prepare(const parser::AbstractExpression &expression); - /** - * @return The code generator instance. - */ + /** @return The code generator instance. */ CodeGen *GetCodeGen() { return &codegen_; } - /** - * @return The query state. - */ + /** @return The query state. */ StateDescriptor *GetQueryState() { return &query_state_; } + /** @return The identifier for the query state variable */ + ast::Identifier GetQueryStateName() const { return query_state_var_; } + /** * @return The translator for the given relational plan node; null if the provided plan node does * not have a translator registered in this context. @@ -108,30 +107,23 @@ class CompilationContext { */ ExpressionTranslator *LookupTranslator(const parser::AbstractExpression &expr) const; - /** - * @return A common prefix for all functions generated in this module. - */ + /** @return A common prefix for all functions generated in this module. */ std::string GetFunctionPrefix() const; - /** - * @return The list of parameters common to all query functions. For now, just the query state. - */ + /** @return The list of parameters common to all query functions. For now, just the query state. */ util::RegionVector QueryParams() const; - /** - * @return The slot in the query state where the execution context can be found. - */ + /** @return The slot in the query state where the execution context can be found. */ ast::Expr *GetExecutionContextPtrFromQueryState(); - /** - * @return The compilation mode. - */ + /** @return The compilation mode. */ CompilationMode GetCompilationMode() const { return mode_; } - /** - * @return The output callback. - */ - ast::Expr *GetOutputCallback() const { return output_callback_; } + /** @return The output callback. */ + ast::LambdaExpr *GetOutputCallback() const { return output_callback_; } + + /** @return `true` if the compilation context has an output callback, `false` otherwise */ + bool HasOutputCallback() const { return output_callback_ != nullptr; } /** @return True if we should collect counters in TPL, used for Lin's models. */ bool IsCountersEnabled() const { return counters_enabled_; } diff --git a/src/include/execution/compiler/executable_query.h b/src/include/execution/compiler/executable_query.h index 8acd6858d2..51bfa79a45 100644 --- a/src/include/execution/compiler/executable_query.h +++ b/src/include/execution/compiler/executable_query.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -37,6 +38,7 @@ class Region; namespace vm { class Module; class ModuleMetadata; +class FunctionInfo; } // namespace vm } // namespace execution @@ -106,10 +108,16 @@ class ExecutableQuery { */ bool IsCompiled() const { return module_ != nullptr; } + /** @return The functions in the fragment, in program execution order*/ + const std::vector &GetFunctions() const { return functions_; } + /** - * @return The functions in the fragment. + * Get the metatdata for the bytecode function identified by `name`. + * @param name The name of the function to query. + * @return The function metadata for the specified function, + * or empty optional in the event that the function is not present */ - const std::vector &GetFunctions() const { return functions_; } + std::optional GetFunctionMetadata(const std::string &name) const; /** * @return The file. @@ -210,9 +218,12 @@ class ExecutableQuery { /** @return The SQL query string */ common::ManagedPointer GetQueryText() { return query_text_; } - /** @return All of the function names in the executable query. */ + /** @return All of the function names in the executable query, in program execution order. */ std::vector GetFunctionNames() const; + /** @return The metadata for each TPL function in the executable query, in program execution order. */ + std::vector GetFunctionMetadata() const; + /** @return All of the declarations in the executable query. */ std::vector GetDecls() const; diff --git a/src/include/execution/compiler/function_builder.h b/src/include/execution/compiler/function_builder.h index b3af42f0e2..1107f152d4 100644 --- a/src/include/execution/compiler/function_builder.h +++ b/src/include/execution/compiler/function_builder.h @@ -47,9 +47,15 @@ class FunctionBuilder { /** Destructor; invokes FunctionBuilder::Finish() */ ~FunctionBuilder(); + /** @return The arity of the function */ + std::size_t GetParameterCount() const { return params_.size(); } + /** @return A reference to a function parameter by its ordinal position */ ast::Expr *GetParameterByPosition(std::size_t param_idx); + /** @return The expression representation of the parameters to the function */ + std::vector GetParameters() const; + /** * Append a statement to the list of statements in this function. * @param stmt The statement to append diff --git a/src/include/execution/compiler/operator/output_translator.h b/src/include/execution/compiler/operator/output_translator.h index 358aaa9c8e..f580791e26 100644 --- a/src/include/execution/compiler/operator/output_translator.h +++ b/src/include/execution/compiler/operator/output_translator.h @@ -32,13 +32,13 @@ class OutputTranslator : public OperatorTranslator { */ DISALLOW_COPY_AND_MOVE(OutputTranslator); - /** - * Define the output struct. - */ + /** Define the output struct. */ void DefineHelperStructs(util::RegionVector *decls) override; + /** Initialize pipeline state for the output translator */ void InitializePipelineState(const Pipeline &pipeline, FunctionBuilder *function) const override; + /** Teardown pipeline state for the output translator */ void TearDownPipelineState(const Pipeline &pipeline, FunctionBuilder *function) const override; void InitializeCounters(const Pipeline &pipeline, FunctionBuilder *function) const override; @@ -46,26 +46,27 @@ class OutputTranslator : public OperatorTranslator { void EndParallelPipelineWork(const Pipeline &pipeline, FunctionBuilder *function) const override; void FinishPipelineWork(const Pipeline &pipeline, FunctionBuilder *function) const override; - /** - * Perform the main work of the translator. - */ + /** Perform the main work of the translator. */ void PerformPipelineWork(WorkContext *context, FunctionBuilder *function) const override; - /** - * Does not interact with tables. - */ + /** Does not interact with tables. */ ast::Expr *GetTableColumn(catalog::col_oid_t col_oid) const override { UNREACHABLE("Output does not interact with tables."); } + /** @return `true` if the output translator has an associated output callback, `false` otherwise */ + bool HasOutputCallback() const; + private: + /** The output variable */ ast::Identifier output_var_; + /** The output structure */ ast::Identifier output_struct_; - // The number of rows that are output. + /** The number of rows that are output */ StateDescriptor::Entry num_output_; - // The OutputBuffer to use + /** The OutputBuffer to use */ StateDescriptor::Entry output_buffer_; }; diff --git a/src/include/execution/compiler/pipeline.h b/src/include/execution/compiler/pipeline.h index e8ec5e2c14..5793fc3abf 100644 --- a/src/include/execution/compiler/pipeline.h +++ b/src/include/execution/compiler/pipeline.h @@ -52,9 +52,7 @@ class Pipeline { */ enum class Parallelism : uint8_t { Serial = 0, Parallel = 2 }; - /** - * Enum class representing whether the pipeline is vectorized. - */ + /** Enum class representing whether the pipeline is vectorized. */ enum class Vectorization : uint8_t { Disabled = 0, Enabled = 1 }; /** @@ -148,56 +146,43 @@ class Pipeline { /** * Generate all functions to execute this pipeline in the provided container. * @param builder The builder for the executable query container. - * @param query_id The ID of the query for which this pipeline is generated. - * @param output_callback The lambda expression that represents the - * output callback for the pipeline. */ - void GeneratePipeline(ExecutableQueryFragmentBuilder *builder, query_id_t query_id, - ast::LambdaExpr *output_callback = nullptr) const; + void GeneratePipeline(ExecutableQueryFragmentBuilder *builder) const; - /** - * @return True if the pipeline is parallel; false otherwise. - */ + /** @return `true` if the pipeline is parallel, `false` otherwise. */ bool IsParallel() const { return parallelism_ == Parallelism ::Parallel; } - /** - * @return True if this pipeline is fully vectorized; false otherwise. - */ + /** @return `true` if this pipeline is fully vectorized, `false` otherwise. */ bool IsVectorized() const { return false; } - /** - * Typedef used to specify an iterator over the steps in a pipeline. - */ + /** Typedef used to specify an iterator over the steps in a pipeline. */ using StepIterator = std::vector::const_reverse_iterator; - /** - * @return An iterator over the operators in the pipeline. - */ + /** @return An iterator over the operators in the pipeline. */ StepIterator Begin() const { return steps_.rbegin(); } - /** - * @return An iterator positioned at the end of the operators steps in the pipeline. - */ + /** @return An iterator positioned at the end of the operators steps in the pipeline. */ StepIterator End() const { return steps_.rend(); } - /** - * @return True if the given operator is the driver for this pipeline; false otherwise. - */ + /** @return True if the given operator is the driver for this pipeline; false otherwise. */ bool IsDriver(const PipelineDriver *driver) const { return driver == driver_; } - /** - * @return Arguments common to all pipeline functions. - */ + /** @return The arguments common to all pipeline functions. */ util::RegionVector PipelineParams() const; - /** - * @return A unique name for a function local to this pipeline. - */ + /** @return An identifier for the pipeline state variable */ + ast::Identifier GetPipelineStateName() const; + + /** @return A unique name for a function local to this pipeline. */ std::string CreatePipelineFunctionName(const std::string &func_name) const; /** - * @return A vector of expressions that initialize, run and teardown a nested pipeline. + * @return A vector of expressions that do the work of running + * a pipeline function and its associated dependendent operations. */ + std::vector CallRunPipelineFunction() const; + + /** @return A vector of expressions that initialize, run and teardown a nested pipeline. */ std::vector CallSingleRunPipelineFunction() const; /** @@ -208,16 +193,6 @@ class Pipeline { */ void CallNestedRunPipelineFunction(WorkContext *ctx, const OperatorTranslator *op, FunctionBuilder *function) const; - /** - * @return A vector of expressions that do the work of running a pipeline function and its dependencies - */ - std::vector CallRunPipelineFunction() const; - - /** - * @return Pipeline state variable - */ - ast::Identifier GetPipelineStateVar() { return state_var_; } - /** @return The unique ID of this pipeline. */ pipeline_id_t GetPipelineId() const { return pipeline_id_t{id_}; } @@ -235,20 +210,16 @@ class Pipeline { */ void InjectEndResourceTracker(FunctionBuilder *builder, bool is_hook) const; - /** - * @return Query identifier of the query that we are codegen-ing - */ + /** @return The identifier for the query that we are codegen-ing */ query_id_t GetQueryId() const; - /** - * @return A pointer to the OUFeatureVector in the pipeline state - */ + /** @return A pointer to the OUFeatureVector in the pipeline state */ ast::Expr *OUFeatureVecPtr() const { return oufeatures_.GetPtr(codegen_); } /** * Gets an argument from the set of "extra" pipeline arguments given to the current pipeline's function * Only applicable if this is a nested pipeline. Extra refers to arguments other than the query state and the - * pipeline state + * pipeline state. * @param index The extra argument index * @return An expression representing the requested argument */ @@ -257,38 +228,104 @@ class Pipeline { /** @return `true` if this pipeline is prepared, `false` otherwise */ bool IsPrepared() const { return prepared_; } - private: - // Return the thread-local state initialization and tear-down function names. - // This is needed when we invoke @tlsReset() from the pipeline initialization - // function to setup the thread-local state. - ast::Identifier GetSetupPipelineStateFunctionName() const; - ast::Identifier GetTearDownPipelineStateFunctionName() const; - ast::Identifier GetWorkFunctionName() const; + /** @return The output callback for the pipeline, `nullptr` if not present */ + ast::LambdaExpr *GetOutputCallback() const { return output_callback_; } + + /** + * Set the output callback for the pipeline. + * @param output_callback The lambda expression that implements the output callback + */ + void SetOutputCallback(ast::LambdaExpr *output_callback) { output_callback_ = output_callback; } - // Generate a wrapper function for the current pipeline. - ast::FunctionDecl *GeneratePipelineWrapperFunction(ast::LambdaExpr *output_callback) const; + /** @return `true` if this pipeline has an output callback, `false` otherwise */ + bool HasOutputCallback() const { return output_callback_ != nullptr; } - // Generate the pipeline state initialization logic. - ast::FunctionDecl *GenerateSetupPipelineStateFunction() const; + private: + /* -------------------------------------------------------------------------- + Pipeline Function Generation + -------------------------------------------------------------------------- */ - // Generate the pipeline state cleanup logic. + /** + * Generate code to initialize pipeline state. + * @return The function declaration for the generated function + */ + ast::FunctionDecl *GenerateInitPipelineStateFunction() const; + + /** + * Generate code to teardown pipeline state. + * @return The function declaration for the generated function + */ ast::FunctionDecl *GenerateTearDownPipelineStateFunction() const; - // Generate pipeline initialization logic. - ast::FunctionDecl *GenerateInitPipelineFunction(ast::LambdaExpr *output_callback) const; + /** + * Generate code to wrap top-level pipeline calls for nested pipelines. + * @return The function declaration for the generated function + */ + ast::FunctionDecl *GeneratePipelineRunAllNestedFunction() const; + + /** + * Generate code to wrap top-level pipeline calls for pipeline with output callback. + * @return The function declaration for the generated function + */ + ast::FunctionDecl *GeneratePipelineRunAllOutputCallbackFunction() const; + + /** + * Generate code to initialize the pipeline. + * @return The function declaration for the generated function + */ + ast::FunctionDecl *GenerateInitPipelineFunction() const; + + /** + * Generate code to run primary pipeline logic. + * @return The function declaration for the generated function + */ + ast::FunctionDecl *GenerateRunPipelineFunction() const; + + /** + * Generate code to perform pipeline work. + * @return The function declaration for the generated function + */ + ast::FunctionDecl *GeneratePipelineWorkFunction() const; + + /** + * Generate code to teardown the pipeline. + * @return The function declaration for the generated function + */ + ast::FunctionDecl *GenerateTearDownPipelineFunction() const; - // Generate the main pipeline work function. - ast::FunctionDecl *GeneratePipelineWorkFunction(ast::LambdaExpr *output_callback) const; + /* -------------------------------------------------------------------------- + Pipeline Function Parameter Definition + -------------------------------------------------------------------------- */ - // Generate the main pipeline logic. - ast::FunctionDecl *GenerateRunPipelineFunction(query_id_t query_id, ast::LambdaExpr *output_callback) const; + util::RegionVector GetInitPipelineStateParams() const; - // Generate pipeline tear-down logic. - ast::FunctionDecl *GenerateTearDownPipelineFunction(ast::LambdaExpr *output_callback) const; + util::RegionVector GetTeardownPipelineStateParams() const; + + util::RegionVector GetRunAllNestedPipelineParams() const; + + util::RegionVector GetRunAllOutputCallbackPipelineParams() const; + + util::RegionVector GetInitPipelineParams() const; + + util::RegionVector GetRunPipelineParams() const; + + util::RegionVector GetPipelineWorkParams() const; + + util::RegionVector GetTeardownPipelineParams() const; + + /** @return The arguments common to all query functions */ + util::RegionVector QueryParams() const; + + /* -------------------------------------------------------------------------- + Nested Pipelines + -------------------------------------------------------------------------- */ /** @brief Indicate that this pipeline is nested. */ void MarkNested() { nested_ = true; } + /** @return `true` if this is a nested pipeline, `false` otherwise */ + bool IsNestedPipeline() const { return nested_; } + private: // Internals which are exposed for minirunners. friend class compiler::CompilationContext; @@ -297,6 +334,18 @@ class Pipeline { /** @return The vector of pipeline operators that make up the pipeline. */ const std::vector &GetTranslators() const { return steps_; } + /** @return An identifier for the query state variable */ + ast::Identifier GetQueryStateName() const; + + ast::Identifier GetInitPipelineStateFunctionName() const; + ast::Identifier GetTearDownPipelineStateFunctionName() const; + + /** @return An identifier for the pipeline `RunAllNested` function */ + ast::Identifier GetRunAllNestedPipelineFunctionName() const; + + /** @return An identifier for the pipeline `RunAllOutputCallback` function */ + ast::Identifier GetRunAllOutputCallbackPipelineFunctionName() const; + /** @return An identifier for the pipeline `Init` function */ ast::Identifier GetInitPipelineFunctionName() const; @@ -306,15 +355,13 @@ class Pipeline { /** @return An identifier for the pipeline `Teardown` function */ ast::Identifier GetTeardownPipelineFunctionName() const; + ast::Identifier GetPipelineWorkFunctionName() const; + /** @return An immutable reference to the pipeline state descriptor */ const StateDescriptor &GetPipelineStateDescriptor() const { return state_; } - StateDescriptor &GetPipelineStateDescriptor() { return state_; } - /** @return A mutable reference to the pipeline state descriptor */ - void InjectStartPipelineTracker(FunctionBuilder *builder) const; - - void InjectEndResourceTracker(FunctionBuilder *builder, query_id_t query_id) const; + StateDescriptor &GetPipelineStateDescriptor() { return state_; } private: // A unique pipeline ID. @@ -323,8 +370,6 @@ class Pipeline { CompilationContext *compilation_context_; // The code generation instance. CodeGen *codegen_; - // Cache of common identifiers. - ast::Identifier state_var_; // The pipeline state. StateDescriptor state_; // The pipeline operating unit feature vector state. @@ -350,6 +395,8 @@ class Pipeline { bool check_parallelism_; // Whether or not this is a nested pipeline. bool nested_; + // The output callback for the pipeline (`nullptr` if not present) + ast::LambdaExpr *output_callback_{nullptr}; // Whether or not this pipeline is prepared. bool prepared_{false}; }; diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index 136b773c2e..0c6209d88c 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -28,6 +28,14 @@ class VariableRef; namespace noisepage::execution { +namespace compiler { +class ExecutableQuery; +} // namespace compiler + +namespace vm { +class FunctionInfo; +} // namespace vm + // Forward declarations namespace ast::udf { class AbstractAST; @@ -344,6 +352,16 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { void CodegenBoundVariableInitForRecord(common::ManagedPointer plan, const std::string &record_name); + /** + * Generate code to invoke each top-level function in the executable query. + * @param exec_query The executable query for which calls are generated + * @param query_state_id The identifier for the query state + * @param lambda_id The identifier for the lambda expression that + * is used as an output callback in the query + */ + void CodegenTopLevelCalls(const ExecutableQuery *exec_query, ast::Identifier query_state_id, + ast::Identifier lambda_id); + /** * Translate a SQL type to its corresponding catalog type. * @param type The SQL type of interest @@ -388,12 +406,12 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { std::unique_ptr OptimizeEmbeddedQuery(parser::ParseResult *parsed_query); /** - * Determine the function identified by `name` is a top-level run function. - * @param function_name The name of the function - * @return `true` if the function is a top-level run - * function, `false` otherwise + * Determine if the function described by the given metdata is a + * top-level run function that accepts an output callback argument. + * @param function_metatdata The function metadata + * @return `true` if the function meets the above criteria, `false` otherwise */ - static bool IsRunFunction(const std::string &function_name); + static bool IsRunAllFunction(const std::string &name); /** * Get the builtin parameter-add function for the specified parameter type. diff --git a/src/include/execution/sema/error_message.h b/src/include/execution/sema/error_message.h index 0b7db9fd46..f874ac5fee 100644 --- a/src/include/execution/sema/error_message.h +++ b/src/include/execution/sema/error_message.h @@ -74,7 +74,7 @@ namespace sema { F(MissingArrayLength, "missing array length (either compile-time number or '*')", ()) \ F(NotASQLAggregate, "'%0' is not a SQL aggregator type", (ast::Type *)) \ F(BadParallelScanFunction, \ - "parallel scan function must have type (*ExecutionContext, *TableVectorIterator)->nil, " \ + "parallel scan function must have type (*QueryState, *PipelineState, *TableVectorIterator)->nil, " \ "received '%0'", \ (ast::Type *)) \ F(BadHookFunction, \ diff --git a/src/include/execution/vm/module.h b/src/include/execution/vm/module.h index dd52ea2816..c9e19a79a6 100644 --- a/src/include/execution/vm/module.h +++ b/src/include/execution/vm/module.h @@ -62,7 +62,7 @@ class Module { /** * Look up a TPL function in this module by its name * @param name The name of the function to lookup - * @return A pointer to the function's info if it exists; null otherwise + * @return A pointer to the function's info if it exists; `nullptr` otherwise */ const FunctionInfo *GetFuncInfoByName(const std::string &name) const { return bytecode_module_->LookupFuncInfoByName(name); From 1ed88d3eaa3274658e8da0f63dd79a924de8d7b4 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Mon, 2 Aug 2021 16:32:05 -0400 Subject: [PATCH 092/139] first draft of codegen design doc --- docs/design_codegen.md | 133 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 docs/design_codegen.md diff --git a/docs/design_codegen.md b/docs/design_codegen.md new file mode 100644 index 0000000000..a511a206f1 --- /dev/null +++ b/docs/design_codegen.md @@ -0,0 +1,133 @@ +# Design Doc: Execution Engine Code Generation + +### Overview + +As described in the _Execution Engine Design Document_, NoisePage utilizes [data-centric code generation](https://15721.courses.cs.cmu.edu/spring2020/papers/14-compilation/p539-neumann.pdf) to compile the query plans produced by the optimizer to a byetcode representation that is then either interpreted or JIT-compiled. This document describes some of the nuances of the code generation process. While it is a strict subset of the process descibed in the _Execution Engine Design Document_, code generation is a complex topic, and giving it its own document allows us to focus in on the details without getting lost in unrelated concerns from the layers of the execution engine above and below it. + +### Data-Centric Code Generation + +Our goal in code generation is to produce a bytecode program that implements a query plan. + +The straightforward and most common way of accomplishing this is to have each operator in the query plan tree assume responsibility for generating the code that it requires to execute. The complete byetcode program might then be realized by having each operator generate code into a distinct bytecode function and then chaining these functions together via calls from the functions produced by parent operators to those produced by child operators. + +As mentioned above, this approach is straightforward to reason about and to implement. The code generated for each operator is nicely self-contained in a single bytecode function, allowing developers to verify the correctness of the generated code and debug code generation issues. However, the simplicity of this approach comes at the cost of query runtime performance. We now incur function-call overhead in the transition between each operator. More importantly, we leave ourselves open to the same performance issues present in any operator-centric execution model: poor code and data locality resulting from tuple-at-a-time processing among each operator. + +Data-centric code generation is a solution to these performance issues. TODO + +### Pipelines + +TODO + +**Complications** + +Since the original implementation of code generation, we have introduced several features that have required updates to the pipeline interface. Namely: +- Inductive Common Table Expressions (`WITH RECURSIVE` and `WITH ITERATIVE`) which introduce the concept of _nested pipelines_ +- User-Defined Functions which introduce the concept of an _output callback_ + +Both of these aditions, nested pipelines and output callbacks, slightly complicate the code generation process, and this additional complexity is reflected in the pipeline interface, to which we now turn our attention. + +### The Pipeline Interface + +The are several flavors of pipelines within NoisePage that differ slightly in the signature of their top-level bytecode functions, as well as their semantics. In this section, we explain each of these distinct flavors and provide the signatures of each of these top-level functions. + +#### Serial Pipelines + +We begin the discussion with serial pipelines because they are slightly less complicated. + +**State Initialization** + +The interface for the pipeline state initialization functions is the same across all pipeline variants. + +To initialize the pipeline state, we generate: + +``` +fun Query0_Pipeline1_InitPipelineState(*QueryState, *PipelineState) +``` + +and the teardown the pipeline state, we generate: + +``` +fun Query0_Pipeline1_TeardownPipelineState(*QueryState, *PipelineState) +``` + +**Pipline Initialization** + +The interface for pipeline initialization varies depending on the pipeline variant. + +In the common case, we generate: + +``` +fun Query0_Pipeline1_Init(*QueryState) +``` + +In the case of a nested pipeline, we generate: + +``` +fun Query0_Pipeline1_Init(*QueryState, *PipelineState) +``` + +In the case of a pipeline with an output callback, we generate: + +``` +fun Query0_Pipeline1_Init(*QueryState, *PipelineState) +``` + +A pointer to the pipeline state (`*PipelineState`) is provided to the call for nested pipelines and pipelines with output callbacks because in both of these cases the pipeline state associated with the thread running the pipeline is not owned by the pipeline in question. Instead, this pipeline state structure is allocated on the stack at runtime and passed through the bytecode function invocations. + +**Pipeline Run** + +The interface for the pipeline _Run_ function varies depending on the pipeline variant. + +In the common case, we generate: + +``` +fun Query0_Pipeline1_Run(*QueryState) +``` + +In the case of nested pipelines, we generate: + +``` +fun Query0_Pipeline1_Run(*QueryState, *PipelineState) +``` + +In the case of pipelines with an output callback, we generate: + +``` +fun Query0_Pipeline1_Run(*QueryState, *PipelineState, Closure) +``` + +The distinction between pipelines with output callbacks and nested pipelines manifests here. The output callback (in the form of a TPL closure) is provided as a third parameter to the _Run_ function such that it can be invoked by the operators that utilize it (for now, just the `OutputTranslator`) in the body of the pipeline _Work_ function. + +**Pipeline Teardown** + +The interface for pipeline teardown varies depending on the pipeline variant. + +In the common case, we generate: + +``` +fun Query0_Pipeline1_Teardown(*QueryState) +``` + +In the case of a nested pipeline, we generate: + +``` +fun Query0_Pipeline1_Teardown(*QueryState, *PipelineState) +``` + +In the case of a pipeline with an output callback, we generate: + +``` +fun Query0_Pipeline1_Teardown(*QueryState, *PipelineState) +``` + +The reason that a pointer to the pipeline state is provided to the call in the latter two cases is the same as in the case of pipeline initialization. + +#### Parallel Pipelines + +Parallel pipelines require different semantics from serial pipelines. Despite these differences, only the _Work_ function is affected by the change from a serial to a parallel pipeline. + +TODO + +### References + +- [Efficiently Compiling Efficient Query Plans for Modern Hardware](https://15721.courses.cs.cmu.edu/spring2020/papers/14-compilation/p539-neumann.pdf) by Thomas Neumann. The paper that introduced the concept of data-centric code generation, among other techniques now considered standard best-practice in compiling query engines. From bb5b250cfe3156a453f6c947e2e88df18daf8657 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Mon, 2 Aug 2021 17:57:04 -0400 Subject: [PATCH 093/139] remove unnecessary code in pipeline code generation, lint and tidy, all integration tests passing --- src/execution/compiler/pipeline.cpp | 337 +++++++----------- src/execution/compiler/udf/udf_codegen.cpp | 5 +- .../execution/compiler/function_builder.h | 1 + src/include/execution/compiler/pipeline.h | 103 +++--- 4 files changed, 180 insertions(+), 266 deletions(-) diff --git a/src/execution/compiler/pipeline.cpp b/src/execution/compiler/pipeline.cpp index cf7dd65fda..06b39d09a2 100644 --- a/src/execution/compiler/pipeline.cpp +++ b/src/execution/compiler/pipeline.cpp @@ -39,9 +39,9 @@ Pipeline::Pipeline(OperatorTranslator *op, Pipeline::Parallelism parallelism) : } void Pipeline::RegisterStep(OperatorTranslator *op) { - NOISEPAGE_ASSERT(std::count(steps_.begin(), steps_.end(), op) == 0, + NOISEPAGE_ASSERT(std::count(steps_.cbegin(), steps_.cend(), op) == 0, "Duplicate registration of operator in pipeline."); - auto num_steps = steps_.size(); + const auto num_steps = steps_.size(); if (num_steps > 0) { auto last_step = common::ManagedPointer(steps_[num_steps - 1]); // TODO(WAN): MAYDAY CHECK WITH LIN AND PRASHANTH, did ordering of these change? @@ -246,7 +246,7 @@ void Pipeline::GeneratePipeline(ExecutableQueryFragmentBuilder *builder) const { builder->DeclareFunction(teardown); if (HasOutputCallback()) { - auto run_all = GeneratePipelineRunAllOutputCallbackFunction(); + auto run_all = GeneratePipelineRunAllFunction(); builder->DeclareFunction(run_all); builder->RegisterStep(run_all); } else if (!IsNestedPipeline()) { @@ -256,6 +256,10 @@ void Pipeline::GeneratePipeline(ExecutableQueryFragmentBuilder *builder) const { builder->RegisterStep(teardown); } + // For nested pipelines, do not register any of the top-level + // pipeline functions as steps in the Fragment builder, and + // instead rely on a call to Pipeline::CallNestedRunPipeline + builder->AddTeardownFn(teardown); } @@ -264,8 +268,12 @@ void Pipeline::GeneratePipeline(ExecutableQueryFragmentBuilder *builder) const { ----------------------------------------------------------------------------- */ ast::FunctionDecl *Pipeline::GenerateInitPipelineStateFunction() const { + /** + * fun QueryX_PipelineY_InitPipelineState(*QueryState, *PipelineState) + */ + auto name = GetInitPipelineStateFunctionName(); - FunctionBuilder builder{codegen_, name, GetInitPipelineStateParams(), codegen_->Nil()}; + FunctionBuilder builder{codegen_, name, PipelineParams(), codegen_->Nil()}; { // Request new scope for the function. CodeGen::CodeScope code_scope{codegen_}; @@ -277,8 +285,12 @@ ast::FunctionDecl *Pipeline::GenerateInitPipelineStateFunction() const { } ast::FunctionDecl *Pipeline::GenerateTearDownPipelineStateFunction() const { + /** + * fun QueryX_PipelineY_TeardownPipelineState(*QueryState, *PipelineState) + */ + auto name = GetTearDownPipelineStateFunctionName(); - FunctionBuilder builder{codegen_, name, GetTeardownPipelineStateParams(), codegen_->Nil()}; + FunctionBuilder builder{codegen_, name, PipelineParams(), codegen_->Nil()}; { // Request new scope for the function. CodeGen::CodeScope code_scope{codegen_}; @@ -300,35 +312,18 @@ ast::FunctionDecl *Pipeline::GenerateTearDownPipelineStateFunction() const { Pipeline Generation: RunAll ----------------------------------------------------------------------------- */ -ast::FunctionDecl *Pipeline::GeneratePipelineRunAllNestedFunction() const { - NOISEPAGE_ASSERT(IsNestedPipeline(), "Should only generate a RunAllNested function in nested pipeline"); - - const ast::Identifier name = GetRunAllNestedPipelineFunctionName(); - FunctionBuilder builder{codegen_, name, GetRunAllNestedPipelineParams(), codegen_->Nil()}; - { - CodeGen::CodeScope code_scope{codegen_}; - - ast::Identifier pipeline_state = codegen_->MakeFreshIdentifier("pipeline_state"); - builder.Append(codegen_->DeclareVarNoInit(pipeline_state, state_.GetType()->TypeRepr())); - auto pipeline_state_ptr = codegen_->AddressOf(pipeline_state); - - NOISEPAGE_ASSERT(builder.GetParameterCount() == 1, "Unexpected parameter count for RunAllNested function"); - auto *query_state_param = builder.GetParameterByPosition(0); - - builder.Append(codegen_->Call(GetInitPipelineFunctionName(), {query_state_param, pipeline_state_ptr})); - builder.Append(codegen_->Call(GetRunPipelineFunctionName(), {query_state_param, pipeline_state_ptr})); - builder.Append(codegen_->Call(GetTeardownPipelineFunctionName(), {query_state_param, pipeline_state_ptr})); - } - - return builder.Finish(); -} +ast::FunctionDecl *Pipeline::GeneratePipelineRunAllFunction() const { + NOISEPAGE_ASSERT(HasOutputCallback(), "Should only generate RunAll function for pipeline with output callback"); + /** + * fun QueryX_PipelineY_RunAll(*QueryState, Closure) + */ -ast::FunctionDecl *Pipeline::GeneratePipelineRunAllOutputCallbackFunction() const { - NOISEPAGE_ASSERT(HasOutputCallback(), - "Should only generate RunAllOutputCallback function for pipeline with output callback"); + const ast::Identifier name = GetRunAllPipelineFunctionName(); + util::RegionVector params{QueryParams()}; + params.push_back(codegen_->MakeField( + GetOutputCallback()->GetName(), codegen_->LambdaType(GetOutputCallback()->GetFunctionLiteralExpr()->TypeRepr()))); - const ast::Identifier name = GetRunAllOutputCallbackPipelineFunctionName(); - FunctionBuilder builder{codegen_, name, GetRunAllOutputCallbackPipelineParams(), codegen_->Nil()}; + FunctionBuilder builder{codegen_, name, std::move(params), codegen_->Nil()}; { CodeGen::CodeScope code_scope{codegen_}; @@ -336,7 +331,7 @@ ast::FunctionDecl *Pipeline::GeneratePipelineRunAllOutputCallbackFunction() cons ast::Identifier pipeline_state_id = codegen_->MakeFreshIdentifier("pipelineState"); builder.Append(codegen_->DeclareVarNoInit(pipeline_state_id, state_.GetType()->TypeRepr())); - NOISEPAGE_ASSERT(builder.GetParameterCount() == 2, "Unexpected parameter count for RunAllOutputCallback function"); + NOISEPAGE_ASSERT(builder.GetParameterCount() == 2, "Unexpected parameter count for RunAll function"); auto *query_state = builder.GetParameterByPosition(0); auto *pipeline_state = codegen_->AddressOf(pipeline_state_id); auto *callback = builder.GetParameterByPosition(1); @@ -354,11 +349,30 @@ ast::FunctionDecl *Pipeline::GeneratePipelineRunAllOutputCallbackFunction() cons ----------------------------------------------------------------------------- */ ast::FunctionDecl *Pipeline::GenerateInitPipelineFunction() const { + /** + * Common Case: + * fun QueryX_PipelineY_Init(queryState) + * + * Nested Pipeline: + * fun QueryX_PipelineY_InitPipeline(queryState, pipelineState) + * + * Output Callback: + * fun QueryX_PipelineY_InitPipeline(queryState, pipelineState) + */ + auto query_state = compilation_context_->GetQueryState(); const ast::Identifier name = GetInitPipelineFunctionName(); - auto parameters = GetInitPipelineParams(); - FunctionBuilder builder{codegen_, name, std::move(parameters), codegen_->Nil()}; + util::RegionVector params{QueryParams()}; + if (IsNestedPipeline() || HasOutputCallback()) { + const auto &state = GetPipelineStateDescriptor(); + ast::FieldDecl *pipeline_state_ptr = + codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), + codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName()))); + params.push_back(pipeline_state_ptr); + } + + FunctionBuilder builder{codegen_, name, std::move(params), codegen_->Nil()}; { CodeGen::CodeScope code_scope{codegen_}; ast::Expr *state_ptr = query_state->GetStatePointer(codegen_); @@ -385,9 +399,34 @@ ast::FunctionDecl *Pipeline::GenerateInitPipelineFunction() const { } ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction() const { + /** + * Common Case: + * fun QueryX_PipelineY_Run(queryState) + * + * Nested Pipeline: + * fun QueryX_PipelineY_Run(queryState, pipelineState) + * + * Output Callback: + * fun QueryX_PipelineY_Run(queryState, outputCallback) + */ + bool started_tracker = false; const ast::Identifier name = GetRunPipelineFunctionName(); - FunctionBuilder builder{codegen_, name, GetRunPipelineParams(), codegen_->Nil()}; + + util::RegionVector params{QueryParams()}; + if (IsNestedPipeline() || HasOutputCallback()) { + params.push_back(codegen_->MakeField(GetPipelineStateName(), codegen_->PointerType(state_.GetTypeName()))); + } + for (auto *field : extra_pipeline_params_) { + params.push_back(field); + } + if (HasOutputCallback()) { + params.push_back( + codegen_->MakeField(GetOutputCallback()->GetName(), + codegen_->LambdaType(GetOutputCallback()->GetFunctionLiteralExpr()->TypeRepr()))); + } + + FunctionBuilder builder{codegen_, name, std::move(params), codegen_->Nil()}; { // Begin a new code scope for fresh variables. CodeGen::CodeScope code_scope{codegen_}; @@ -428,7 +467,7 @@ ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction() const { params.push_back(codegen_->MakeExpr(GetOutputCallback()->GetName())); } - builder.Append(codegen_->Call(GetPipelineWorkFunctionName(), std::move(params))); + builder.Append(codegen_->Call(GetPipelineWorkFunctionName(), params)); } // TODO(abalakum): This shouldn't actually be dependent on order and the loop can be simplified @@ -447,7 +486,21 @@ ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction() const { } ast::FunctionDecl *Pipeline::GeneratePipelineWorkFunction() const { - FunctionBuilder builder{codegen_, GetPipelineWorkFunctionName(), GetPipelineWorkParams(), codegen_->Nil()}; + util::RegionVector params{PipelineParams()}; + for (auto *field : extra_pipeline_params_) { + params.push_back(field); + } + if (IsParallel()) { + auto additional_params = driver_->GetWorkerParams(); + params.insert(params.end(), additional_params.cbegin(), additional_params.cend()); + } + if (HasOutputCallback()) { + params.push_back( + codegen_->MakeField(GetOutputCallback()->GetName(), + codegen_->LambdaType(GetOutputCallback()->GetFunctionLiteralExpr()->TypeRepr()))); + } + + FunctionBuilder builder{codegen_, GetPipelineWorkFunctionName(), std::move(params), codegen_->Nil()}; { // Begin a new code scope for fresh variables. CodeGen::CodeScope code_scope{codegen_}; @@ -475,8 +528,28 @@ ast::FunctionDecl *Pipeline::GeneratePipelineWorkFunction() const { } ast::FunctionDecl *Pipeline::GenerateTearDownPipelineFunction() const { + /** + * Common Case: + * QueryX_PipelineY_Teardown(queryState) + * + * Nested Pipeline: + * QueryX_PipelineY_Teardown(queryState, pipelineState) + * + * Output Callback: + * QueryX_PipelineY_Teardown(queryState, pipelineState) + */ + const ast::Identifier name = GetTeardownPipelineFunctionName(); - FunctionBuilder builder{codegen_, name, GetTeardownPipelineParams(), codegen_->Nil()}; + + util::RegionVector params{QueryParams()}; + if (IsNestedPipeline() || HasOutputCallback()) { + ast::FieldDecl *pipeline_state = + codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), + codegen_->PointerType(codegen_->MakeExpr(GetPipelineStateDescriptor().GetTypeName()))); + params.push_back(pipeline_state); + } + + FunctionBuilder builder{codegen_, name, std::move(params), codegen_->Nil()}; { // Begin a new code scope for fresh variables. CodeGen::CodeScope code_scope{codegen_}; @@ -502,127 +575,6 @@ ast::FunctionDecl *Pipeline::GenerateTearDownPipelineFunction() const { Pipeline Function Parameter Definition ----------------------------------------------------------------------------- */ -util::RegionVector Pipeline::GetInitPipelineStateParams() const { - // fun QueryX_PipelineY_InitPipelineState(queryState, pipelineState) - return PipelineParams(); -} - -util::RegionVector Pipeline::GetTeardownPipelineStateParams() const { - // fun QueryX_PipelineY_TeardownPipelineState(queryState, pipelineState) - return PipelineParams(); -} - -util::RegionVector Pipeline::GetRunAllNestedPipelineParams() const { - NOISEPAGE_ASSERT(IsNestedPipeline(), "RunAllNested should only be generated for nested pipelines"); - // fun QueryX_PipelineY_RunAll(queryState) - return QueryParams(); -} - -util::RegionVector Pipeline::GetRunAllOutputCallbackPipelineParams() const { - NOISEPAGE_ASSERT(HasOutputCallback(), - "RunAllOutputCallback should only be generated for pipeline with output callback"); - // fun QueryX_PipelineY_RunAll(queryState, udfLambda) - util::RegionVector params{QueryParams()}; - params.push_back(codegen_->MakeField( - GetOutputCallback()->GetName(), codegen_->LambdaType(GetOutputCallback()->GetFunctionLiteralExpr()->TypeRepr()))); - return params; -} - -util::RegionVector Pipeline::GetInitPipelineParams() const { - /** - * Common Case: - * fun QueryX_PipelineY_Init(queryState) - * - * Nested Pipeline: - * fun QueryX_PipelineY_InitPipeline(queryState, pipelineState) - * - * Output Callback: - * fun QueryX_PipelineY_InitPipeline(queryState, pipelineState) - */ - - util::RegionVector params{QueryParams()}; - if (IsNestedPipeline() || HasOutputCallback()) { - const auto &state = GetPipelineStateDescriptor(); - ast::FieldDecl *pipeline_state_ptr = - codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), - codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName()))); - params.push_back(pipeline_state_ptr); - } - return params; -} - -util::RegionVector Pipeline::GetRunPipelineParams() const { - /** - * Common Case: - * fun QueryX_PipelineY_Run(queryState) - * - * Nested Pipeline: - * fun QueryX_PipelineY_Run(queryState, pipelineState) - * - * Output Callback: - * fun QueryX_PipelineY_Run(queryState, outputCallback) - */ - - util::RegionVector params{QueryParams()}; - - if (IsNestedPipeline() || HasOutputCallback()) { - params.push_back(codegen_->MakeField(GetPipelineStateName(), codegen_->PointerType(state_.GetTypeName()))); - } - - for (auto *field : extra_pipeline_params_) { - params.push_back(field); - } - - if (HasOutputCallback()) { - params.push_back( - codegen_->MakeField(GetOutputCallback()->GetName(), - codegen_->LambdaType(GetOutputCallback()->GetFunctionLiteralExpr()->TypeRepr()))); - } - - return params; -} - -util::RegionVector Pipeline::GetPipelineWorkParams() const { - util::RegionVector params{PipelineParams()}; - for (auto *field : extra_pipeline_params_) { - params.push_back(field); - } - - if (IsParallel()) { - auto additional_params = driver_->GetWorkerParams(); - params.insert(params.end(), additional_params.cbegin(), additional_params.cend()); - } - - if (HasOutputCallback()) { - params.push_back( - codegen_->MakeField(GetOutputCallback()->GetName(), - codegen_->LambdaType(GetOutputCallback()->GetFunctionLiteralExpr()->TypeRepr()))); - } - - return params; -} - -util::RegionVector Pipeline::GetTeardownPipelineParams() const { - /** - * Common Case: - * QueryX_PipelineY_Teardown(queryState) - * - * Nested Pipeline: - * QueryX_PipelineY_Teardown(queryState, pipelineState) - * - * Output Callback: - * QueryX_PipelineY_Teardown(queryState, pipelineState) - */ - util::RegionVector params{QueryParams()}; - if (IsNestedPipeline() || HasOutputCallback()) { - ast::FieldDecl *pipeline_state = - codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), - codegen_->PointerType(codegen_->MakeExpr(GetPipelineStateDescriptor().GetTypeName()))); - params.push_back(pipeline_state); - } - return params; -} - util::RegionVector Pipeline::QueryParams() const { return compilation_context_->QueryParams(); } util::RegionVector Pipeline::PipelineParams() const { @@ -636,51 +588,36 @@ util::RegionVector Pipeline::PipelineParams() const { } /* ---------------------------------------------------------------------------- - Pipeline Calls + Nested Pipelines ----------------------------------------------------------------------------- */ -std::vector Pipeline::CallSingleRunPipelineFunction() const { - NOISEPAGE_ASSERT(!IsNestedPipeline(), "Can't call a nested pipeline like this"); - return { - codegen_->Call(GetInitPipelineFunctionName(), {compilation_context_->GetQueryState()->GetStatePointer(codegen_)}), - codegen_->Call(GetRunPipelineFunctionName(), {compilation_context_->GetQueryState()->GetStatePointer(codegen_)}), - codegen_->Call(GetTeardownPipelineFunctionName(), - {compilation_context_->GetQueryState()->GetStatePointer(codegen_)})}; -} - void Pipeline::CallNestedRunPipelineFunction(WorkContext *ctx, const OperatorTranslator *op, FunctionBuilder *function) const { - std::vector stmts; - auto p_state = codegen_->MakeFreshIdentifier("nested_state"); - auto p_state_ptr = codegen_->AddressOf(p_state); - - std::vector params_vec{compilation_context_->GetQueryState()->GetStatePointer(codegen_)}; - params_vec.push_back(p_state_ptr); + std::vector stmts{}; + auto pipeline_state = codegen_->MakeFreshIdentifier("nested_state"); + auto pipeline_state_ptr = codegen_->AddressOf(pipeline_state); + // Populate the parameters passed to the Run function for the nested pipeline + std::vector run_parameters{compilation_context_->GetQueryState()->GetStatePointer(codegen_), + pipeline_state_ptr}; for (std::size_t i = 0; i < op->GetPlan().GetOutputSchema()->GetColumns().size(); i++) { - params_vec.push_back(codegen_->AddressOf(op->GetOutput(ctx, i))); + run_parameters.push_back(codegen_->AddressOf(op->GetOutput(ctx, i))); } - function->Append(codegen_->DeclareVarNoInit(p_state, codegen_->MakeExpr(GetPipelineStateDescriptor().GetTypeName()))); - function->Append(codegen_->Call(GetInitPipelineFunctionName(), - {compilation_context_->GetQueryState()->GetStatePointer(codegen_), p_state_ptr})); - function->Append(codegen_->Call(GetRunPipelineFunctionName(), params_vec)); - function->Append(codegen_->Call(GetTeardownPipelineFunctionName(), - {compilation_context_->GetQueryState()->GetStatePointer(codegen_), p_state_ptr})); -} + // Declare a local pipeline state variable + function->Append( + codegen_->DeclareVarNoInit(pipeline_state, codegen_->MakeExpr(GetPipelineStateDescriptor().GetTypeName()))); -std::vector Pipeline::CallRunPipelineFunction() const { - std::vector calls; - std::vector pipelines; - CollectDependencies(&pipelines); - for (auto pipeline : pipelines) { - if (!pipeline->IsNestedPipeline() || (pipeline == this)) { - for (auto call : CallSingleRunPipelineFunction()) { - calls.push_back(call); - } - } - } - return calls; + // call QueryX_PipelineY_Init(*QueryState, *PipelineState) + function->Append( + codegen_->Call(GetInitPipelineFunctionName(), + {compilation_context_->GetQueryState()->GetStatePointer(codegen_), pipeline_state_ptr})); + // call QueryX_PipelineY_Run(*QueryState, *PipelineState, ...) + function->Append(codegen_->Call(GetRunPipelineFunctionName(), run_parameters)); + // call QueryX_PipelineY_Teardown(*QueryState, *PipelineState) + function->Append( + codegen_->Call(GetTeardownPipelineFunctionName(), + {compilation_context_->GetQueryState()->GetStatePointer(codegen_), pipeline_state_ptr})); } /* ---------------------------------------------------------------------------- @@ -699,12 +636,8 @@ ast::Identifier Pipeline::GetTearDownPipelineStateFunctionName() const { return codegen_->MakeIdentifier(CreatePipelineFunctionName("TearDownPipelineState")); } -ast::Identifier Pipeline::GetRunAllNestedPipelineFunctionName() const { - return codegen_->MakeIdentifier(CreatePipelineFunctionName("RunAllNested")); -} - -ast::Identifier Pipeline::GetRunAllOutputCallbackPipelineFunctionName() const { - return codegen_->MakeIdentifier(CreatePipelineFunctionName("RunAllOutputCallback")); +ast::Identifier Pipeline::GetRunAllPipelineFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName("RunAll")); } ast::Identifier Pipeline::GetInitPipelineFunctionName() const { diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index e45539d408..add41f2e68 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -25,6 +25,9 @@ namespace noisepage::execution::compiler::udf { +/** The identifier for the pipeline `RunAll` function */ +constexpr static const char RUN_ALL_IDENTIFIER[] = "RunAll"; + UdfCodegen::UdfCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, ast::udf::UdfAstContext *udf_ast_context, CodeGen *codegen, catalog::db_oid_t db_oid) : accessor_{accessor}, @@ -870,7 +873,7 @@ std::unique_ptr UdfCodegen::OptimizeEmbeddedQuery(par // Static bool UdfCodegen::IsRunAllFunction(const std::string &name) { - return name.find("RunAllOutputCallback") != std::string::npos; + return name.find(RUN_ALL_IDENTIFIER) != std::string::npos; } // Static diff --git a/src/include/execution/compiler/function_builder.h b/src/include/execution/compiler/function_builder.h index 1107f152d4..50269ffce6 100644 --- a/src/include/execution/compiler/function_builder.h +++ b/src/include/execution/compiler/function_builder.h @@ -3,6 +3,7 @@ #include #include #include +#include #include "common/macros.h" #include "execution/ast/identifier.h" diff --git a/src/include/execution/compiler/pipeline.h b/src/include/execution/compiler/pipeline.h index 5793fc3abf..646e3360b1 100644 --- a/src/include/execution/compiler/pipeline.h +++ b/src/include/execution/compiler/pipeline.h @@ -176,15 +176,6 @@ class Pipeline { /** @return A unique name for a function local to this pipeline. */ std::string CreatePipelineFunctionName(const std::string &func_name) const; - /** - * @return A vector of expressions that do the work of running - * a pipeline function and its associated dependendent operations. - */ - std::vector CallRunPipelineFunction() const; - - /** @return A vector of expressions that initialize, run and teardown a nested pipeline. */ - std::vector CallSingleRunPipelineFunction() const; - /** * Calls a nested pipeline's execution functions * @param ctx Workcontext that we are using to run on @@ -241,6 +232,10 @@ class Pipeline { bool HasOutputCallback() const { return output_callback_ != nullptr; } private: + // Internals which are exposed for minirunners. + friend class compiler::CompilationContext; + friend class selfdriving::OperatingUnitRecorder; + /* -------------------------------------------------------------------------- Pipeline Function Generation -------------------------------------------------------------------------- */ @@ -258,16 +253,11 @@ class Pipeline { ast::FunctionDecl *GenerateTearDownPipelineStateFunction() const; /** - * Generate code to wrap top-level pipeline calls for nested pipelines. + * Generate code to wrap top-level pipeline calls. + * NOTE: Currently only used for pipelines with output callback. * @return The function declaration for the generated function */ - ast::FunctionDecl *GeneratePipelineRunAllNestedFunction() const; - - /** - * Generate code to wrap top-level pipeline calls for pipeline with output callback. - * @return The function declaration for the generated function - */ - ast::FunctionDecl *GeneratePipelineRunAllOutputCallbackFunction() const; + ast::FunctionDecl *GeneratePipelineRunAllFunction() const; /** * Generate code to initialize the pipeline. @@ -297,22 +287,6 @@ class Pipeline { Pipeline Function Parameter Definition -------------------------------------------------------------------------- */ - util::RegionVector GetInitPipelineStateParams() const; - - util::RegionVector GetTeardownPipelineStateParams() const; - - util::RegionVector GetRunAllNestedPipelineParams() const; - - util::RegionVector GetRunAllOutputCallbackPipelineParams() const; - - util::RegionVector GetInitPipelineParams() const; - - util::RegionVector GetRunPipelineParams() const; - - util::RegionVector GetPipelineWorkParams() const; - - util::RegionVector GetTeardownPipelineParams() const; - /** @return The arguments common to all query functions */ util::RegionVector QueryParams() const; @@ -326,25 +300,21 @@ class Pipeline { /** @return `true` if this is a nested pipeline, `false` otherwise */ bool IsNestedPipeline() const { return nested_; } - private: - // Internals which are exposed for minirunners. - friend class compiler::CompilationContext; - friend class selfdriving::OperatingUnitRecorder; - - /** @return The vector of pipeline operators that make up the pipeline. */ - const std::vector &GetTranslators() const { return steps_; } + /* -------------------------------------------------------------------------- + Pipeline Variable and Function Identifiers + -------------------------------------------------------------------------- */ /** @return An identifier for the query state variable */ ast::Identifier GetQueryStateName() const; + /** @return An identifier for the `InitPipelineState` function */ ast::Identifier GetInitPipelineStateFunctionName() const; - ast::Identifier GetTearDownPipelineStateFunctionName() const; - /** @return An identifier for the pipeline `RunAllNested` function */ - ast::Identifier GetRunAllNestedPipelineFunctionName() const; + /** @return An identifier for the `TeardownPipelineState` function */ + ast::Identifier GetTearDownPipelineStateFunctionName() const; - /** @return An identifier for the pipeline `RunAllOutputCallback` function */ - ast::Identifier GetRunAllOutputCallbackPipelineFunctionName() const; + /** @return An identifier for the pipeline `RunAll` function */ + ast::Identifier GetRunAllPipelineFunctionName() const; /** @return An identifier for the pipeline `Init` function */ ast::Identifier GetInitPipelineFunctionName() const; @@ -355,6 +325,7 @@ class Pipeline { /** @return An identifier for the pipeline `Teardown` function */ ast::Identifier GetTeardownPipelineFunctionName() const; + /** @return An identifier for the pipeline `Work` function (serial or parallel) */ ast::Identifier GetPipelineWorkFunctionName() const; /** @return An immutable reference to the pipeline state descriptor */ @@ -363,41 +334,47 @@ class Pipeline { /** @return A mutable reference to the pipeline state descriptor */ StateDescriptor &GetPipelineStateDescriptor() { return state_; } + /* -------------------------------------------------------------------------- + Additional Helpers + -------------------------------------------------------------------------- */ + + /** @return The vector of pipeline operators that make up the pipeline. */ + const std::vector &GetTranslators() const { return steps_; } + private: - // A unique pipeline ID. + /** A unique pipeline ID. */ uint32_t id_; - // The compilation context this pipeline is part of. + /** The compilation context this pipeline is part of. */ CompilationContext *compilation_context_; - // The code generation instance. + /** The code generation instance. */ CodeGen *codegen_; - // The pipeline state. + /** The pipeline state. */ StateDescriptor state_; - // The pipeline operating unit feature vector state. + /** The pipeline operating unit feature vector state. */ StateDescriptor::Entry oufeatures_; - // Operators making up the pipeline. + /** Operators making up the pipeline. */ std::vector steps_; - // The driver. + /** The driver. */ PipelineDriver *driver_; - // pointer to parent pipeline (only applicable if this is a nested pipeline) + /** pointer to parent pipeline (only applicable if this is a nested pipeline) */ Pipeline *parent_; - // Expressions participating in the pipeline. + /** Expressions participating in the pipeline. */ std::vector expressions_; - // All unnested pipelines this one depends on completion of. + /** All unnested pipelines this one depends on completion of. */ std::vector dependencies_; - // Vector of pipelines that are nested under this pipeline + /** Vector of pipelines that are nested under this pipeline. */ std::vector nested_pipelines_; - // Extra parameters to pass into pipeline; - // currently used for nested consumer pipeline work functions + /** Extra parameters to passed into pipeline functions; used for nested consumer pipeline work. */ std::vector extra_pipeline_params_; - // Configured parallelism. + /** Configured parallelism. */ Parallelism parallelism_; - // Whether to check for parallelism in new pipeline elements. + /** Whether to check for parallelism in new pipeline elements. */ bool check_parallelism_; - // Whether or not this is a nested pipeline. + /** Whether or not this is a nested pipeline. */ bool nested_; - // The output callback for the pipeline (`nullptr` if not present) + /** The output callback for the pipeline (`nullptr` if not present) */ ast::LambdaExpr *output_callback_{nullptr}; - // Whether or not this pipeline is prepared. + /** Whether or not this pipeline is prepared. */ bool prepared_{false}; }; From bc15102592c12df19515667b286a6705e5657584 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Mon, 2 Aug 2021 23:10:25 -0400 Subject: [PATCH 094/139] some basic integration tests for embedded sql queries, need to cover more complex cases, also currently works but with parallelism limited for UDF pipelines --- script/testing/junit/sql/udf.sql | 43 +++++++++++ script/testing/junit/traces/udf.test | 75 +++++++++++++++++++ .../compiler/operator/output_translator.cpp | 6 -- src/execution/compiler/pipeline.cpp | 48 ++++++++---- src/include/execution/compiler/pipeline.h | 14 +--- src/parser/udf/plpgsql_parser.cpp | 1 + 6 files changed, 154 insertions(+), 33 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 261f14cd40..09731255d5 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -151,6 +151,49 @@ $$ LANGUAGE PLPGSQL; SELECT sql_select_mutliple_constants(); + +-- ---------------------------------------------------------------------------- +-- sql_embedded_agg_count() + +CREATE FUNCTION sql_embedded_agg_count() RETURNS INT AS $$ \ +DECLARE \ + v INT; \ +BEGIN \ + SELECT COUNT(*) FROM integers INTO v; \ + RETURN v; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT sql_embedded_agg_count(); + +-- ---------------------------------------------------------------------------- +-- sql_embedded_agg_min() + +CREATE FUNCTION sql_embedded_agg_min() RETURNS INT AS $$ \ +DECLARE \ + v INT; \ +BEGIN \ + SELECT MIN(x) FROM integers INTO v; \ + RETURN v; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT sql_embedded_agg_min(); + +-- ---------------------------------------------------------------------------- +-- sql_embedded_agg_max() + +CREATE FUNCTION sql_embedded_agg_max() RETURNS INT AS $$ \ +DECLARE \ + v INT; \ +BEGIN \ + SELECT MAX(x) FROM integers INTO v; \ + RETURN v; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT sql_embedded_agg_max(); + -- ---------------------------------------------------------------------------- -- proc_fors() diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index afcca9e69d..681a412345 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -318,6 +318,81 @@ SELECT sql_select_mutliple_constants(); statement ok +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- sql_embedded_agg_count() + +statement ok + + +statement ok +CREATE FUNCTION sql_embedded_agg_count() RETURNS INT AS $$ DECLARE v INT; BEGIN SELECT COUNT(*) FROM integers INTO v; RETURN v; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT sql_embedded_agg_count(); +---- +3 + + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- sql_embedded_agg_min() + +statement ok + + +statement ok +CREATE FUNCTION sql_embedded_agg_min() RETURNS INT AS $$ DECLARE v INT; BEGIN SELECT MIN(x) FROM integers INTO v; RETURN v; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT sql_embedded_agg_min(); +---- +1 + + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- sql_embedded_agg_max() + +statement ok + + +statement ok +CREATE FUNCTION sql_embedded_agg_max() RETURNS INT AS $$ DECLARE v INT; BEGIN SELECT MAX(x) FROM integers INTO v; RETURN v; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT sql_embedded_agg_max(); +---- +3 + + +statement ok + + statement ok -- ---------------------------------------------------------------------------- diff --git a/src/execution/compiler/operator/output_translator.cpp b/src/execution/compiler/operator/output_translator.cpp index 4f5b362655..990b341755 100644 --- a/src/execution/compiler/operator/output_translator.cpp +++ b/src/execution/compiler/operator/output_translator.cpp @@ -26,12 +26,6 @@ OutputTranslator::OutputTranslator(const planner::AbstractPlanNode &plan, Compil output_buffer_ = pipeline->DeclarePipelineStateEntry( "output_buffer", GetCodeGen()->PointerType(GetCodeGen()->BuiltinType(ast::BuiltinType::OutputBuffer))); num_output_ = CounterDeclare("num_output", pipeline); - - // If the compilation context contains an output callback, - // the output translator injects the callback into its pipeline - if (compilation_context->HasOutputCallback()) { - pipeline->SetOutputCallback(compilation_context->GetOutputCallback()); - } } void OutputTranslator::InitializePipelineState(const Pipeline &pipeline, FunctionBuilder *function) const { diff --git a/src/execution/compiler/pipeline.cpp b/src/execution/compiler/pipeline.cpp index 06b39d09a2..e20d21afad 100644 --- a/src/execution/compiler/pipeline.cpp +++ b/src/execution/compiler/pipeline.cpp @@ -31,10 +31,16 @@ Pipeline::Pipeline(CompilationContext *ctx) driver_(nullptr), parallelism_(Parallelism::Parallel), check_parallelism_(true), - nested_(false) {} + nested_(false) { + if (HasOutputCallback()) { + UpdateParallelism(Parallelism::Serial); + } +} Pipeline::Pipeline(OperatorTranslator *op, Pipeline::Parallelism parallelism) : Pipeline(op->GetCompilationContext()) { - UpdateParallelism(parallelism); + if (!HasOutputCallback()) { + UpdateParallelism(parallelism); + } RegisterStep(op); } @@ -366,9 +372,8 @@ ast::FunctionDecl *Pipeline::GenerateInitPipelineFunction() const { util::RegionVector params{QueryParams()}; if (IsNestedPipeline() || HasOutputCallback()) { const auto &state = GetPipelineStateDescriptor(); - ast::FieldDecl *pipeline_state_ptr = - codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), - codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName()))); + ast::FieldDecl *pipeline_state_ptr = codegen_->MakeField( + codegen_->MakeFreshIdentifier("pipelineState"), codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName()))); params.push_back(pipeline_state_ptr); } @@ -438,13 +443,16 @@ ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction() const { (*iter)->BeginPipelineWork(*this, &builder); } - // TODO(Kyle): I think this is wrong for nested pipelines / output callbacks - // var pipelineState = @tlsGetCurrentThreadState(...) - auto exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); - auto tls = codegen_->ExecCtxGetTLS(exec_ctx); - auto state_type = GetPipelineStateDescriptor().GetTypeName(); - auto state = codegen_->TLSAccessCurrentThreadState(tls, state_type); - builder.Append(codegen_->DeclareVarWithInit(GetPipelineStateName(), state)); + // Nested pipelines and pipelines with callbacks have their + // pipeline state passed as an argument to this function + if (!IsNestedPipeline() && !HasOutputCallback()) { + // var pipelineState = @tlsGetCurrentThreadState(...) + auto exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); + auto tls = codegen_->ExecCtxGetTLS(exec_ctx); + auto state_type = GetPipelineStateDescriptor().GetTypeName(); + auto state = codegen_->TLSAccessCurrentThreadState(tls, state_type); + builder.Append(codegen_->DeclareVarWithInit(GetPipelineStateName(), state)); + } // Launch pipeline work. if (IsParallel()) { @@ -490,11 +498,12 @@ ast::FunctionDecl *Pipeline::GeneratePipelineWorkFunction() const { for (auto *field : extra_pipeline_params_) { params.push_back(field); } + + // NOTE(Kyle): This is hacky... if (IsParallel()) { auto additional_params = driver_->GetWorkerParams(); params.insert(params.end(), additional_params.cbegin(), additional_params.cend()); - } - if (HasOutputCallback()) { + } else if (HasOutputCallback()) { params.push_back( codegen_->MakeField(GetOutputCallback()->GetName(), codegen_->LambdaType(GetOutputCallback()->GetFunctionLiteralExpr()->TypeRepr()))); @@ -544,7 +553,7 @@ ast::FunctionDecl *Pipeline::GenerateTearDownPipelineFunction() const { util::RegionVector params{QueryParams()}; if (IsNestedPipeline() || HasOutputCallback()) { ast::FieldDecl *pipeline_state = - codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), + codegen_->MakeField(codegen_->MakeFreshIdentifier("pipelineState"), codegen_->PointerType(codegen_->MakeExpr(GetPipelineStateDescriptor().GetTypeName()))); params.push_back(pipeline_state); } @@ -594,7 +603,7 @@ util::RegionVector Pipeline::PipelineParams() const { void Pipeline::CallNestedRunPipelineFunction(WorkContext *ctx, const OperatorTranslator *op, FunctionBuilder *function) const { std::vector stmts{}; - auto pipeline_state = codegen_->MakeFreshIdentifier("nested_state"); + auto pipeline_state = codegen_->MakeFreshIdentifier("nestedPipelineState"); auto pipeline_state_ptr = codegen_->AddressOf(pipeline_state); // Populate the parameters passed to the Run function for the nested pipeline @@ -674,4 +683,11 @@ ast::Expr *Pipeline::GetNestedInputArg(const std::size_t index) const { return codegen_->UnaryOp(parsing::Token::Type::STAR, codegen_->MakeExpr(extra_pipeline_params_[index]->Name())); } +ast::LambdaExpr *Pipeline::GetOutputCallback() const { + NOISEPAGE_ASSERT(HasOutputCallback(), "Attempt to get nonexistent output callback"); + return compilation_context_->GetOutputCallback(); +} + +bool Pipeline::HasOutputCallback() const { return compilation_context_->HasOutputCallback(); } + } // namespace noisepage::execution::compiler diff --git a/src/include/execution/compiler/pipeline.h b/src/include/execution/compiler/pipeline.h index 646e3360b1..24c45e689b 100644 --- a/src/include/execution/compiler/pipeline.h +++ b/src/include/execution/compiler/pipeline.h @@ -219,17 +219,11 @@ class Pipeline { /** @return `true` if this pipeline is prepared, `false` otherwise */ bool IsPrepared() const { return prepared_; } - /** @return The output callback for the pipeline, `nullptr` if not present */ - ast::LambdaExpr *GetOutputCallback() const { return output_callback_; } - - /** - * Set the output callback for the pipeline. - * @param output_callback The lambda expression that implements the output callback - */ - void SetOutputCallback(ast::LambdaExpr *output_callback) { output_callback_ = output_callback; } + /** @return The output callback for the pipeline */ + ast::LambdaExpr *GetOutputCallback() const; /** @return `true` if this pipeline has an output callback, `false` otherwise */ - bool HasOutputCallback() const { return output_callback_ != nullptr; } + bool HasOutputCallback() const; private: // Internals which are exposed for minirunners. @@ -372,8 +366,6 @@ class Pipeline { bool check_parallelism_; /** Whether or not this is a nested pipeline. */ bool nested_; - /** The output callback for the pipeline (`nullptr` if not present) */ - ast::LambdaExpr *output_callback_{nullptr}; /** Whether or not this pipeline is prepared. */ bool prepared_{false}; }; diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index 484b0bbe47..5a3303c11a 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -185,6 +185,7 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo } if ((type == DECL_TYPE_ID_DOUBLE) || (type == DECL_TYPE_ID_NUMERIC)) { // TODO(Kyle): type.rfind("numeric") + // TODO(Kyle): Should this support FLOAT and DECMIAL as well?? udf_ast_context_->SetVariableType(var_name, type::TypeId::DECIMAL); return std::make_unique(var_name, type::TypeId::DECIMAL, std::move(initial)); } From b6df7216737430505c4f515d49a5fab2f69cc319 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 5 Aug 2021 15:21:02 -0400 Subject: [PATCH 095/139] add support for direct assignment of query results to scalar values, found a bug in implementation that results from a failure to disambiguate global structures --- script/testing/junit/sql/udf.sql | 31 ++ script/testing/junit/traces/udf.test | 36 ++ .../parser/expression/subquery_expression.h | 5 +- src/include/parser/select_statement.h | 2 +- src/include/parser/udf/plpgsql_parser.h | 88 +++- src/parser/udf/plpgsql_parser.cpp | 381 +++++++++++++----- 6 files changed, 444 insertions(+), 99 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 09731255d5..b09990b0c5 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -151,6 +151,21 @@ $$ LANGUAGE PLPGSQL; SELECT sql_select_mutliple_constants(); +-- ---------------------------------------------------------------------------- +-- sql_select_constant_assignment() + +CREATE FUNCTION sql_select_constant_assignment() RETURNS INT AS $$ \ +DECLARE \ + x INT; \ + y INT; \ +BEGIN \ + x = (SELECT 1); \ + y = (SELECT 2); \ + RETURN x + y; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT sql_select_constant_assignment(); -- ---------------------------------------------------------------------------- -- sql_embedded_agg_count() @@ -194,6 +209,22 @@ $$ LANGUAGE PLPGSQL; SELECT sql_embedded_agg_max(); +-- ---------------------------------------------------------------------------- +-- sql_embedded_agg_multi() + +-- CREATE FUNCTION sql_embedded_agg_multi() RETURNS INT AS $$ \ +-- DECLARE \ +-- s INT; \ +-- minimum INT; \ +-- maximum INT; \ +-- BEGIN \ +-- minimum = (SELECT MIN(x) FROM integers); \ +-- maximum = (SELECT MAX(x) FROM integers); \ +-- s = minumum + maximum; \ +-- RETURN s; \ +-- END; \ +-- $$ LANGUAGE PLPGSQL; + -- ---------------------------------------------------------------------------- -- proc_fors() diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 681a412345..c97aad3d27 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -318,6 +318,27 @@ SELECT sql_select_mutliple_constants(); statement ok +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- sql_select_constant_assignment() + +statement ok + + +statement ok +CREATE FUNCTION sql_select_constant_assignment() RETURNS INT AS $$ DECLARE x INT; y INT; BEGIN x = (SELECT 1); y = (SELECT 2); RETURN x + y; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT sql_select_constant_assignment(); +---- +3 + + statement ok @@ -393,6 +414,21 @@ SELECT sql_embedded_agg_max(); statement ok +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- sql_embedded_agg_multi() + +statement ok + + +statement ok +-- CREATE FUNCTION sql_embedded_agg_multi() RETURNS INT AS $$ -- DECLARE -- s INT; -- minimum INT; -- maximum INT; -- BEGIN -- minimum = (SELECT MIN(x) FROM integers); -- maximum = (SELECT MAX(x) FROM integers); -- s = minumum + maximum; -- RETURN s; -- END; -- $$ LANGUAGE PLPGSQL; + +statement ok + + statement ok -- ---------------------------------------------------------------------------- diff --git a/src/include/parser/expression/subquery_expression.h b/src/include/parser/expression/subquery_expression.h index 295ab123c1..ee1b4b685b 100644 --- a/src/include/parser/expression/subquery_expression.h +++ b/src/include/parser/expression/subquery_expression.h @@ -42,9 +42,12 @@ class SubqueryExpression : public AbstractExpression { return Copy(); } - /** @return managed pointer to the sub-select */ + /** @return A non-owning pointer to the sub-select */ common::ManagedPointer GetSubselect() { return common::ManagedPointer(subselect_); } + /** @return An owning pointer to the sub-select */ + std::unique_ptr ReleaseSubselect() { return std::move(subselect_); } + void Accept(common::ManagedPointer v) override { v->Visit(common::ManagedPointer(this)); } /** diff --git a/src/include/parser/select_statement.h b/src/include/parser/select_statement.h index 9bf1f165e7..9a92885f2a 100644 --- a/src/include/parser/select_statement.h +++ b/src/include/parser/select_statement.h @@ -462,7 +462,7 @@ class SelectStatement : public SQLStatement { // The depth of the SELECT statement int depth_{-1}; - // A colletion of the temporary tables (CTEs) available to this SELECT + // A collection of the temporary tables (CTEs) available to this SELECT std::vector> with_table_; /** @param select List of SELECT columns */ diff --git a/src/include/parser/udf/plpgsql_parser.h b/src/include/parser/udf/plpgsql_parser.h index 4f53072c92..e034394cfd 100644 --- a/src/include/parser/udf/plpgsql_parser.h +++ b/src/include/parser/udf/plpgsql_parser.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -12,6 +13,11 @@ #include "parser/expression_util.h" #include "parser/postgresparser.h" +#include "parser/sql_statement.h" + +namespace noisepage::parser { +class SQLStatement; +} // namespace noisepage::parser namespace noisepage::execution::ast::udf { class FunctionAST; @@ -19,6 +25,9 @@ class FunctionAST; namespace noisepage::parser::udf { +/** An enumeration over the supported PL/pgSQL statement types */ +enum class StatementType { UNKNOWN, RETURN, IF, ASSIGN, WHILE, FORI, FORS, EXECSQL, DYNEXECUTE }; + /** * The PLpgSQLParser class parses source PL/pgSQL to an abstract syntax tree. * @@ -69,6 +78,13 @@ class PLpgSQLParser { */ std::unique_ptr ParseDecl(const nlohmann::json &json); + /** + * Parse an assignment statement. + * @param json The input JSON object + * @return The AST for the assignment + */ + std::unique_ptr ParseAssign(const nlohmann::json &json); + /** * Parse an if-statement. * @param json The input JSON object @@ -102,7 +118,16 @@ class PLpgSQLParser { * @param json The input JSON object * @return The AST for the SQL statement */ - std::unique_ptr ParseSQL(const nlohmann::json &json); + std::unique_ptr ParseExecSQL(const nlohmann::json &json); + + /** + * Parse a SQL statement. + * @param sql The input SQL query text + * @param variables The collection of variables to which results are bound + * @return The AST for the SQL statement + */ + std::unique_ptr ParseExecSQL(const std::string &sql, + std::vector &&variables); /** * Parse a dynamic SQL statement. @@ -112,11 +137,37 @@ class PLpgSQLParser { std::unique_ptr ParseDynamicSQL(const nlohmann::json &json); /** - * Parse a SQL expression to an expression AST. + * Parse a SQL expression from a query string. * @param sql The SQL expression string * @return The AST for the SQL expression */ - std::unique_ptr ParseExprFromSQL(const std::string &sql); + std::unique_ptr ParseExprFromSQLString(const std::string &sql); + + /** + * Try to parse a SQL expression from a query string. If the expression + * type is not supported, indicate failure with an empty std::optional. + * @param sql The SQL expression string + * @return The AST for the SQL expression on success, empty std::optional on failure + */ + std::optional> TryParseExprFromSQLString( + const std::string &sql) noexcept; + + /** + * Parse a SQL expression from a SQL statement. + * @param statement The SQL statement + * @return The AST for the SQL statement + */ + std::unique_ptr ParseExprFromSQLStatement( + common::ManagedPointer statement); + + /** + * Try to parse an abstract expression from a SQL statement. If the statement + * type is not supported, indicate failure with an empty std::optional. + * @param statement The input SQL statement + * @return The AST for the statement on success, empty std::optional on failure + */ + std::optional> TryParseExprFromSQLStatement( + common::ManagedPointer statement) noexcept; /** * Parse an abstract expression to an expression AST. @@ -126,6 +177,15 @@ class PLpgSQLParser { std::unique_ptr ParseExprFromAbstract( common::ManagedPointer expr); + /** + * Try to parse an abstract expression to an expression AST. If the expression + * type is not supported, indicate failure with an empty std::optional. + * @param expr The input expression + * @return The AST for the expression on success, empty std::optional on failure + */ + std::optional> TryParseExprFromAbstract( + common::ManagedPointer expr) noexcept; + private: /** * Determine if all variables in `names` are declared in the function. @@ -149,6 +209,28 @@ class PLpgSQLParser { */ std::vector> ResolveRecordType(const ParseResult *parse_result); + /** + * Get the StatementType for the provided statement type identifier. + * @param type The identifier for the statement type + * @return The corresponding StatementType + */ + static StatementType GetStatementType(const std::string &type); + + /** + * Strip an enclosing SELECT query from an existing ParseResult. + * @param input The existing ParseResult + * @return A new ParseResult with the enclosing query stripped + */ + static std::unique_ptr StripEnclosingQuery(std::unique_ptr &&input); + + /** + * Determine if the parsed query has an enclosing "wrapper" query + * introduced by the PL/pgSQL parser. + * @param parse_result The parsed query + * @return `true` if the query has an enclosing query, `false` otherwise + */ + static bool HasEnclosingQuery(ParseResult *parse_result); + private: /** The UDF AST context */ common::ManagedPointer udf_ast_context_; diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index 5a3303c11a..e54315ce4e 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -2,6 +2,7 @@ #include "binder/bind_node_visitor.h" #include "execution/ast/udf/udf_ast_nodes.h" +#include "parser/expression/subquery_expression.h" #include "parser/udf/plpgsql_parse_result.h" #include "parser/udf/plpgsql_parser.h" #include "parser/udf/string_utils.h" @@ -63,14 +64,13 @@ std::unique_ptr PLpgSQLParser::Parse(const std const std::string &func_body) { auto result = PLpgSQLParseResult{pg_query_parse_plpgsql(func_body.c_str())}; if ((*result).error != nullptr) { - throw PARSER_EXCEPTION("PL/pgSQL parsing error"); + throw PARSER_EXCEPTION(fmt::format("PL/pgSQL parser : {}", (*result).error->message)); } + // The result is a list, we need to wrap it const auto ast_json_str = fmt::format("{{ \"{}\" : {} }}", K_FUNCTION_LIST, (*result).plpgsql_funcs); - std::istringstream ss{ast_json_str}; - nlohmann::json ast_json{}; - ss >> ast_json; + const nlohmann::json ast_json = nlohmann::json::parse(ast_json_str); const auto function_list = ast_json[K_FUNCTION_LIST]; NOISEPAGE_ASSERT(function_list.is_array(), "Function list is not an array"); @@ -89,23 +89,22 @@ std::unique_ptr PLpgSQLParser::Parse(const std } std::unique_ptr PLpgSQLParser::ParseFunction(const nlohmann::json &json) { - const auto decl_list = json[K_DATUMS]; - NOISEPAGE_ASSERT(decl_list.is_array(), "Declaration list is not an array"); + const auto declarations = json[K_DATUMS]; + NOISEPAGE_ASSERT(declarations.is_array(), "Declaration list is not an array"); const auto function_body = json[K_ACTION][K_PLPGSQL_STMT_BLOCK][K_BODY]; - std::vector> stmts{}; - - for (std::size_t i = 1UL; i < decl_list.size(); i++) { - stmts.push_back(ParseDecl(decl_list[i])); - } - - stmts.push_back(ParseBlock(function_body)); - return std::make_unique(std::move(stmts)); + std::vector> statements{}; + // Skip the first declaration in the datums list + std::transform(declarations.cbegin() + 1, declarations.cend(), std::back_inserter(statements), + [this](const nlohmann::json &declaration) -> std::unique_ptr { + return ParseDecl(declaration); + }); + statements.push_back(ParseBlock(function_body)); + return std::make_unique(std::move(statements)); } std::unique_ptr PLpgSQLParser::ParseBlock(const nlohmann::json &json) { - // TODO(boweic): Support statements size other than 1 NOISEPAGE_ASSERT(json.is_array(), "Block isn't array"); if (json.empty()) { throw PARSER_EXCEPTION("PL/pgSQL parser : Empty block is not supported"); @@ -113,38 +112,51 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl std::vector> statements{}; for (const auto &statement : json) { - const std::string &statement_type = statement.items().begin().key(); - if (statement_type == K_PLPGSQL_STMT_RETURN) { - // TODO(Kyle): Handle RETURN without expression - if (statement[K_PLPGSQL_STMT_RETURN].empty()) { - throw NOT_IMPLEMENTED_EXCEPTION("PL/pgSQL Parser : RETURN without expression not implemented."); + const StatementType statement_type = GetStatementType(statement.items().begin().key()); + switch (statement_type) { + case StatementType::RETURN: { + // TODO(Kyle): Handle RETURN without expression + if (statement[K_PLPGSQL_STMT_RETURN].empty()) { + throw NOT_IMPLEMENTED_EXCEPTION("PL/pgSQL Parser : RETURN without expression not implemented."); + } + auto expr = ParseExprFromSQLString( + statement[K_PLPGSQL_STMT_RETURN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); + statements.push_back(std::make_unique(std::move(expr))); + break; + } + case StatementType::IF: { + statements.push_back(ParseIf(statement[K_PLPGSQL_STMT_IF])); + break; + } + case StatementType::ASSIGN: { + // TODO(Kyle): Need to fix Assignment expression / statement + statements.push_back(ParseAssign(statement[K_PLPGSQL_STMT_ASSIGN])); + break; + } + case StatementType::WHILE: { + statements.push_back(ParseWhile(statement[K_PLPGSQL_STMT_WHILE])); + break; + } + case StatementType::FORI: { + statements.push_back(ParseForI(statement[K_PLPGSQL_STMT_FORI])); + break; + } + case StatementType::FORS: { + statements.push_back(ParseForS(statement[K_PLPGSQL_STMT_FORS])); + break; + } + case StatementType::EXECSQL: { + statements.push_back(ParseExecSQL(statement[K_PLGPSQL_STMT_EXECSQL])); + break; + } + case StatementType::DYNEXECUTE: { + statements.push_back(ParseDynamicSQL(statement[K_PLPGSQL_STMT_DYNEXECUTE])); + break; + } + case StatementType::UNKNOWN: { + throw PARSER_EXCEPTION( + fmt::format("PL/pgSQL Parser : statement type '{}' not supported", statement.items().begin().key())); } - auto expr = - ParseExprFromSQL(statement[K_PLPGSQL_STMT_RETURN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); - statements.push_back(std::make_unique(std::move(expr))); - } else if (statement_type == K_PLPGSQL_STMT_IF) { - statements.push_back(ParseIf(statement[K_PLPGSQL_STMT_IF])); - } else if (statement_type == K_PLPGSQL_STMT_ASSIGN) { - // TODO(Kyle): Need to fix Assignment expression / statement - // NOTE(Kyle): We subtract 1 here because variable numbers from - // the Postres parser index from 1 rather than 0 (?) - const auto &var_name = - udf_ast_context_->GetLocalAtIndex(statement[K_PLPGSQL_STMT_ASSIGN][K_VARNO].get() - 1); - auto lhs = std::make_unique(var_name); - auto rhs = ParseExprFromSQL(statement[K_PLPGSQL_STMT_ASSIGN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); - statements.push_back(std::make_unique(std::move(lhs), std::move(rhs))); - } else if (statement_type == K_PLPGSQL_STMT_WHILE) { - statements.push_back(ParseWhile(statement[K_PLPGSQL_STMT_WHILE])); - } else if (statement_type == K_PLPGSQL_STMT_FORI) { - statements.push_back(ParseForI(statement[K_PLPGSQL_STMT_FORI])); - } else if (statement_type == K_PLPGSQL_STMT_FORS) { - statements.push_back(ParseForS(statement[K_PLPGSQL_STMT_FORS])); - } else if (statement_type == K_PLGPSQL_STMT_EXECSQL) { - statements.push_back(ParseSQL(statement[K_PLGPSQL_STMT_EXECSQL])); - } else if (statement_type == K_PLPGSQL_STMT_DYNEXECUTE) { - statements.push_back(ParseDynamicSQL(statement[K_PLPGSQL_STMT_DYNEXECUTE])); - } else { - throw PARSER_EXCEPTION(fmt::format("PL/pgSQL Parser : statement type '{}' not supported", statement_type)); } } @@ -166,7 +178,7 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo // Parse the initializer, if present std::unique_ptr initial{nullptr}; if (json[K_PLPGSQL_VAR].find(K_DEFAULT_VAL) != json[K_PLPGSQL_VAR].end()) { - initial = ParseExprFromSQL(json[K_PLPGSQL_VAR][K_DEFAULT_VAL][K_PLPGSQL_EXPR][K_QUERY].get()); + initial = ParseExprFromSQLString(json[K_PLPGSQL_VAR][K_DEFAULT_VAL][K_PLPGSQL_EXPR][K_QUERY].get()); } // Detemine if the variable has already been declared; @@ -217,8 +229,33 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo throw PARSER_EXCEPTION(fmt::format("PL/pgSQL Parser : declaration type '{}' not supported", declaration_type)); } +std::unique_ptr PLpgSQLParser::ParseAssign(const nlohmann::json &json) { + // Extract the destination of the assignment + const auto var_index = json[K_VARNO].get() - 1; + const auto &var_name = udf_ast_context_->GetLocalAtIndex(var_index); + + // Attempt to parse the SQL expression to an AST expression + const auto &sql = json[K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get(); + auto rhs = TryParseExprFromSQLString(sql); + if (rhs.has_value()) { + auto lhs = std::make_unique(var_name); + return std::make_unique(std::move(lhs), std::move(*rhs)); + } + + // Failed to parse the SQL expression to an AST expression; + // this could be the result of malformed SQL, OR it could + // be that the SQL is sufficiently complex that we need to + // generate code to execute the query. In this latter case, + // we use the existing infrastructure for executing SQL in + // the UDF body, and "desugar" the assignment to a SELECT INTO. + + // TODO(Kyle): Is this semantically correct? We are hacking + // an assignment expression into a SQL execution statement + return ParseExecSQL(sql, std::vector{var_name}); +} + std::unique_ptr PLpgSQLParser::ParseIf(const nlohmann::json &json) { - auto cond_expr = ParseExprFromSQL(json[K_COND][K_PLPGSQL_EXPR][K_QUERY].get()); + auto cond_expr = ParseExprFromSQLString(json[K_COND][K_PLPGSQL_EXPR][K_QUERY].get()); auto then_stmt = ParseBlock(json[K_THEN_BODY]); std::unique_ptr else_stmt = json.contains(K_ELSE_BODY) ? ParseBlock(json[K_ELSE_BODY]) : nullptr; @@ -227,17 +264,17 @@ std::unique_ptr PLpgSQLParser::ParseIf(const nlohm } std::unique_ptr PLpgSQLParser::ParseWhile(const nlohmann::json &json) { - auto cond_expr = ParseExprFromSQL(json[K_COND][K_PLPGSQL_EXPR][K_QUERY].get()); + auto cond_expr = ParseExprFromSQLString(json[K_COND][K_PLPGSQL_EXPR][K_QUERY].get()); auto body_stmt = ParseBlock(json[K_BODY]); return std::make_unique(std::move(cond_expr), std::move(body_stmt)); } std::unique_ptr PLpgSQLParser::ParseForI(const nlohmann::json &json) { const auto name = json[K_VAR][K_PLPGSQL_VAR][K_REFNAME].get(); - auto lower = ParseExprFromSQL(json[K_LOWER][K_PLPGSQL_EXPR][K_QUERY]); - auto upper = ParseExprFromSQL(json[K_UPPER][K_PLPGSQL_EXPR][K_QUERY]); - auto step = json.contains(K_STEP) ? ParseExprFromSQL(json[K_STEP][K_PLPGSQL_EXPR][K_QUERY]) - : ParseExprFromSQL(execution::ast::udf::ForIStmtAST::DEFAULT_STEP_EXPR); + auto lower = ParseExprFromSQLString(json[K_LOWER][K_PLPGSQL_EXPR][K_QUERY]); + auto upper = ParseExprFromSQLString(json[K_UPPER][K_PLPGSQL_EXPR][K_QUERY]); + auto step = json.contains(K_STEP) ? ParseExprFromSQLString(json[K_STEP][K_PLPGSQL_EXPR][K_QUERY]) + : ParseExprFromSQLString(execution::ast::udf::ForIStmtAST::DEFAULT_STEP_EXPR); auto body = ParseBlock(json[K_BODY]); return std::make_unique(name, std::move(lower), std::move(upper), std::move(step), std::move(body)); @@ -264,28 +301,36 @@ std::unique_ptr PLpgSQLParser::ParseForS(const nlo std::move(body_stmt)); } -std::unique_ptr PLpgSQLParser::ParseSQL(const nlohmann::json &json) { +std::unique_ptr PLpgSQLParser::ParseExecSQL(const nlohmann::json &json) { // The query text - const auto sql_query = json[K_SQLSTMT][K_PLPGSQL_EXPR][K_QUERY].get(); - auto parse_result = PostgresParser::BuildParseTree(sql_query); - if (parse_result == nullptr) { - return nullptr; - } - - auto variable_array = json[K_ROW][K_PLPGSQL_ROW][K_FIELDS]; + const auto sql = json[K_SQLSTMT][K_PLPGSQL_EXPR][K_QUERY].get(); + // The variable(s) to which results are bound + const auto variable_array = json[K_ROW][K_PLPGSQL_ROW][K_FIELDS]; std::vector variables{}; variables.reserve(variable_array.size()); std::transform(variable_array.cbegin(), variable_array.cend(), std::back_inserter(variables), [](const nlohmann::json &var) -> std::string { return var[K_NAME].get(); }); + return ParseExecSQL(sql, std::move(variables)); +} + +std::unique_ptr PLpgSQLParser::ParseExecSQL(const std::string &sql, + std::vector &&variables) { + auto parse_result = StripEnclosingQuery(PostgresParser::BuildParseTree(sql)); + if (!parse_result) { + throw PARSER_EXCEPTION(fmt::format("PL/pgSQL parser : failed to parse query '{}'", sql)); + } + // Ensure all variables to which results are bound are declared if (!AllVariablesDeclared(variables)) { throw PARSER_EXCEPTION("PL/pgSQL parser : variable was not declared"); } - // Two possibilities for binding of results: - // - Exactly one RECORD variable - // - One or more non-RECORD variables + /** + * Two possibilities for binding of results: + * - Exactly one RECORD variable + * - One or more non-RECORD variables + */ if (ContainsRecordType(variables)) { if (variables.size() > 1) { @@ -301,64 +346,128 @@ std::unique_ptr PLpgSQLParser::ParseSQL(const nloh } std::unique_ptr PLpgSQLParser::ParseDynamicSQL(const nlohmann::json &json) { - auto sql_expr = ParseExprFromSQL(json[K_QUERY][K_PLPGSQL_EXPR][K_QUERY].get()); + auto sql_expr = ParseExprFromSQLString(json[K_QUERY][K_PLPGSQL_EXPR][K_QUERY].get()); auto var_name = json[K_ROW][K_PLPGSQL_ROW][K_FIELDS][0][K_NAME].get(); return std::make_unique(std::move(sql_expr), std::move(var_name)); } -std::unique_ptr PLpgSQLParser::ParseExprFromSQL(const std::string &sql) { - auto stmt_list = PostgresParser::BuildParseTree(sql); - if (stmt_list == nullptr) { - return nullptr; +std::unique_ptr PLpgSQLParser::ParseExprFromSQLString(const std::string &sql) { + auto expr = TryParseExprFromSQLString(sql); + if (!expr.has_value()) { + throw PARSER_EXCEPTION("PL/pgSQL parser : failed to parse SQL query"); } - NOISEPAGE_ASSERT(stmt_list->GetStatements().size() == 1, "Bad number of statements"); - auto stmt = stmt_list->GetStatement(0); - NOISEPAGE_ASSERT(stmt->GetType() == parser::StatementType::SELECT, "Unsupported statement type"); - NOISEPAGE_ASSERT(stmt.CastManagedPointerTo()->GetSelectTable() == nullptr, - "Unsupported SQL Expr in UDF"); - auto &select_list = stmt.CastManagedPointerTo()->GetSelectColumns(); - NOISEPAGE_ASSERT(select_list.size() == 1, "Unsupported number of select columns in UDF"); - return PLpgSQLParser::ParseExprFromAbstract(select_list[0]); + return std::move(*expr); +} + +std::optional> PLpgSQLParser::TryParseExprFromSQLString( + const std::string &sql) noexcept { + auto statements = PostgresParser::BuildParseTree(sql); + if (!statements) { + return std::nullopt; + } + + if (statements->GetStatements().size() != 1) { + return std::nullopt; + } + return TryParseExprFromSQLStatement(statements->GetStatement(0)); +} + +std::unique_ptr PLpgSQLParser::ParseExprFromSQLStatement( + common::ManagedPointer statement) { + auto expr = TryParseExprFromSQLStatement(statement); + if (!expr.has_value()) { + throw PARSER_EXCEPTION("PL/pgSQL parser : failed to parse SQL statement"); + } + return std::move(*expr); +} + +std::optional> PLpgSQLParser::TryParseExprFromSQLStatement( + common::ManagedPointer statement) noexcept { + if (statement->GetType() != parser::StatementType::SELECT) { + return std::nullopt; + } + + auto select = statement.CastManagedPointerTo(); + if (select->GetSelectTable() != nullptr || select->GetSelectColumns().size() != 1) { + return std::nullopt; + } + return TryParseExprFromAbstract(select->GetSelectColumns().at(0)); } std::unique_ptr PLpgSQLParser::ParseExprFromAbstract( common::ManagedPointer expr) { - if (expr->GetExpressionType() == parser::ExpressionType::COLUMN_VALUE) { - auto cve = expr.CastManagedPointerTo(); - if (cve->GetTableName().empty()) { - return std::make_unique(cve->GetColumnName()); - } - auto vexpr = std::make_unique(cve->GetTableName()); - return std::make_unique(std::move(vexpr), cve->GetColumnName()); + auto result = TryParseExprFromAbstract(expr); + if (!result.has_value()) { + throw PARSER_EXCEPTION(fmt::format("PL/pgSQL parser : expression type '{}' not supported", + parser::ExpressionTypeToShortString(expr->GetExpressionType()))); } + return std::move(*result); +} +std::optional> PLpgSQLParser::TryParseExprFromAbstract( + common::ManagedPointer expr) noexcept { if ((parser::ExpressionUtil::IsOperatorExpression(expr->GetExpressionType()) && expr->GetChildrenSize() == 2) || (parser::ExpressionUtil::IsComparisonExpression(expr->GetExpressionType()))) { - return std::make_unique( - expr->GetExpressionType(), ParseExprFromAbstract(expr->GetChild(0)), ParseExprFromAbstract(expr->GetChild(1))); + auto lhs = TryParseExprFromAbstract(expr->GetChild(0)); + auto rhs = TryParseExprFromAbstract(expr->GetChild(1)); + if (!lhs.has_value() || !rhs.has_value()) { + return std::nullopt; + } + return std::make_optional(std::make_unique(expr->GetExpressionType(), + std::move(*lhs), std::move(*rhs))); } // TODO(Kyle): I am not a fan of non-exhaustive switch statements; // is there a way that we can refactor this logic to make it better? switch (expr->GetExpressionType()) { + case parser::ExpressionType::COLUMN_VALUE: { + auto cve = expr.CastManagedPointerTo(); + if (cve->GetTableName().empty()) { + return std::make_optional(std::make_unique(cve->GetColumnName())); + } + auto vexpr = std::make_unique(cve->GetTableName()); + return std::make_optional( + std::make_unique(std::move(vexpr), cve->GetColumnName())); + } case parser::ExpressionType::FUNCTION: { auto func_expr = expr.CastManagedPointerTo(); std::vector> args{}; auto num_args = func_expr->GetChildrenSize(); for (std::size_t idx = 0; idx < num_args; ++idx) { - args.push_back(ParseExprFromAbstract(func_expr->GetChild(idx))); + auto arg = TryParseExprFromAbstract(func_expr->GetChild(idx)); + if (!arg.has_value()) { + return std::nullopt; + } + args.push_back(std::move(*arg)); } - return std::make_unique(func_expr->GetFuncName(), std::move(args)); + return std::make_optional( + std::make_unique(func_expr->GetFuncName(), std::move(args))); } case parser::ExpressionType::VALUE_CONSTANT: - return std::make_unique(expr->Copy()); - case parser::ExpressionType::OPERATOR_IS_NOT_NULL: - return std::make_unique(false, ParseExprFromAbstract(expr->GetChild(0))); - case parser::ExpressionType::OPERATOR_IS_NULL: - return std::make_unique(true, ParseExprFromAbstract(expr->GetChild(0))); + return std::make_optional(std::make_unique(expr->Copy())); + case parser::ExpressionType::OPERATOR_IS_NOT_NULL: { + auto target = TryParseExprFromAbstract(expr->GetChild(0)); + if (!target.has_value()) { + return std::nullopt; + } + return std::make_optional(std::make_unique(false, std::move(*target))); + } + case parser::ExpressionType::OPERATOR_IS_NULL: { + auto target = TryParseExprFromAbstract(expr->GetChild(0)); + if (!target.has_value()) { + return std::nullopt; + } + return std::make_optional(std::make_unique(true, std::move(*target))); + } + case parser::ExpressionType::ROW_SUBQUERY: { + // We can handle subqeries, but only in the event + // that they are shallow "wrappers" around simple queries + auto subquery_expr = expr.CastManagedPointerTo(); + return TryParseExprFromSQLStatement(subquery_expr->GetSubselect().CastManagedPointerTo()); + } default: - throw PARSER_EXCEPTION("PL/pgSQL parser : expression type not supported"); + return std::nullopt; } } @@ -385,4 +494,88 @@ std::vector> PLpgSQLParser::ResolveRecordTy return fields; } +StatementType PLpgSQLParser::GetStatementType(const std::string &type) { + if (type == K_PLPGSQL_STMT_RETURN) { + return StatementType::RETURN; + } else if (type == K_PLPGSQL_STMT_IF) { + return StatementType::IF; + } else if (type == K_PLPGSQL_STMT_ASSIGN) { + return StatementType::ASSIGN; + } else if (type == K_PLPGSQL_STMT_WHILE) { + return StatementType::WHILE; + } else if (type == K_PLPGSQL_STMT_FORI) { + return StatementType::FORI; + } else if (type == K_PLPGSQL_STMT_FORS) { + return StatementType::FORS; + } else if (type == K_PLGPSQL_STMT_EXECSQL) { + return StatementType::EXECSQL; + } else if (type == K_PLPGSQL_STMT_DYNEXECUTE) { + return StatementType::DYNEXECUTE; + } else { + return StatementType::UNKNOWN; + } +} + +// Static +std::unique_ptr PLpgSQLParser::StripEnclosingQuery(std::unique_ptr &&input) { + NOISEPAGE_ASSERT(input->GetStatements().size() == 1, "Must have a single SQL statement"); + + // If the query does not match the target pattern, return unmodified + if (!HasEnclosingQuery(input.get())) { + return std::move(input); + } + + // The query consists of enclosing SELECT around a + // subquery that implements the actual logic we want; + // now we perform some surgery on the ParseResult + + // Grab the SELECT from the subquery expression + auto statement = input->GetStatement(0); + auto select = statement.CastManagedPointerTo(); + auto subquery = select->GetSelectColumns().at(0).CastManagedPointerTo(); + + // Here, we take ownership of the new top-level statement for the query; + // the SELECT does not own its own target expressions, however, so we + // need to ensure that we manually copy these over into the new ParseResult + // such that their data is still alive after the transformation + auto subselect = subquery->ReleaseSubselect(); + + // Take ownership of the expressions we want; it is important + // that we actually take ownership of the existing expressions + // rather than making a copy of the collection because the + // remainder of the statements in the query hold non-owning + // pointers to these existing expressions, copies will result + // in dangling pointers in any number of the query statements + auto expressions = input->TakeExpressionsOwnership(); + expressions.erase(std::remove_if(expressions.begin(), expressions.end(), + [](const std::unique_ptr &expr) { + return expr->GetExpressionType() == parser::ExpressionType::ROW_SUBQUERY; + })); + + auto output = std::make_unique(); + output->AddStatement(std::move(subselect)); + for (auto &expression : expressions) { + output->AddExpression(std::move(expression)); + } + + // The input ParseResult is dropped here, so we need to be sure + // that we have extracted all of the data that we want out of it + + return output; +} + +bool PLpgSQLParser::HasEnclosingQuery(ParseResult *parse_result) { + NOISEPAGE_ASSERT(parse_result->GetStatements().size() == 1, "Must have a single SQL statement"); + auto statement = parse_result->GetStatement(0); + if (statement->GetType() != parser::StatementType::SELECT) { + return false; + } + auto select = statement.CastManagedPointerTo(); + if (select->GetSelectColumns().size() > 1) { + return false; + } + auto target = select->GetSelectColumns().at(0); + return (target->GetExpressionType() == parser::ExpressionType::ROW_SUBQUERY); +} + } // namespace noisepage::parser::udf From 964448ddea62fa383c41f900793513422154de7c Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 5 Aug 2021 22:44:22 -0400 Subject: [PATCH 096/139] add support for multiple embedded queries within UDF body, not particularly happy with the solution but it is certainly the most expedient --- script/testing/junit/sql/udf.sql | 22 ++++++++---------- script/testing/junit/traces/udf.test | 2 +- .../compiler/compilation_context.cpp | 2 +- .../compiler/operator/operator_translator.cpp | 12 ++++++++++ .../static_aggregation_translator.cpp | 11 +++++---- .../compiler/operator/operator_translator.h | 14 +++++++++++ src/parser/udf/plpgsql_parser.cpp | 23 +++++++++++-------- 7 files changed, 58 insertions(+), 28 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index b09990b0c5..7a94829502 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -212,18 +212,16 @@ SELECT sql_embedded_agg_max(); -- ---------------------------------------------------------------------------- -- sql_embedded_agg_multi() --- CREATE FUNCTION sql_embedded_agg_multi() RETURNS INT AS $$ \ --- DECLARE \ --- s INT; \ --- minimum INT; \ --- maximum INT; \ --- BEGIN \ --- minimum = (SELECT MIN(x) FROM integers); \ --- maximum = (SELECT MAX(x) FROM integers); \ --- s = minumum + maximum; \ --- RETURN s; \ --- END; \ --- $$ LANGUAGE PLPGSQL; +CREATE FUNCTION sql_embedded_agg_multi() RETURNS INT AS $$ \ +DECLARE \ + minimum INT; \ + maximum INT; \ +BEGIN \ + minimum = (SELECT MIN(x) FROM integers); \ + maximum = (SELECT MAX(x) FROM integers); \ + RETURN minimum + maximum; \ +END; \ +$$ LANGUAGE PLPGSQL; -- ---------------------------------------------------------------------------- -- proc_fors() diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index c97aad3d27..ca286fea25 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -424,7 +424,7 @@ statement ok statement ok --- CREATE FUNCTION sql_embedded_agg_multi() RETURNS INT AS $$ -- DECLARE -- s INT; -- minimum INT; -- maximum INT; -- BEGIN -- minimum = (SELECT MIN(x) FROM integers); -- maximum = (SELECT MAX(x) FROM integers); -- s = minumum + maximum; -- RETURN s; -- END; -- $$ LANGUAGE PLPGSQL; +CREATE FUNCTION sql_embedded_agg_multi() RETURNS INT AS $$ DECLARE minimum INT; maximum INT; BEGIN minimum = (SELECT MIN(x) FROM integers); maximum = (SELECT MAX(x) FROM integers); RETURN minimum + maximum; END; $$ LANGUAGE PLPGSQL; statement ok diff --git a/src/execution/compiler/compilation_context.cpp b/src/execution/compiler/compilation_context.cpp index f6c5f857be..af9a3e345b 100644 --- a/src/execution/compiler/compilation_context.cpp +++ b/src/execution/compiler/compilation_context.cpp @@ -444,7 +444,7 @@ ExpressionTranslator *CompilationContext::LookupTranslator(const parser::Abstrac std::string CompilationContext::GetFunctionPrefix() const { // If an output callback is present, we prefix // each function with the callback name - if (output_callback_ != nullptr) { + if (HasOutputCallback()) { return fmt::format("{}Query{}", output_callback_->GetName().GetString(), std::to_string(unique_id_)); } return fmt::format("Query{}", std::to_string(unique_id_)); diff --git a/src/execution/compiler/operator/operator_translator.cpp b/src/execution/compiler/operator/operator_translator.cpp index 071ce93fd8..e4d36d4776 100644 --- a/src/execution/compiler/operator/operator_translator.cpp +++ b/src/execution/compiler/operator/operator_translator.cpp @@ -88,6 +88,18 @@ ast::Expr *OperatorTranslator::GetMemoryPool() const { return GetCodeGen()->ExecCtxGetMemoryPool(GetExecutionContext()); } +ast::Identifier OperatorTranslator::MakeLocalIdentifier(std::string_view name) const { + const auto identifier = fmt::format("{}", name); + return GetCodeGen()->MakeFreshIdentifier(identifier); +} + +ast::Identifier OperatorTranslator::MakeGlobalIdentifier(std::string_view name) const { + const auto identifier = GetCompilationContext()->HasOutputCallback() + ? fmt::format("{}{}", GetCompilationContext()->GetFunctionPrefix(), name) + : fmt::format("{}", name); + return GetCodeGen()->MakeFreshIdentifier(identifier); +} + void OperatorTranslator::GetAllChildOutputFields(const uint32_t child_index, const std::string &field_name_prefix, util::RegionVector *fields) const { auto *codegen = GetCodeGen(); diff --git a/src/execution/compiler/operator/static_aggregation_translator.cpp b/src/execution/compiler/operator/static_aggregation_translator.cpp index 77fd01787f..43a81266b0 100644 --- a/src/execution/compiler/operator/static_aggregation_translator.cpp +++ b/src/execution/compiler/operator/static_aggregation_translator.cpp @@ -12,15 +12,18 @@ namespace noisepage::execution::compiler { namespace { constexpr char AGG_ATTR_PREFIX[] = "agg_term_attr"; +constexpr char AGG_ROW_VAR[] = "aggRow"; +constexpr char AGG_PAYLOAD_TYPE[] = "AggPayload"; +constexpr char AGG_VALUES_TYPE[] = "AggValues"; +constexpr char AGG_MERGE_FUNC[] = "MergeAggregates"; } // namespace StaticAggregationTranslator::StaticAggregationTranslator(const planner::AggregatePlanNode &plan, CompilationContext *compilation_context, Pipeline *pipeline) : OperatorTranslator(plan, compilation_context, pipeline, selfdriving::ExecutionOperatingUnitType::DUMMY), - agg_row_var_(GetCodeGen()->MakeFreshIdentifier("aggRow")), - agg_payload_type_(GetCodeGen()->MakeFreshIdentifier("AggPayload")), - agg_values_type_(GetCodeGen()->MakeFreshIdentifier("AggValues")), - merge_func_(GetCodeGen()->MakeFreshIdentifier("MergeAggregates")), + agg_row_var_(GetCodeGen()->MakeFreshIdentifier(AGG_ROW_VAR)), + agg_payload_type_(MakeGlobalIdentifier(AGG_PAYLOAD_TYPE)), + agg_values_type_(MakeGlobalIdentifier(AGG_VALUES_TYPE)), build_pipeline_(this, Pipeline::Parallelism::Parallel) { NOISEPAGE_ASSERT(plan.GetGroupByTerms().empty(), "Global aggregations shouldn't have grouping keys"); NOISEPAGE_ASSERT(plan.GetChildrenSize() == 1, "Global aggregations should only have one child"); diff --git a/src/include/execution/compiler/operator/operator_translator.h b/src/include/execution/compiler/operator/operator_translator.h index f35ce91468..7309cf52b5 100644 --- a/src/include/execution/compiler/operator/operator_translator.h +++ b/src/include/execution/compiler/operator/operator_translator.h @@ -273,6 +273,20 @@ class OperatorTranslator : public ColumnValueProvider { /** @return The pipeline this translator is a part of. */ Pipeline *GetPipeline() const { return pipeline_; } + /** + * Make a local identifier from `name`. + * @param name The base name for the identifier + * @return The identifier + */ + ast::Identifier MakeLocalIdentifier(std::string_view name) const; + + /** + * Make a global identifier from `name`. + * @param name The base name for the identifier + * @return The identifier + */ + ast::Identifier MakeGlobalIdentifier(std::string_view name) const; + /** The plan node for this translator as its concrete type. */ template const T &GetPlanAs() const { diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index e54315ce4e..c535970f65 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -495,25 +495,27 @@ std::vector> PLpgSQLParser::ResolveRecordTy } StatementType PLpgSQLParser::GetStatementType(const std::string &type) { + StatementType stmt_type; if (type == K_PLPGSQL_STMT_RETURN) { - return StatementType::RETURN; + stmt_type = StatementType::RETURN; } else if (type == K_PLPGSQL_STMT_IF) { - return StatementType::IF; + stmt_type = StatementType::IF; } else if (type == K_PLPGSQL_STMT_ASSIGN) { - return StatementType::ASSIGN; + stmt_type = StatementType::ASSIGN; } else if (type == K_PLPGSQL_STMT_WHILE) { - return StatementType::WHILE; + stmt_type = StatementType::WHILE; } else if (type == K_PLPGSQL_STMT_FORI) { - return StatementType::FORI; + stmt_type = StatementType::FORI; } else if (type == K_PLPGSQL_STMT_FORS) { - return StatementType::FORS; + stmt_type = StatementType::FORS; } else if (type == K_PLGPSQL_STMT_EXECSQL) { - return StatementType::EXECSQL; + stmt_type = StatementType::EXECSQL; } else if (type == K_PLPGSQL_STMT_DYNEXECUTE) { - return StatementType::DYNEXECUTE; + stmt_type = StatementType::DYNEXECUTE; } else { - return StatementType::UNKNOWN; + stmt_type = StatementType::UNKNOWN; } + return stmt_type; } // Static @@ -550,7 +552,8 @@ std::unique_ptr PLpgSQLParser::StripEnclosingQuery(std::uni expressions.erase(std::remove_if(expressions.begin(), expressions.end(), [](const std::unique_ptr &expr) { return expr->GetExpressionType() == parser::ExpressionType::ROW_SUBQUERY; - })); + }), + expressions.end()); auto output = std::make_unique(); output->AddStatement(std::move(subselect)); From bc795eaf55839694e4bebcbd00e4ae58bfeb2513 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 6 Aug 2021 15:50:29 -0400 Subject: [PATCH 097/139] integration tests for query-fed for loop construct, passing --- script/testing/junit/sql/udf.sql | 42 +++--- script/testing/junit/traces/udf.test | 53 +++++-- src/execution/compiler/udf/udf_codegen.cpp | 166 +++++++++++---------- 3 files changed, 150 insertions(+), 111 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 7a94829502..bfd4f1b813 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -224,10 +224,7 @@ END; \ $$ LANGUAGE PLPGSQL; -- ---------------------------------------------------------------------------- --- proc_fors() - --- CREATE TABLE tmp(z INT); --- INSERT INTO tmp(z) VALUES (0), (1); +-- proc_fors_constant_var() -- Select constant into a scalar variable CREATE FUNCTION proc_fors_constant_var() RETURNS INT AS $$ \ @@ -244,6 +241,9 @@ $$ LANGUAGE PLPGSQL; SELECT proc_fors_constant_var(); +-- ---------------------------------------------------------------------------- +-- proc_fors_constant_vars() + -- Select multiple constants in scalar variables CREATE FUNCTION proc_fors_constant_vars() RETURNS INT AS $$ \ DECLARE \ @@ -260,6 +260,11 @@ $$ LANGUAGE PLPGSQL; SELECT proc_fors_constant_vars(); +-- ---------------------------------------------------------------------------- +-- proc_fors_rec() +-- +-- TODO(Kyle): RECORD types not supported + -- -- Bind query result to a RECORD type -- CREATE FUNCTION proc_fors_rec() RETURNS INT AS $$ \ -- DECLARE \ @@ -275,17 +280,20 @@ SELECT proc_fors_constant_vars(); -- SELECT proc_fors_rec() FROM integers; --- -- Bind query result directly to INT type --- CREATE FUNCTION proc_fors_var() RETURNS INT AS $$ \ --- DECLARE \ --- x INT := 0; \ --- v INT; \ --- BEGIN \ --- FOR v IN (SELECT z FROM tmp) LOOP \ --- x = x + 1; \ --- END LOOP; \ --- RETURN x; \ --- END \ --- $$ LANGUAGE PLPGSQL; +-- ---------------------------------------------------------------------------- +-- proc_fors_var() + +-- Bind query result directly to INT type +CREATE FUNCTION proc_fors_var() RETURNS INT AS $$ \ +DECLARE \ + c INT := 0; \ + v INT; \ +BEGIN \ + FOR v IN (SELECT x FROM integers) LOOP \ + c = c + 1; \ + END LOOP; \ + RETURN c; \ +END \ +$$ LANGUAGE PLPGSQL; --- SELECT proc_fors_var() FROM integers; +SELECT proc_fors_var(); diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index ca286fea25..96de75061a 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -433,16 +433,7 @@ statement ok -- ---------------------------------------------------------------------------- statement ok --- proc_fors() - -statement ok - - -statement ok --- CREATE TABLE tmp(z INT); - -statement ok --- INSERT INTO tmp(z) VALUES (0), (1); +-- proc_fors_constant_var() statement ok @@ -465,6 +456,15 @@ SELECT proc_fors_constant_var(); statement ok +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_fors_constant_vars() + +statement ok + + statement ok -- Select multiple constants in scalar variables @@ -483,6 +483,21 @@ SELECT proc_fors_constant_vars(); statement ok +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_fors_rec() + +statement ok +-- + +statement ok +-- TODO(Kyle): RECORD types not supported + +statement ok + + statement ok -- -- Bind query result to a RECORD type @@ -502,14 +517,26 @@ statement ok statement ok --- -- Bind query result directly to INT type +-- ---------------------------------------------------------------------------- statement ok --- CREATE FUNCTION proc_fors_var() RETURNS INT AS $$ -- DECLARE -- x INT := 0; -- v INT; -- BEGIN -- FOR v IN (SELECT z FROM tmp) LOOP -- x = x + 1; -- END LOOP; -- RETURN x; -- END -- $$ LANGUAGE PLPGSQL; +-- proc_fors_var() statement ok statement ok --- SELECT proc_fors_var() FROM integers; +-- Bind query result directly to INT type + +statement ok +CREATE FUNCTION proc_fors_var() RETURNS INT AS $$ DECLARE c INT := 0; v INT; BEGIN FOR v IN (SELECT x FROM integers) LOOP c = c + 1; END LOOP; RETURN c; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_fors_var(); +---- +3 + diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index add41f2e68..352122d21e 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -39,15 +39,15 @@ UdfCodegen::UdfCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, needs_exec_ctx_{false} { for (auto i = 0UL; fb->GetParameterByPosition(i) != nullptr; ++i) { auto param = fb->GetParameterByPosition(i); - const auto &name = param->As()->Name(); + const auto &name = param->As()->Name(); SymbolTable()[name.GetString()] = name; } } // Static -execution::ast::File *UdfCodegen::Run(catalog::CatalogAccessor *accessor, FunctionBuilder *function_builder, - ast::udf::UdfAstContext *ast_context, CodeGen *codegen, catalog::db_oid_t db_oid, - ast::udf::FunctionAST *root) { +ast::File *UdfCodegen::Run(catalog::CatalogAccessor *accessor, FunctionBuilder *function_builder, + ast::udf::UdfAstContext *ast_context, CodeGen *codegen, catalog::db_oid_t db_oid, + ast::udf::FunctionAST *root) { UdfCodegen generator{accessor, function_builder, ast_context, codegen, db_oid}; generator.GenerateUDF(root->Body()); return generator.Finish(); @@ -58,7 +58,7 @@ const char *UdfCodegen::GetReturnParamString() { return "return_val"; } void UdfCodegen::GenerateUDF(ast::udf::AbstractAST *ast) { ast->Accept(this); } -catalog::type_oid_t UdfCodegen::GetCatalogTypeOidFromSQLType(execution::ast::BuiltinType::Kind type) { +catalog::type_oid_t UdfCodegen::GetCatalogTypeOidFromSQLType(ast::BuiltinType::Kind type) { switch (type) { case ast::BuiltinType::Kind::Integer: { return accessor_->GetTypeOidFromTypeId(type::TypeId::INTEGER); @@ -93,15 +93,15 @@ void UdfCodegen::Visit(ast::udf::DynamicSQLStmtAST *ast) { } void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { - std::vector args_ast{}; - std::vector args_ast_region_vec{}; + std::vector args_ast{}; + std::vector args_ast_region_vec{}; std::vector arg_types{}; for (auto &arg : ast->Args()) { arg->Accept(this); args_ast.push_back(dst_); args_ast_region_vec.push_back(dst_); - auto *builtin = dst_->GetType()->SafeAs(); + auto *builtin = dst_->GetType()->SafeAs(); NOISEPAGE_ASSERT(builtin != nullptr, "Parameter must be a built-in type"); NOISEPAGE_ASSERT(builtin->IsSqlValueType(), "Parameter must be a SQL value type"); arg_types.push_back(GetCatalogTypeOidFromSQLType(builtin->GetKind())); @@ -114,13 +114,13 @@ void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { fb_->Append(codegen_->MakeStmt(codegen_->CallBuiltin(context->GetBuiltin(), args_ast))); } else { auto it = SymbolTable().find(ast->Callee()); - execution::ast::Identifier ident_expr; + ast::Identifier ident_expr; if (it != SymbolTable().end()) { ident_expr = it->second; } else { - auto file = reinterpret_cast( - execution::ast::AstClone::Clone(context->GetFile(), codegen_->GetAstContext()->GetNodeFactory(), - context->GetASTContext(), codegen_->GetAstContext().Get())); + auto file = reinterpret_cast( + ast::AstClone::Clone(context->GetFile(), codegen_->GetAstContext()->GetNodeFactory(), + context->GetASTContext(), codegen_->GetAstContext().Get())); for (auto decl : file->Declarations()) { aux_decls_.push_back(decl); } @@ -140,14 +140,14 @@ void UdfCodegen::Visit(ast::udf::DeclStmtAST *ast) { return; } - const execution::ast::Identifier identifier = codegen_->MakeFreshIdentifier(ast->Name()); + const ast::Identifier identifier = codegen_->MakeFreshIdentifier(ast->Name()); SymbolTable()[ast->Name()] = identifier; auto prev_type = current_type_; - execution::ast::Expr *tpl_type = nullptr; + ast::Expr *tpl_type = nullptr; if (ast->Type() == type::TypeId::INVALID) { // Record type - execution::util::RegionVector fields{codegen_->GetAstContext()->GetRegion()}; + util::RegionVector fields{codegen_->GetAstContext()->GetRegion()}; // TODO(Kyle): Handle unbound record types const auto record_type = udf_ast_context_->GetRecordType(ast->Name()); @@ -157,14 +157,14 @@ void UdfCodegen::Visit(ast::udf::DeclStmtAST *ast) { } for (const auto &p : record_type.value()) { - fields.push_back(codegen_->MakeField(codegen_->MakeIdentifier(p.first), - codegen_->TplType(execution::sql::GetTypeId(p.second)))); + fields.push_back( + codegen_->MakeField(codegen_->MakeIdentifier(p.first), codegen_->TplType(sql::GetTypeId(p.second)))); } auto record_decl = codegen_->DeclareStruct(codegen_->MakeFreshIdentifier("rectype"), std::move(fields)); aux_decls_.push_back(record_decl); tpl_type = record_decl->TypeRepr(); } else { - tpl_type = codegen_->TplType(execution::sql::GetTypeId(ast->Type())); + tpl_type = codegen_->TplType(sql::GetTypeId(ast->Type())); } current_type_ = ast->Type(); if (ast->Initial() != nullptr) { @@ -195,27 +195,27 @@ void UdfCodegen::Visit(ast::udf::ValueExprAST *ast) { dst_ = codegen_->ConstNull(current_type_); return; } - auto type_id = execution::sql::GetTypeId(val->GetReturnValueType()); + auto type_id = sql::GetTypeId(val->GetReturnValueType()); switch (type_id) { - case execution::sql::TypeId::Boolean: + case sql::TypeId::Boolean: dst_ = codegen_->BoolToSql(val->GetBoolVal().val_); break; - case execution::sql::TypeId::TinyInt: - case execution::sql::TypeId::SmallInt: - case execution::sql::TypeId::Integer: - case execution::sql::TypeId::BigInt: + case sql::TypeId::TinyInt: + case sql::TypeId::SmallInt: + case sql::TypeId::Integer: + case sql::TypeId::BigInt: dst_ = codegen_->IntToSql(val->GetInteger().val_); break; - case execution::sql::TypeId::Float: - case execution::sql::TypeId::Double: + case sql::TypeId::Float: + case sql::TypeId::Double: dst_ = codegen_->FloatToSql(val->GetReal().val_); - case execution::sql::TypeId::Date: + case sql::TypeId::Date: dst_ = codegen_->DateToSql(val->GetDateVal().val_); break; - case execution::sql::TypeId::Timestamp: + case sql::TypeId::Timestamp: dst_ = codegen_->TimestampToSql(val->GetTimestampVal().val_); break; - case execution::sql::TypeId::Varchar: + case sql::TypeId::Varchar: dst_ = codegen_->StringToSql(val->GetStringVal().StringView()); break; default: @@ -239,49 +239,49 @@ void UdfCodegen::Visit(ast::udf::AssignStmtAST *ast) { } void UdfCodegen::Visit(ast::udf::BinaryExprAST *ast) { - execution::parsing::Token::Type op_token; + parsing::Token::Type op_token; bool compare = false; switch (ast->Op()) { - case noisepage::parser::ExpressionType::OPERATOR_DIVIDE: - op_token = execution::parsing::Token::Type::SLASH; + case parser::ExpressionType::OPERATOR_DIVIDE: + op_token = parsing::Token::Type::SLASH; break; - case noisepage::parser::ExpressionType::OPERATOR_PLUS: - op_token = execution::parsing::Token::Type::PLUS; + case parser::ExpressionType::OPERATOR_PLUS: + op_token = parsing::Token::Type::PLUS; break; - case noisepage::parser::ExpressionType::OPERATOR_MINUS: - op_token = execution::parsing::Token::Type::MINUS; + case parser::ExpressionType::OPERATOR_MINUS: + op_token = parsing::Token::Type::MINUS; break; - case noisepage::parser::ExpressionType::OPERATOR_MULTIPLY: - op_token = execution::parsing::Token::Type::STAR; + case parser::ExpressionType::OPERATOR_MULTIPLY: + op_token = parsing::Token::Type::STAR; break; - case noisepage::parser::ExpressionType::OPERATOR_MOD: - op_token = execution::parsing::Token::Type::PERCENT; + case parser::ExpressionType::OPERATOR_MOD: + op_token = parsing::Token::Type::PERCENT; break; - case noisepage::parser::ExpressionType::CONJUNCTION_OR: - op_token = execution::parsing::Token::Type::OR; + case parser::ExpressionType::CONJUNCTION_OR: + op_token = parsing::Token::Type::OR; break; - case noisepage::parser::ExpressionType::CONJUNCTION_AND: - op_token = execution::parsing::Token::Type::AND; + case parser::ExpressionType::CONJUNCTION_AND: + op_token = parsing::Token::Type::AND; break; - case noisepage::parser::ExpressionType::COMPARE_GREATER_THAN: + case parser::ExpressionType::COMPARE_GREATER_THAN: compare = true; - op_token = execution::parsing::Token::Type::GREATER; + op_token = parsing::Token::Type::GREATER; break; - case noisepage::parser::ExpressionType::COMPARE_GREATER_THAN_OR_EQUAL_TO: + case parser::ExpressionType::COMPARE_GREATER_THAN_OR_EQUAL_TO: compare = true; - op_token = execution::parsing::Token::Type::GREATER_EQUAL; + op_token = parsing::Token::Type::GREATER_EQUAL; break; - case noisepage::parser::ExpressionType::COMPARE_LESS_THAN_OR_EQUAL_TO: + case parser::ExpressionType::COMPARE_LESS_THAN_OR_EQUAL_TO: compare = true; - op_token = execution::parsing::Token::Type::LESS_EQUAL; + op_token = parsing::Token::Type::LESS_EQUAL; break; - case noisepage::parser::ExpressionType::COMPARE_LESS_THAN: + case parser::ExpressionType::COMPARE_LESS_THAN: compare = true; - op_token = execution::parsing::Token::Type::LESS; + op_token = parsing::Token::Type::LESS; break; - case noisepage::parser::ExpressionType::COMPARE_EQUAL: + case parser::ExpressionType::COMPARE_EQUAL: compare = true; - op_token = execution::parsing::Token::Type::EQUAL_EQUAL; + op_token = parsing::Token::Type::EQUAL_EQUAL; break; default: // TODO(Kyle): Figure out concatenation operation from expressions? @@ -315,9 +315,9 @@ void UdfCodegen::Visit(ast::udf::IfStmtAST *ast) { void UdfCodegen::Visit(ast::udf::IsNullExprAST *ast) { ast->Child()->Accept(this); auto chld = dst_; - dst_ = codegen_->CallBuiltin(execution::ast::Builtin::IsValNull, {chld}); + dst_ = codegen_->CallBuiltin(ast::Builtin::IsValNull, {chld}); if (!ast->IsNullCheck()) { - dst_ = codegen_->UnaryOp(execution::parsing::Token::Type::BANG, dst_); + dst_ = codegen_->UnaryOp(parsing::Token::Type::BANG, dst_); } } @@ -360,7 +360,7 @@ void UdfCodegen::Visit(ast::udf::ForIStmtAST *ast) { throw NOT_IMPLEMENTED_EXCEP void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { // Executing a SQL query requires an execution context needs_exec_ctx_ = true; - execution::ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); + ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); // Bind the embedded query; must do this prior to attempting // to optimize to ensure correctness @@ -389,10 +389,10 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { lambda_expr->SetName(lambda_identifier); // Materialize the lambda into the lambda expression - execution::exec::ExecutionSettings exec_settings{}; + exec::ExecutionSettings exec_settings{}; const std::string dummy_query{}; - auto exec_query = execution::compiler::CompilationContext::Compile( - *plan, exec_settings, accessor_, execution::compiler::CompilationMode::OneShot, std::nullopt, + auto exec_query = compiler::CompilationContext::Compile( + *plan, exec_settings, accessor_, compiler::CompilationMode::OneShot, std::nullopt, common::ManagedPointer{}, lambda_expr, codegen_->GetAstContext()); // Append all of the declarations from the compiled query @@ -406,7 +406,7 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { lambda_identifier, codegen_->LambdaType(lambda_expr->GetFunctionLiteralExpr()->TypeRepr()), lambda_expr)); // Set its execution context to whatever execution context was passed in here - fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::StartNewParams, {exec_ctx})); + fb_->Append(codegen_->CallBuiltin(ast::Builtin::StartNewParams, {exec_ctx})); CodegenAddParameters(exec_ctx, variable_refs); @@ -417,7 +417,7 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { // executable query (implementing the closure) to the builder CodegenTopLevelCalls(exec_query.get(), query_state, lambda_identifier); - fb_->Append(codegen_->CallBuiltin(execution::ast::Builtin::FinishNewParams, {exec_ctx})); + fb_->Append(codegen_->CallBuiltin(ast::Builtin::FinishNewParams, {exec_ctx})); } std::unique_ptr UdfCodegen::StartLambda(common::ManagedPointer plan, @@ -507,9 +507,9 @@ std::unique_ptr UdfCodegen::StartLambdaBindingToScalars( // The first parameter is always the execution context ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); - parameters.push_back(codegen_->MakeField( - exec_ctx->As()->Name(), - codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); + parameters.push_back( + codegen_->MakeField(exec_ctx->As()->Name(), + codegen_->PointerType(codegen_->BuiltinType(ast::BuiltinType::Kind::ExecutionContext)))); // Assignees are those captures that are written in the closure std::vector assignees{}; @@ -526,7 +526,7 @@ std::unique_ptr UdfCodegen::StartLambdaBindingToScalars( // Begin construction of the function that implements the closure auto builder = std::make_unique(codegen_, std::move(parameters), std::move(captures), - codegen_->BuiltinType(execution::ast::BuiltinType::Nil)); + codegen_->BuiltinType(ast::BuiltinType::Nil)); // Generate an assignment from each input parameter to the associated capture for (std::size_t i = 0UL; i < assignees.size(); ++i) { @@ -672,9 +672,9 @@ ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToScalars(common::ManagedPointer

GetParameterByPosition(0); - parameters.push_back(codegen_->MakeField( - exec_ctx->As()->Name(), - codegen_->PointerType(codegen_->BuiltinType(execution::ast::BuiltinType::Kind::ExecutionContext)))); + parameters.push_back( + codegen_->MakeField(exec_ctx->As()->Name(), + codegen_->PointerType(codegen_->BuiltinType(ast::BuiltinType::Kind::ExecutionContext)))); // Populate the remainder of the parameters and captures for (std::size_t i = 0; i < n_columns; ++i) { @@ -690,7 +690,7 @@ ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToScalars(common::ManagedPointer

BuiltinType(execution::ast::BuiltinType::Nil)}; + codegen_->BuiltinType(ast::BuiltinType::Nil)}; // Generate an assignment from each input parameter to the associated capture for (std::size_t i = 0UL; i < assignees.size(); ++i) { @@ -702,7 +702,7 @@ ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToScalars(common::ManagedPointer

&variable_refs) { @@ -744,6 +744,10 @@ void UdfCodegen::CodegenAddTableParameter(ast::Expr *exec_ctx, const parser::udf fb_->Append(codegen_->CallBuiltin(AddParamBuiltinForParameterType(type), {exec_ctx, expr})); } +/* ---------------------------------------------------------------------------- + Code Gneration Helpers: Bound Variable Initialization +---------------------------------------------------------------------------- */ + void UdfCodegen::CodegenBoundVariableInit(common::ManagedPointer plan, const std::vector &bound_variables) { if (bound_variables.empty()) { @@ -771,7 +775,7 @@ void UdfCodegen::CodegenBoundVariableInitForScalars(common::ManagedPointerGetOutputSchema()->GetColumn(i); const auto &variable = bound_variables.at(i); - execution::ast::Expr *capture = codegen_->MakeExpr(SymbolTable().find(variable)->second); + ast::Expr *capture = codegen_->MakeExpr(SymbolTable().find(variable)->second); fb_->Append(codegen_->Assign(capture, codegen_->ConstNull(column.GetType()))); } } @@ -882,23 +886,23 @@ ast::Builtin UdfCodegen::AddParamBuiltinForParameterType(type::TypeId parameter_ // dispatch table, but honestly that would be overkill at this point switch (parameter_type) { case type::TypeId::BOOLEAN: - return execution::ast::Builtin::AddParamBool; + return ast::Builtin::AddParamBool; case type::TypeId::TINYINT: - return execution::ast::Builtin::AddParamTinyInt; + return ast::Builtin::AddParamTinyInt; case type::TypeId::SMALLINT: - return execution::ast::Builtin::AddParamSmallInt; + return ast::Builtin::AddParamSmallInt; case type::TypeId::INTEGER: - return execution::ast::Builtin::AddParamInt; + return ast::Builtin::AddParamInt; case type::TypeId::BIGINT: - return execution::ast::Builtin::AddParamBigInt; + return ast::Builtin::AddParamBigInt; case type::TypeId::DECIMAL: - return execution::ast::Builtin::AddParamDouble; + return ast::Builtin::AddParamDouble; case type::TypeId::DATE: - return execution::ast::Builtin::AddParamDate; + return ast::Builtin::AddParamDate; case type::TypeId::TIMESTAMP: - return execution::ast::Builtin::AddParamTimestamp; + return ast::Builtin::AddParamTimestamp; case type::TypeId::VARCHAR: - return execution::ast::Builtin::AddParamString; + return ast::Builtin::AddParamString; default: UNREACHABLE("Unsupported parameter type"); } From 41bd5a21685f8b09c6712a91569fd36f968ed0da Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 6 Aug 2021 22:28:27 -0400 Subject: [PATCH 098/139] fig bug in modified translator identified by clang tidy --- docs/design_codegen.md | 10 ++++------ .../operator/static_aggregation_translator.cpp | 1 + 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/design_codegen.md b/docs/design_codegen.md index a511a206f1..9898adfca1 100644 --- a/docs/design_codegen.md +++ b/docs/design_codegen.md @@ -8,15 +8,15 @@ As described in the _Execution Engine Design Document_, NoisePage utilizes [data Our goal in code generation is to produce a bytecode program that implements a query plan. -The straightforward and most common way of accomplishing this is to have each operator in the query plan tree assume responsibility for generating the code that it requires to execute. The complete byetcode program might then be realized by having each operator generate code into a distinct bytecode function and then chaining these functions together via calls from the functions produced by parent operators to those produced by child operators. +The straightforward and most common way of accomplishingz this is to have each operator in the query plan tree assume responsibility for generating the code that it requires to execute. The complete byetcode program might then be realized by having each operator generate code into a distinct bytecode function and then chaining these functions together via calls from the functions produced by parent operators to those produced by child operators. As mentioned above, this approach is straightforward to reason about and to implement. The code generated for each operator is nicely self-contained in a single bytecode function, allowing developers to verify the correctness of the generated code and debug code generation issues. However, the simplicity of this approach comes at the cost of query runtime performance. We now incur function-call overhead in the transition between each operator. More importantly, we leave ourselves open to the same performance issues present in any operator-centric execution model: poor code and data locality resulting from tuple-at-a-time processing among each operator. -Data-centric code generation is a solution to these performance issues. TODO +Data-centric code generation is a solution to these performance issues. In this paradigm, code is generated according to the data dependencies between individual operators, rather than along operator boundaries themselves. In practice, this has the effect of _fusing_ multiple operators together into larger units called _pipelines_. When multiple operators are fused into a pipeline, all of the operations required to implement the logic of each operator may be performed in sequence, without incurring function call overhead or even spilling tuple attributes to memory - it is often possible to keep tuple attributes in registers for the duration of a pipeline, dramatically improving CPU efficiency. ### Pipelines -TODO +Pipelines are the lowest-level unit of code generation in the NoisePage query compilation architecture. Individual operators are assigned to a pipeline (some operators may be part of more than one pipeline e.g. `JOIN` operators). The pipeline defines a set of top-level bytecode functions to generate, and invokes a set of pre-defined member functions on each of its operators to populate the body of each of these functions. The specifics of each of the functions defined by each pipeline are described below. **Complications** @@ -32,7 +32,7 @@ The are several flavors of pipelines within NoisePage that differ slightly in th #### Serial Pipelines -We begin the discussion with serial pipelines because they are slightly less complicated. +We begin the discussion with serial pipelines because they are slightly less complicated than their parallel counterparts. **State Initialization** @@ -126,8 +126,6 @@ The reason that a pointer to the pipeline state is provided to the call in the l Parallel pipelines require different semantics from serial pipelines. Despite these differences, only the _Work_ function is affected by the change from a serial to a parallel pipeline. -TODO - ### References - [Efficiently Compiling Efficient Query Plans for Modern Hardware](https://15721.courses.cs.cmu.edu/spring2020/papers/14-compilation/p539-neumann.pdf) by Thomas Neumann. The paper that introduced the concept of data-centric code generation, among other techniques now considered standard best-practice in compiling query engines. diff --git a/src/execution/compiler/operator/static_aggregation_translator.cpp b/src/execution/compiler/operator/static_aggregation_translator.cpp index 43a81266b0..f2c8090842 100644 --- a/src/execution/compiler/operator/static_aggregation_translator.cpp +++ b/src/execution/compiler/operator/static_aggregation_translator.cpp @@ -24,6 +24,7 @@ StaticAggregationTranslator::StaticAggregationTranslator(const planner::Aggregat agg_row_var_(GetCodeGen()->MakeFreshIdentifier(AGG_ROW_VAR)), agg_payload_type_(MakeGlobalIdentifier(AGG_PAYLOAD_TYPE)), agg_values_type_(MakeGlobalIdentifier(AGG_VALUES_TYPE)), + merge_func_(MakeGlobalIdentifier(AGG_MERGE_FUNC)), build_pipeline_(this, Pipeline::Parallelism::Parallel) { NOISEPAGE_ASSERT(plan.GetGroupByTerms().empty(), "Global aggregations shouldn't have grouping keys"); NOISEPAGE_ASSERT(plan.GetChildrenSize() == 1, "Global aggregations should only have one child"); From 4847199b2dc12379b27776d932ad49e05742a48a Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sun, 8 Aug 2021 00:06:54 -0400 Subject: [PATCH 099/139] fix implementation of function calls, going to refactor a few things in code generation to make them pretty and then call it for the features of this PR --- script/testing/junit/sql/udf.sql | 17 ++++++++ script/testing/junit/traces/udf.test | 30 +++++++++++++ src/execution/compiler/udf/udf_codegen.cpp | 43 ++++++++++++------- src/include/execution/ast/udf/udf_ast_nodes.h | 2 +- .../execution/compiler/udf/udf_codegen.h | 12 +++--- src/include/parser/udf/plpgsql_parser.h | 7 +++ src/parser/udf/plpgsql_parser.cpp | 33 +++++++------- 7 files changed, 104 insertions(+), 40 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index bfd4f1b813..874f413270 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -297,3 +297,20 @@ END \ $$ LANGUAGE PLPGSQL; SELECT proc_fors_var(); + +-- ---------------------------------------------------------------------------- +-- proc_call_ret() + +CREATE FUNCTION proc_call_ret_callee() RETURNS INT AS $$ \ +BEGIN \ + RETURN 1; \ +END \ +$$ LANGUAGE PLPGSQL; + +CREATE FUNCTION proc_call_ret_caller() RETURNS INT AS $$ \ +BEGIN \ + RETURN proc_call_ret_callee(); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_call_ret_caller(); diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 96de75061a..9c44b752a5 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -540,3 +540,33 @@ SELECT proc_fors_var(); 3 +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_call_ret() + +statement ok + + +statement ok +CREATE FUNCTION proc_call_ret_callee() RETURNS INT AS $$ BEGIN RETURN 1; END $$ LANGUAGE PLPGSQL; + +statement ok + + +statement ok +CREATE FUNCTION proc_call_ret_caller() RETURNS INT AS $$ BEGIN RETURN proc_call_ret_callee(); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_call_ret_caller(); +---- +1 + + diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 352122d21e..2247b3c84b 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -35,8 +35,7 @@ UdfCodegen::UdfCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, udf_ast_context_{udf_ast_context}, codegen_{codegen}, db_oid_{db_oid}, - aux_decls_(codegen->GetAstContext()->GetRegion()), - needs_exec_ctx_{false} { + aux_decls_{codegen->GetAstContext()->GetRegion()} { for (auto i = 0UL; fb->GetParameterByPosition(i) != nullptr; ++i) { auto param = fb->GetParameterByPosition(i); const auto &name = param->As()->Name(); @@ -97,6 +96,13 @@ void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { std::vector args_ast_region_vec{}; std::vector arg_types{}; + // First argument to UDF is an execution context + args_ast_region_vec.push_back(GetExecutionContext()); + + // TODO(Kyle): Is this the semantics we want? The execution + // context for the entire TPL program is shared? + + // TODO(Kyle): Clean up this logic for (auto &arg : ast->Args()) { arg->Accept(this); args_ast.push_back(dst_); @@ -106,12 +112,17 @@ void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { NOISEPAGE_ASSERT(builtin->IsSqlValueType(), "Parameter must be a SQL value type"); arg_types.push_back(GetCatalogTypeOidFromSQLType(builtin->GetKind())); } - auto proc_oid = accessor_->GetProcOid(ast->Callee(), arg_types); - NOISEPAGE_ASSERT(proc_oid != catalog::INVALID_PROC_OID, "Invalid call"); + + const auto proc_oid = accessor_->GetProcOid(ast->Callee(), arg_types); + if (proc_oid == catalog::INVALID_PROC_OID) { + throw BINDER_EXCEPTION(fmt::format("Invalid function call '{}'", ast->Callee()), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } auto context = accessor_->GetProcCtxPtr(proc_oid); if (context->IsBuiltin()) { - fb_->Append(codegen_->MakeStmt(codegen_->CallBuiltin(context->GetBuiltin(), args_ast))); + ast::Expr *result = codegen_->CallBuiltin(context->GetBuiltin(), args_ast); + dst_ = result; } else { auto it = SymbolTable().find(ast->Callee()); ast::Identifier ident_expr; @@ -127,7 +138,8 @@ void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { ident_expr = codegen_->MakeFreshIdentifier(file->Declarations().back()->Name().GetString()); SymbolTable()[file->Declarations().back()->Name().GetString()] = ident_expr; } - fb_->Append(codegen_->MakeStmt(codegen_->Call(ident_expr, args_ast_region_vec))); + ast::Expr *result = codegen_->Call(ident_expr, args_ast_region_vec); + dst_ = result; } } @@ -336,6 +348,7 @@ void UdfCodegen::Visit(ast::udf::WhileStmtAST *ast) { } void UdfCodegen::Visit(ast::udf::RetStmtAST *ast) { + // TODO(Kyle): Handle NULL returns ast->Return()->Accept(reinterpret_cast(this)); auto ret_expr = dst_; fb_->Append(codegen_->Return(ret_expr)); @@ -359,8 +372,7 @@ void UdfCodegen::Visit(ast::udf::ForIStmtAST *ast) { throw NOT_IMPLEMENTED_EXCEP void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { // Executing a SQL query requires an execution context - needs_exec_ctx_ = true; - ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); + ast::Expr *exec_ctx = GetExecutionContext(); // Bind the embedded query; must do this prior to attempting // to optimize to ensure correctness @@ -446,7 +458,7 @@ std::unique_ptr UdfCodegen::StartLambdaBindingToRecord( util::RegionVector parameters{codegen_->GetAstContext()->GetRegion()}; // The first parameter is always the execution context - ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); + ast::Expr *exec_ctx = GetExecutionContext(); parameters.push_back( codegen_->MakeField(exec_ctx->As()->Name(), codegen_->PointerType(codegen_->BuiltinType(ast::BuiltinType::Kind::ExecutionContext)))); @@ -506,7 +518,7 @@ std::unique_ptr UdfCodegen::StartLambdaBindingToScalars( } // The first parameter is always the execution context - ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); + ast::Expr *exec_ctx = GetExecutionContext(); parameters.push_back( codegen_->MakeField(exec_ctx->As()->Name(), codegen_->PointerType(codegen_->BuiltinType(ast::BuiltinType::Kind::ExecutionContext)))); @@ -543,8 +555,7 @@ std::unique_ptr UdfCodegen::StartLambdaBindingToScalars( void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { // Executing a SQL query requires an execution context - needs_exec_ctx_ = true; - ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); + ast::Expr *exec_ctx = GetExecutionContext(); // Bind the embedded query; must do this prior to attempting // to optimize to ensure correctness @@ -622,8 +633,7 @@ ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToRecord(common::ManagedPointer parameters{codegen_->GetAstContext()->GetRegion()}; - // The first parameter is always the execution context - ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); + ast::Expr *exec_ctx = GetExecutionContext(); parameters.push_back( codegen_->MakeField(exec_ctx->As()->Name(), codegen_->PointerType(codegen_->BuiltinType(ast::BuiltinType::Kind::ExecutionContext)))); @@ -670,8 +680,7 @@ ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToScalars(common::ManagedPointer

captures{codegen_->GetAstContext()->GetRegion()}; - // The first parameter is always the execution context - ast::Expr *exec_ctx = fb_->GetParameterByPosition(0); + ast::Expr *exec_ctx = GetExecutionContext(); parameters.push_back( codegen_->MakeField(exec_ctx->As()->Name(), codegen_->PointerType(codegen_->BuiltinType(ast::BuiltinType::Kind::ExecutionContext)))); @@ -844,6 +853,8 @@ void UdfCodegen::CodegenTopLevelCalls(const ExecutableQuery *exec_query, ast::Id General Utilities ---------------------------------------------------------------------------- */ +ast::Expr *UdfCodegen::GetExecutionContext() { return fb_->GetParameterByPosition(0); } + type::TypeId UdfCodegen::GetVariableType(const std::string &name) const { auto type = udf_ast_context_->GetVariableType(name); if (!type.has_value()) { diff --git a/src/include/execution/ast/udf/udf_ast_nodes.h b/src/include/execution/ast/udf/udf_ast_nodes.h index 811dfeaa68..3546a2f07c 100644 --- a/src/include/execution/ast/udf/udf_ast_nodes.h +++ b/src/include/execution/ast/udf/udf_ast_nodes.h @@ -267,7 +267,7 @@ class CallExprAST : public ExprAST { private: /** The name of the called function */ - std::string callee_; + const std::string callee_; /** The arguments to the function call */ std::vector> args_; diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index 0c6209d88c..9c4a03afd1 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -61,6 +61,8 @@ class ForSStmtAST; namespace compiler::udf { +class ExpressionResultScope; + /** * The UdfCodegen class implements a visitor for UDF AST * nodes and encapsulates all of the logic required to generate @@ -421,17 +423,20 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { static ast::Builtin AddParamBuiltinForParameterType(type::TypeId parameter_type); /** - * Sort the query + * TODO(Kyle): this */ static std::vector ParametersSortedByIndex( const std::unordered_map> ¶meter_map); /** - * + * TODO(Kyle): this */ static std::vector ColumnsSortedByIndex( const std::unordered_map> ¶meter_map); + /** @return The execution context provided to the function */ + ast::Expr *GetExecutionContext(); + private: /** The string identifier for internal declarations */ constexpr static const char INTERNAL_DECL_ID[] = "*internal*"; @@ -454,9 +459,6 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { /** Auxiliary declarations */ execution::util::RegionVector aux_decls_; - /** Flag indicating whether this UDF requires an execution context */ - bool needs_exec_ctx_; - /** The current type during code generation */ type::TypeId current_type_{type::TypeId::INVALID}; diff --git a/src/include/parser/udf/plpgsql_parser.h b/src/include/parser/udf/plpgsql_parser.h index e034394cfd..d334ee05ae 100644 --- a/src/include/parser/udf/plpgsql_parser.h +++ b/src/include/parser/udf/plpgsql_parser.h @@ -64,6 +64,13 @@ class PLpgSQLParser { */ std::unique_ptr ParseBlock(const nlohmann::json &json); + /** + * Parse a return statement. + * @param json The input JSON object + * @return The AST for the return statement + */ + std::unique_ptr ParseReturn(const nlohmann::json &json); + /** * Parse a function statement. * @param json The input JSON object diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index c535970f65..3f492b0ae0 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -78,8 +78,7 @@ std::unique_ptr PLpgSQLParser::Parse(const std throw PARSER_EXCEPTION("Function list has size other than 1"); } - // TODO(Kyle): This is a zip(), can we add our own generic - // algorithms library somewhere for stuff like this? + // TODO(Kyle): This is a zip() std::size_t i = 0; for (const auto &udf_name : param_names) { udf_ast_context_->SetVariableType(udf_name, param_types[i++]); @@ -115,13 +114,7 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl const StatementType statement_type = GetStatementType(statement.items().begin().key()); switch (statement_type) { case StatementType::RETURN: { - // TODO(Kyle): Handle RETURN without expression - if (statement[K_PLPGSQL_STMT_RETURN].empty()) { - throw NOT_IMPLEMENTED_EXCEPTION("PL/pgSQL Parser : RETURN without expression not implemented."); - } - auto expr = ParseExprFromSQLString( - statement[K_PLPGSQL_STMT_RETURN][K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); - statements.push_back(std::make_unique(std::move(expr))); + statements.push_back(ParseReturn(statement[K_PLPGSQL_STMT_RETURN])); break; } case StatementType::IF: { @@ -129,7 +122,6 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl break; } case StatementType::ASSIGN: { - // TODO(Kyle): Need to fix Assignment expression / statement statements.push_back(ParseAssign(statement[K_PLPGSQL_STMT_ASSIGN])); break; } @@ -163,6 +155,15 @@ std::unique_ptr PLpgSQLParser::ParseBlock(const nl return std::make_unique(std::move(statements)); } +std::unique_ptr PLpgSQLParser::ParseReturn(const nlohmann::json &json) { + // TODO(Kyle): Handle RETURN without expression + if (json.empty()) { + throw NOT_IMPLEMENTED_EXCEPTION("PL/pgSQL Parser : RETURN without expression not implemented."); + } + auto expr = ParseExprFromSQLString(json[K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); + return std::make_unique(std::move(expr)); +} + std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &json) { const auto &declaration_type = json.items().begin().key(); if (declaration_type == K_PLPGSQL_VAR) { @@ -417,9 +418,6 @@ std::optional> PLpgSQLParser::TryP std::move(*lhs), std::move(*rhs))); } - // TODO(Kyle): I am not a fan of non-exhaustive switch statements; - // is there a way that we can refactor this logic to make it better? - switch (expr->GetExpressionType()) { case parser::ExpressionType::COLUMN_VALUE: { auto cve = expr.CastManagedPointerTo(); @@ -433,13 +431,12 @@ std::optional> PLpgSQLParser::TryP case parser::ExpressionType::FUNCTION: { auto func_expr = expr.CastManagedPointerTo(); std::vector> args{}; - auto num_args = func_expr->GetChildrenSize(); - for (std::size_t idx = 0; idx < num_args; ++idx) { - auto arg = TryParseExprFromAbstract(func_expr->GetChild(idx)); - if (!arg.has_value()) { + for (auto child : func_expr->GetChildren()) { + auto argument = TryParseExprFromAbstract(child); + if (!argument.has_value()) { return std::nullopt; } - args.push_back(std::move(*arg)); + args.push_back(std::move(*argument)); } return std::make_optional( std::make_unique(func_expr->GetFuncName(), std::move(args))); From f347ffac646c00674ebfb4442a74197b8f52d4da Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sun, 8 Aug 2021 07:56:07 -0400 Subject: [PATCH 100/139] refactor execution result handling in udf code generation, slightly more readable now --- src/execution/compiler/udf/udf_codegen.cpp | 92 ++++++++++--------- .../execution/compiler/udf/udf_codegen.h | 21 ++++- 2 files changed, 66 insertions(+), 47 deletions(-) diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 2247b3c84b..0d6a25ef31 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -104,10 +104,10 @@ void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { // TODO(Kyle): Clean up this logic for (auto &arg : ast->Args()) { - arg->Accept(this); - args_ast.push_back(dst_); - args_ast_region_vec.push_back(dst_); - auto *builtin = dst_->GetType()->SafeAs(); + ast::Expr *result = EvaluateExpression(arg.get()); + args_ast.push_back(result); + args_ast_region_vec.push_back(result); + auto *builtin = result->GetType()->SafeAs(); NOISEPAGE_ASSERT(builtin != nullptr, "Parameter must be a built-in type"); NOISEPAGE_ASSERT(builtin->IsSqlValueType(), "Parameter must be a SQL value type"); arg_types.push_back(GetCatalogTypeOidFromSQLType(builtin->GetKind())); @@ -122,7 +122,7 @@ void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { auto context = accessor_->GetProcCtxPtr(proc_oid); if (context->IsBuiltin()) { ast::Expr *result = codegen_->CallBuiltin(context->GetBuiltin(), args_ast); - dst_ = result; + SetExecutionResult(result); } else { auto it = SymbolTable().find(ast->Callee()); ast::Identifier ident_expr; @@ -139,7 +139,7 @@ void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { SymbolTable()[file->Declarations().back()->Name().GetString()] = ident_expr; } ast::Expr *result = codegen_->Call(ident_expr, args_ast_region_vec); - dst_ = result; + SetExecutionResult(result); } } @@ -180,8 +180,8 @@ void UdfCodegen::Visit(ast::udf::DeclStmtAST *ast) { } current_type_ = ast->Type(); if (ast->Initial() != nullptr) { - ast->Initial()->Accept(this); - fb_->Append(codegen_->DeclareVar(identifier, tpl_type, dst_)); + ast::Expr *initializer = EvaluateExpression(ast->Initial()); + fb_->Append(codegen_->DeclareVar(identifier, tpl_type, initializer)); } else { fb_->Append(codegen_->DeclareVarNoInit(identifier, tpl_type)); } @@ -198,49 +198,51 @@ void UdfCodegen::Visit(ast::udf::FunctionAST *ast) { void UdfCodegen::Visit(ast::udf::VariableExprAST *ast) { auto it = SymbolTable().find(ast->Name()); NOISEPAGE_ASSERT(it != SymbolTable().end(), "Variable not declared"); - dst_ = codegen_->MakeExpr(it->second); + SetExecutionResult(codegen_->MakeExpr(it->second)); } void UdfCodegen::Visit(ast::udf::ValueExprAST *ast) { auto val = common::ManagedPointer(ast->Value()).CastManagedPointerTo(); if (val->IsNull()) { - dst_ = codegen_->ConstNull(current_type_); + SetExecutionResult(codegen_->ConstNull(current_type_)); return; } + + ast::Expr *expr; auto type_id = sql::GetTypeId(val->GetReturnValueType()); switch (type_id) { case sql::TypeId::Boolean: - dst_ = codegen_->BoolToSql(val->GetBoolVal().val_); + expr = codegen_->BoolToSql(val->GetBoolVal().val_); break; case sql::TypeId::TinyInt: case sql::TypeId::SmallInt: case sql::TypeId::Integer: case sql::TypeId::BigInt: - dst_ = codegen_->IntToSql(val->GetInteger().val_); + expr = codegen_->IntToSql(val->GetInteger().val_); break; case sql::TypeId::Float: case sql::TypeId::Double: - dst_ = codegen_->FloatToSql(val->GetReal().val_); + expr = codegen_->FloatToSql(val->GetReal().val_); case sql::TypeId::Date: - dst_ = codegen_->DateToSql(val->GetDateVal().val_); + expr = codegen_->DateToSql(val->GetDateVal().val_); break; case sql::TypeId::Timestamp: - dst_ = codegen_->TimestampToSql(val->GetTimestampVal().val_); + expr = codegen_->TimestampToSql(val->GetTimestampVal().val_); break; case sql::TypeId::Varchar: - dst_ = codegen_->StringToSql(val->GetStringVal().StringView()); + expr = codegen_->StringToSql(val->GetStringVal().StringView()); break; default: throw NOT_IMPLEMENTED_EXCEPTION("Unsupported type in UDF codegen"); } + SetExecutionResult(expr); } void UdfCodegen::Visit(ast::udf::AssignStmtAST *ast) { const type::TypeId left_type = GetVariableType(ast->Destination()->Name()); current_type_ = left_type; - reinterpret_cast(ast->Source())->Accept(this); - auto rhs_expr = dst_; + ast::Expr *rhs_expr = EvaluateExpression(ast->Source()); auto it = SymbolTable().find(ast->Destination()->Name()); NOISEPAGE_ASSERT(it != SymbolTable().end(), "Variable not found"); @@ -299,23 +301,16 @@ void UdfCodegen::Visit(ast::udf::BinaryExprAST *ast) { // TODO(Kyle): Figure out concatenation operation from expressions? UNREACHABLE("Unsupported expression"); } - ast->Left()->Accept(this); - auto lhs_expr = dst_; - - ast->Right()->Accept(this); - auto rhs_expr = dst_; - if (compare) { - dst_ = codegen_->Compare(op_token, lhs_expr, rhs_expr); - } else { - dst_ = codegen_->BinaryOp(op_token, lhs_expr, rhs_expr); - } + ast::Expr *lhs_expr = EvaluateExpression(ast->Left()); + ast::Expr *rhs_expr = EvaluateExpression(ast->Right()); + ast::Expr *result = + compare ? codegen_->Compare(op_token, lhs_expr, rhs_expr) : codegen_->BinaryOp(op_token, lhs_expr, rhs_expr); + SetExecutionResult(result); } void UdfCodegen::Visit(ast::udf::IfStmtAST *ast) { - ast->Condition()->Accept(this); - auto cond = dst_; - - If branch(fb_, cond); + ast::Expr *condition = EvaluateExpression(ast->Condition()); + If branch(fb_, condition); ast->Then()->Accept(this); if (ast->Else() != nullptr) { branch.Else(); @@ -325,11 +320,11 @@ void UdfCodegen::Visit(ast::udf::IfStmtAST *ast) { } void UdfCodegen::Visit(ast::udf::IsNullExprAST *ast) { - ast->Child()->Accept(this); - auto chld = dst_; - dst_ = codegen_->CallBuiltin(ast::Builtin::IsValNull, {chld}); + ast::Expr *child = EvaluateExpression(ast->Child()); + ast::Expr *null_check = codegen_->CallBuiltin(ast::Builtin::IsValNull, {child}); + SetExecutionResult(null_check); if (!ast->IsNullCheck()) { - dst_ = codegen_->UnaryOp(parsing::Token::Type::BANG, dst_); + SetExecutionResult(codegen_->UnaryOp(parsing::Token::Type::BANG, null_check)); } } @@ -340,24 +335,22 @@ void UdfCodegen::Visit(ast::udf::SeqStmtAST *ast) { } void UdfCodegen::Visit(ast::udf::WhileStmtAST *ast) { - ast->Condition()->Accept(this); - auto cond = dst_; - Loop loop(fb_, cond); + ast::Expr *condition = EvaluateExpression(ast->Condition()); + Loop loop(fb_, condition); ast->Body()->Accept(this); loop.EndLoop(); } void UdfCodegen::Visit(ast::udf::RetStmtAST *ast) { // TODO(Kyle): Handle NULL returns - ast->Return()->Accept(reinterpret_cast(this)); - auto ret_expr = dst_; - fb_->Append(codegen_->Return(ret_expr)); + ast::Expr *return_expr = EvaluateExpression(ast->Return()); + fb_->Append(codegen_->Return(return_expr)); } void UdfCodegen::Visit(ast::udf::MemberExprAST *ast) { - ast->Object()->Accept(reinterpret_cast(this)); - auto object = dst_; - dst_ = codegen_->AccessStructMember(object, codegen_->MakeIdentifier(ast->FieldName())); + ast::Expr *object = EvaluateExpression(ast->Object()); + ast::Expr *access = codegen_->AccessStructMember(object, codegen_->MakeIdentifier(ast->FieldName())); + SetExecutionResult(access); } /* ---------------------------------------------------------------------------- @@ -855,6 +848,15 @@ void UdfCodegen::CodegenTopLevelCalls(const ExecutableQuery *exec_query, ast::Id ast::Expr *UdfCodegen::GetExecutionContext() { return fb_->GetParameterByPosition(0); } +ast::Expr *UdfCodegen::GetExecutionResult() { return execution_result_; } + +void UdfCodegen::SetExecutionResult(ast::Expr *result) { execution_result_ = result; } + +ast::Expr *UdfCodegen::EvaluateExpression(ast::udf::ExprAST *expr) { + expr->Accept(this); + return GetExecutionResult(); +} + type::TypeId UdfCodegen::GetVariableType(const std::string &name) const { auto type = udf_ast_context_->GetVariableType(name); if (!type.has_value()) { diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index 9c4a03afd1..5002b7fa2f 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -437,6 +437,23 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { /** @return The execution context provided to the function */ ast::Expr *GetExecutionContext(); + /** @return The current execution result expression */ + ast::Expr *GetExecutionResult(); + + /** + * Set the current execution result expression. + * @param The execution result expression + */ + void SetExecutionResult(ast::Expr *result); + + /** + * Stage evaluation of the expression `expr` by generating + * code to perform the evaluation (at runtime). + * @param expr The expression to evaluate + * @return The result of evaluating the expression + */ + ast::Expr *EvaluateExpression(ast::udf::ExprAST *expr); + private: /** The string identifier for internal declarations */ constexpr static const char INTERNAL_DECL_ID[] = "*internal*"; @@ -462,8 +479,8 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { /** The current type during code generation */ type::TypeId current_type_{type::TypeId::INVALID}; - /** The destination expression */ - execution::ast::Expr *dst_; + /** The current execution result expression */ + execution::ast::Expr *execution_result_; /** Map from human-readable string identifier to internal identifier */ std::unordered_map symbol_table_; From 08327aad9aeaa79194a6f4a9164f1f9f3ed7cd2c Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sun, 8 Aug 2021 19:00:31 -0400 Subject: [PATCH 101/139] more integration tests for function calls --- script/testing/junit/sql/udf.sql | 45 ++++++++++++++++++++------ script/testing/junit/traces/udf.test | 47 +++++++++++++++++++++++++--- 2 files changed, 78 insertions(+), 14 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 874f413270..464edaf2a1 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -299,18 +299,43 @@ $$ LANGUAGE PLPGSQL; SELECT proc_fors_var(); -- ---------------------------------------------------------------------------- --- proc_call_ret() +-- proc_call_*() -CREATE FUNCTION proc_call_ret_callee() RETURNS INT AS $$ \ -BEGIN \ - RETURN 1; \ -END \ +CREATE FUNCTION proc_call_callee() RETURNS INT AS $$ \ +BEGIN \ + RETURN 1; \ +END \ $$ LANGUAGE PLPGSQL; -CREATE FUNCTION proc_call_ret_caller() RETURNS INT AS $$ \ -BEGIN \ - RETURN proc_call_ret_callee(); \ -END \ +-- Just RETURN the result of call +CREATE FUNCTION proc_call_ret() RETURNS INT AS $$ \ +BEGIN \ + RETURN proc_call_callee(); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_call_ret(); + +-- Assign the result of call to variable +CREATE FUNCTION proc_call_assign() RETURNS INT AS $$ \ +DECLARE \ + v INT; \ +BEGIN \ + v = proc_call_callee(); \ + RETURN v; \ +END \ +$$ LANGUAGE PLPGSQl; + +SELECT proc_call_assign(); + +-- SELECT the result of call into variable +CREATE FUNCTION proc_call_select() RETURNS INT AS $$ \ +DECLARE \ + v INT; \ +BEGIN \ + SELECT proc_call_callee() INTO v; \ + RETURN v; \ +END \ $$ LANGUAGE PLPGSQL; -SELECT proc_call_ret_caller(); +SELECT proc_call_select(); diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 9c44b752a5..2d99216478 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -547,25 +547,64 @@ statement ok -- ---------------------------------------------------------------------------- statement ok --- proc_call_ret() +-- proc_call_*() statement ok statement ok -CREATE FUNCTION proc_call_ret_callee() RETURNS INT AS $$ BEGIN RETURN 1; END $$ LANGUAGE PLPGSQL; +CREATE FUNCTION proc_call_callee() RETURNS INT AS $$ BEGIN RETURN 1; END $$ LANGUAGE PLPGSQL; statement ok statement ok -CREATE FUNCTION proc_call_ret_caller() RETURNS INT AS $$ BEGIN RETURN proc_call_ret_callee(); END $$ LANGUAGE PLPGSQL; +-- Just RETURN the result of call + +statement ok +CREATE FUNCTION proc_call_ret() RETURNS INT AS $$ BEGIN RETURN proc_call_callee(); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_call_ret(); +---- +1 + + +statement ok + + +statement ok +-- Assign the result of call to variable + +statement ok +CREATE FUNCTION proc_call_assign() RETURNS INT AS $$ DECLARE v INT; BEGIN v = proc_call_callee(); RETURN v; END $$ LANGUAGE PLPGSQl; + +statement ok + + +query I rowsort +SELECT proc_call_assign(); +---- +1 + + +statement ok + + +statement ok +-- SELECT the result of call into variable + +statement ok +CREATE FUNCTION proc_call_select() RETURNS INT AS $$ DECLARE v INT; BEGIN SELECT proc_call_callee() INTO v; RETURN v; END $$ LANGUAGE PLPGSQL; statement ok query I rowsort -SELECT proc_call_ret_caller(); +SELECT proc_call_select(); ---- 1 From d28c2849c4cfa26e6b7ce3aae00ea3848e2f80f1 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sun, 8 Aug 2021 23:23:51 -0400 Subject: [PATCH 102/139] finally figured out how to get type names from the postgres parser after some RE work, implementing DROP FUNCTION --- src/binder/bind_node_visitor.cpp | 3 ++ src/include/network/network_defs.h | 1 + src/include/parser/drop_statement.h | 34 +++++++++++++++++- src/include/parser/nodes.h | 15 ++++++++ src/include/parser/postgresparser.h | 1 + .../query_to_operator_transformer.cpp | 1 + src/parser/postgresparser.cpp | 35 +++++++++++++++++++ src/traffic_cop/traffic_cop_util.cpp | 2 ++ 8 files changed, 91 insertions(+), 1 deletion(-) diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index ab572e7777..a2c24c9210 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -323,6 +323,9 @@ void BindNodeVisitor::Visit(common::ManagedPointer node) common::ErrorCode::ERRCODE_UNDEFINED_OBJECT); } break; + case parser::DropStatement::DropType::kFunction: + ValidateDatabaseName(node->GetDatabaseName()); + throw NOT_IMPLEMENTED_EXCEPTION("DROP FUNCTION Not Implemented"); case parser::DropStatement::DropType::kTrigger: // TODO(Ling): Get Trigger OID in catalog? case parser::DropStatement::DropType::kSchema: diff --git a/src/include/network/network_defs.h b/src/include/network/network_defs.h index d15410de68..ed29f2fa7b 100644 --- a/src/include/network/network_defs.h +++ b/src/include/network/network_defs.h @@ -113,6 +113,7 @@ enum class QueryType : uint8_t { QUERY_CREATE_VIEW, QUERY_DROP_TABLE, QUERY_DROP_DB, + QUERY_DROP_FUNCTION, QUERY_DROP_INDEX, QUERY_DROP_TRIGGER, QUERY_DROP_SCHEMA, diff --git a/src/include/parser/drop_statement.h b/src/include/parser/drop_statement.h index 946cd6aa29..59f707b496 100644 --- a/src/include/parser/drop_statement.h +++ b/src/include/parser/drop_statement.h @@ -3,6 +3,7 @@ #include #include #include +#include #include "binder/sql_node_visitor.h" #include "parser/sql_statement.h" @@ -15,7 +16,12 @@ namespace parser { class DropStatement : public TableRefStatement { public: /** Drop statement type. */ - enum class DropType { kDatabase, kTable, kSchema, kIndex, kView, kPreparedStatement, kTrigger }; + enum class DropType { kDatabase, kTable, kSchema, kIndex, kView, kPreparedStatement, kTrigger, kFunction }; + + // TODO(Kyle): This class is becoming overly-overloaded. + // For instance, I can't define the interface for a ctor + // for DROP FUNCTION that is identical to DROP INDEX. + // Additionally, we carry a bunch of useless state around. /** * DROP DATABASE, DROP TABLE @@ -26,6 +32,19 @@ class DropStatement : public TableRefStatement { DropStatement(std::unique_ptr table_info, DropType type, bool if_exists) : TableRefStatement(StatementType::DROP, std::move(table_info)), type_(type), if_exists_(if_exists) {} + /** + * DROP FUNCTION + * @param table_info table information + * @param function_name function name + * @param function_args function argument types + */ + DropStatement(std::unique_ptr table_info, std::string function_name, + std::vector &&function_args) + : TableRefStatement(StatementType::DROP, std::move(table_info)), + type_(DropType::kFunction), + function_name_(std::move(function_name)), + function_args_(std::move(function_args)) {} + /** * DROP INDEX * @param table_info table information @@ -79,6 +98,15 @@ class DropStatement : public TableRefStatement { /** @return trigger name for [DROP TRIGGER] */ std::string GetTriggerName() { return trigger_name_; } + /** @return function name for [DROP FUNCTION] */ + std::string GetFunctionName() { return function_name_; } + + /** @return function arguments for [DROP FUNCTION] */ + const std::vector &GetFunctionArguments() const { return function_args_; } + + // TODO(Kyle): Why are we returning all of these strings by value? + // It appears that we can just use const references here... + private: const DropType type_; @@ -93,6 +121,10 @@ class DropStatement : public TableRefStatement { // DROP TRIGGER const std::string trigger_name_; + + // DROP FUNCTION + const std::string function_name_; + std::vector function_args_; }; } // namespace parser diff --git a/src/include/parser/nodes.h b/src/include/parser/nodes.h index c004a40ef0..bf5bd68f18 100644 --- a/src/include/parser/nodes.h +++ b/src/include/parser/nodes.h @@ -15,3 +15,18 @@ using value = struct Value { char *str_; /**< string */ } val_; /**< value */ }; + +/** + * A typename parsenode as produced by the Postgres parser + */ +using typname = struct TypName { + NodeTag type_; + List *names_; + Oid typeOid_; + bool setof_; + bool pct_type_; + List *typmods_; + int32_t typemod_; + List *arrayBounds_; + int location_; +}; diff --git a/src/include/parser/postgresparser.h b/src/include/parser/postgresparser.h index 82c9a2aa6c..99a3aab2ab 100644 --- a/src/include/parser/postgresparser.h +++ b/src/include/parser/postgresparser.h @@ -164,6 +164,7 @@ class PostgresParser { // DROP statements static std::unique_ptr DropTransform(ParseResult *parse_result, DropStmt *root); static std::unique_ptr DropDatabaseTransform(ParseResult *parse_result, DropDatabaseStmt *root); + static std::unique_ptr DropFunctionTransform(ParseResult *parse_result, DropStmt *root); static std::unique_ptr DropIndexTransform(ParseResult *parse_result, DropStmt *root); static std::unique_ptr DropSchemaTransform(ParseResult *parse_result, DropStmt *root); static std::unique_ptr DropTableTransform(ParseResult *parse_result, DropStmt *root); diff --git a/src/optimizer/query_to_operator_transformer.cpp b/src/optimizer/query_to_operator_transformer.cpp index 2d9a03fd22..f14953b78c 100644 --- a/src/optimizer/query_to_operator_transformer.cpp +++ b/src/optimizer/query_to_operator_transformer.cpp @@ -636,6 +636,7 @@ void QueryToOperatorTransformer::Visit(common::ManagedPointer>{}, txn_context); break; + case parser::DropStatement::DropType::kFunction: case parser::DropStatement::DropType::kTrigger: case parser::DropStatement::DropType::kView: case parser::DropStatement::DropType::kPreparedStatement: diff --git a/src/parser/postgresparser.cpp b/src/parser/postgresparser.cpp index a89cef3dfb..ddb5c4e588 100644 --- a/src/parser/postgresparser.cpp +++ b/src/parser/postgresparser.cpp @@ -48,6 +48,10 @@ std::unique_ptr PostgresParser::BuildParseTree(const std::s auto result = pg_query_parse(text); // Parse the query string with the Postgres parser. + + // TODO(Kyle): Syntax "DROP FUNCTION fun;" fails in the + // Postgres parser, do we need to update the version? + if (result.error != nullptr) { PARSER_LOG_DEBUG("BuildParseTree error: msg {}, curpos {}", result.error->message, result.error->cursorpos); @@ -1750,6 +1754,9 @@ std::unique_ptr PostgresParser::DeleteTransform(ParseResult *pa // Postgres.DropStmt -> noisepage.DropStatement std::unique_ptr PostgresParser::DropTransform(ParseResult *parse_result, DropStmt *root) { switch (root->remove_type_) { + case ObjectType::OBJECT_FUNCTION: { + return DropFunctionTransform(parse_result, root); + } case ObjectType::OBJECT_INDEX: { return DropIndexTransform(parse_result, root); } @@ -1778,6 +1785,34 @@ std::unique_ptr PostgresParser::DropDatabaseTransform(ParseResult return result; } +// Postgres.DropStmt -> noisepage.DropStatement +std::unique_ptr PostgresParser::DropFunctionTransform(ParseResult *parse_result, DropStmt *root) { + // Grab the function name + auto objects = reinterpret_cast(root->objects_->head->data.ptr_value); + std::string function_name = reinterpret_cast(objects->head->data.ptr_value)->val_.str_; + + // Grab the argument types from the function signature + auto arguments = reinterpret_cast(root->arguments_->head->data.ptr_value); + + std::vector function_args{}; + function_args.reserve(arguments->length); + for (auto *cell = arguments->head; cell != nullptr; cell = cell->next) { + // The descriptor for some types consists of a head node with + // "pg_catalog" as the string value, so we need to skip over + auto *descriptor = reinterpret_cast(cell->data.ptr_value)->names_; + if (descriptor->length > 1) { + std::string type = reinterpret_cast(descriptor->head->next->data.ptr_value)->val_.str_; + function_args.emplace_back(std::move(type)); + } else { + std::string type = reinterpret_cast(descriptor->head->data.ptr_value)->val_.str_; + function_args.emplace_back(std::move(type)); + } + } + + return std::make_unique(std::make_unique("", "", ""), std::move(function_name), + std::move(function_args)); +} + // Postgres.DropStmt -> noisepage.DropStatement std::unique_ptr PostgresParser::DropIndexTransform(ParseResult *parse_result, DropStmt *root) { // TODO(WAN): There are unimplemented DROP INDEX options. diff --git a/src/traffic_cop/traffic_cop_util.cpp b/src/traffic_cop/traffic_cop_util.cpp index cab061983d..4480d1253b 100644 --- a/src/traffic_cop/traffic_cop_util.cpp +++ b/src/traffic_cop/traffic_cop_util.cpp @@ -163,6 +163,8 @@ network::QueryType TrafficCopUtil::QueryTypeForStatement(const common::ManagedPo return network::QueryType::QUERY_DROP_PREPARED_STATEMENT; case parser::DropStatement::DropType::kTrigger: return network::QueryType::QUERY_DROP_TRIGGER; + case parser::DropStatement::DropType::kFunction: + return network::QueryType::QUERY_DROP_FUNCTION; } } case parser::StatementType::VARIABLE_SET: From a403aa84bc0a0a79564e824a1546674edff8e889 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Mon, 9 Aug 2021 19:34:32 -0400 Subject: [PATCH 103/139] add DROP FUNCTION, some double-free bug when destroying the UDF proc context --- src/binder/bind_node_visitor.cpp | 11 +- src/catalog/catalog_accessor.cpp | 37 ++++-- src/catalog/database_catalog.cpp | 23 ++-- src/catalog/postgres/pg_proc_impl.cpp | 8 +- src/execution/compiler/udf/udf_codegen.cpp | 2 +- src/execution/sql/ddl_executors.cpp | 12 +- src/include/catalog/catalog_accessor.h | 43 ++++--- src/include/catalog/database_catalog.h | 11 +- src/include/execution/sql/ddl_executors.h | 11 +- .../optimizer/child_property_deriver.h | 6 + src/include/optimizer/logical_operators.h | 37 ++++++ src/include/optimizer/operator_visitor.h | 14 +++ src/include/optimizer/physical_operators.h | 34 ++++++ src/include/optimizer/plan_generator.h | 6 + src/include/optimizer/rule.h | 1 + .../optimizer/rules/implementation_rules.h | 27 +++++ .../parser/create_function_statement.h | 34 ++---- src/include/parser/drop_statement.h | 35 +++--- src/include/parser/postgresparser.h | 5 + .../plannodes/create_function_plan_node.h | 2 +- .../plannodes/drop_function_plan_node.h | 109 ++++++++++++++++++ .../planner/plannodes/plan_node_defs.h | 1 + src/include/planner/plannodes/plan_visitor.h | 7 ++ .../postgres/postgres_packet_writer.cpp | 3 + src/optimizer/child_property_deriver.cpp | 5 + src/optimizer/logical_operators.cpp | 30 +++++ src/optimizer/physical_operators.cpp | 30 +++++ src/optimizer/plan_generator.cpp | 9 ++ .../query_to_operator_transformer.cpp | 4 + src/optimizer/rule.cpp | 1 + src/optimizer/rules/implementation_rules.cpp | 24 ++++ src/parser/postgresparser.cpp | 31 ++--- src/planner/plannodes/abstract_plan_node.cpp | 5 + .../plannodes/drop_function_plan_node.cpp | 65 +++++++++++ src/planner/plannodes/plan_node_defs.cpp | 2 + src/storage/recovery/recovery_manager.cpp | 2 +- src/traffic_cop/traffic_cop.cpp | 11 +- 37 files changed, 575 insertions(+), 123 deletions(-) create mode 100644 src/include/planner/plannodes/drop_function_plan_node.h create mode 100644 src/planner/plannodes/drop_function_plan_node.cpp diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index a2c24c9210..8c7f683ca6 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -323,9 +323,15 @@ void BindNodeVisitor::Visit(common::ManagedPointer node) common::ErrorCode::ERRCODE_UNDEFINED_OBJECT); } break; - case parser::DropStatement::DropType::kFunction: + case parser::DropStatement::DropType::kFunction: { ValidateDatabaseName(node->GetDatabaseName()); - throw NOT_IMPLEMENTED_EXCEPTION("DROP FUNCTION Not Implemented"); + if (catalog_accessor_->GetProcOid(node->GetFunctionName(), node->GetFunctionArguments()) == + catalog::INVALID_PROC_OID) { + throw BINDER_EXCEPTION(fmt::format("function \"{}\" does not exist", node->GetFunctionName()), + common::ErrorCode::ERRCODE_UNDEFINED_OBJECT); + } + break; + } case parser::DropStatement::DropType::kTrigger: // TODO(Ling): Get Trigger OID in catalog? case parser::DropStatement::DropType::kSchema: @@ -1167,5 +1173,4 @@ void BindNodeVisitor::AddUDFVariableReference(common::ManagedPointerSetParamIdx(index); } } - } // namespace noisepage::binder diff --git a/src/catalog/catalog_accessor.cpp b/src/catalog/catalog_accessor.cpp index ce1c5b98ce..38af4f6146 100644 --- a/src/catalog/catalog_accessor.cpp +++ b/src/catalog/catalog_accessor.cpp @@ -192,6 +192,15 @@ proc_oid_t CatalogAccessor::CreateProcedure(const std::string &procname, languag bool CatalogAccessor::DropProcedure(proc_oid_t proc_oid) { return dbc_->DropProcedure(txn_, proc_oid); } +proc_oid_t CatalogAccessor::GetProcOid(const std::string &procname, const std::vector &arg_types) { + // Transform the string type identifiers to internal type IDs + std::vector types{}; + types.reserve(arg_types.size()); + std::transform(arg_types.cbegin(), arg_types.cend(), std::back_inserter(types), + [this](const std::string &name) { return TypeNameToType(name); }); + return GetProcOid(procname, types); +} + proc_oid_t CatalogAccessor::GetProcOid(const std::string &procname, const std::vector &arg_types) { proc_oid_t ret; for (auto ns_oid : search_path_) { @@ -203,20 +212,15 @@ proc_oid_t CatalogAccessor::GetProcOid(const std::string &procname, const std::v return catalog::INVALID_PROC_OID; } -common::ManagedPointer CatalogAccessor::GetProcCtxPtr( - const proc_oid_t proc_oid) { - return dbc_->GetProcCtxPtr(txn_, proc_oid); -} - -bool CatalogAccessor::SetFunctionContextPointer(proc_oid_t proc_oid, - const execution::functions::FunctionContext *func_context) { - return dbc_->SetFunctionContextPointer(txn_, proc_oid, func_context); -} - common::ManagedPointer CatalogAccessor::GetFunctionContext(proc_oid_t proc_oid) { return dbc_->GetFunctionContext(txn_, proc_oid); } +bool CatalogAccessor::SetFunctionContext(proc_oid_t proc_oid, + const execution::functions::FunctionContext *func_context) { + return dbc_->SetFunctionContext(txn_, proc_oid, func_context); +} + std::unique_ptr CatalogAccessor::GetColumnStatistics(table_oid_t table_oid, col_oid_t col_oid) { return dbc_->GetColumnStatistics(txn_, table_oid, col_oid); @@ -239,4 +243,17 @@ void CatalogAccessor::RegisterTempTable(table_oid_t table_oid, const common::Man temp_tables_[table_oid] = table; } +type_oid_t CatalogAccessor::TypeNameToType(const std::string &type_name) { + // TODO(Kyle): Complete this function + type_oid_t type; + if (type_name == "int4") { + type = GetTypeOidFromTypeId(type::TypeId::INTEGER); + } else if (type_name == "bool") { + type = GetTypeOidFromTypeId(type::TypeId::BOOLEAN); + } else { + type = GetTypeOidFromTypeId(type::TypeId::INVALID); + } + return type; +} + } // namespace noisepage::catalog diff --git a/src/catalog/database_catalog.cpp b/src/catalog/database_catalog.cpp index df6651ef29..54c433ae27 100644 --- a/src/catalog/database_catalog.cpp +++ b/src/catalog/database_catalog.cpp @@ -364,9 +364,14 @@ bool DatabaseCatalog::CreateIndexEntry(const common::ManagedPointer txn, - proc_oid_t proc_oid, - const execution::functions::FunctionContext *func_context) { +common::ManagedPointer DatabaseCatalog::GetFunctionContext( + common::ManagedPointer txn, proc_oid_t proc_oid) { + return pg_proc_.GetProcCtxPtr(txn, proc_oid); +} + +bool DatabaseCatalog::SetFunctionContext(common::ManagedPointer txn, + proc_oid_t proc_oid, + const execution::functions::FunctionContext *func_context) { NOISEPAGE_ASSERT( write_lock_.load() == txn->FinishTime(), "Setting the object's pointer should only be done after successful DDL change request. i.e. this txn " @@ -380,13 +385,6 @@ bool DatabaseCatalog::SetFunctionContextPointer(common::ManagedPointer DatabaseCatalog::GetFunctionContext( - common::ManagedPointer txn, proc_oid_t proc_oid) { - auto proc_ctx = pg_proc_.GetProcCtxPtr(txn, proc_oid); - NOISEPAGE_ASSERT(proc_ctx != nullptr, "Dynamically added UDFs are currently not supported."); - return proc_ctx; -} - std::unique_ptr DatabaseCatalog::GetColumnStatistics( common::ManagedPointer txn, table_oid_t table_oid, col_oid_t col_oid) { return pg_stat_.GetColumnStatistics(txn, common::ManagedPointer(this), table_oid, col_oid); @@ -472,11 +470,6 @@ proc_oid_t DatabaseCatalog::GetProcOid(common::ManagedPointer DatabaseCatalog::GetProcCtxPtr( - common::ManagedPointer txn, proc_oid_t proc_oid) { - return pg_proc_.GetProcCtxPtr(txn, proc_oid); -} - template bool DatabaseCatalog::SetClassPointer(const common::ManagedPointer txn, const ClassOid oid, const Ptr *const pointer, const col_oid_t class_col) { diff --git a/src/catalog/postgres/pg_proc_impl.cpp b/src/catalog/postgres/pg_proc_impl.cpp index d6093b0813..67f6b209b6 100644 --- a/src/catalog/postgres/pg_proc_impl.cpp +++ b/src/catalog/postgres/pg_proc_impl.cpp @@ -203,7 +203,8 @@ bool PgProcImpl::DropProcedure(const common::ManagedPointerGet(proc_pm[PgProc::PRONAME.oid_], nullptr); auto proc_ns = *table_pr->Get(proc_pm[PgProc::PRONAMESPACE.oid_], nullptr); - auto ctx_ptr = table_pr->AccessWithNullCheck(proc_pm[PgProc::PRO_CTX_PTR.oid_]); + auto *ctx_ptr = reinterpret_cast( + table_pr->AccessWithNullCheck(proc_pm[PgProc::PRO_CTX_PTR.oid_])); // Delete from pg_proc_name_index. { @@ -276,8 +277,7 @@ common::ManagedPointer PgProcImpl::GetPro NOISEPAGE_ASSERT(result, "Index already verified visibility. This shouldn't fail."); auto *ptr_ptr = (reinterpret_cast(select_pr->AccessWithNullCheck(0))); - NOISEPAGE_ASSERT(nullptr != ptr_ptr, - "GetFunctionContext called on an invalid OID or before SetFunctionContextPointer."); + NOISEPAGE_ASSERT(nullptr != ptr_ptr, "GetFunctionContext called on an invalid OID or before SetFunctionContext."); execution::functions::FunctionContext *ptr = *reinterpret_cast(ptr_ptr); delete[] buffer; @@ -473,7 +473,7 @@ void PgProcImpl::BootstrapProcContext(const common::ManagedPointerSetFunctionContextPointer(txn, proc_oid, func_context); + const auto retval UNUSED_ATTRIBUTE = dbc->SetFunctionContext(txn, proc_oid, func_context); NOISEPAGE_ASSERT(retval, "Bootstrap operations should not fail"); } diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 0d6a25ef31..6f86b851e3 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -119,7 +119,7 @@ void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { common::ErrorCode::ERRCODE_PLPGSQL_ERROR); } - auto context = accessor_->GetProcCtxPtr(proc_oid); + auto context = accessor_->GetFunctionContext(proc_oid); if (context->IsBuiltin()) { ast::Expr *result = codegen_->CallBuiltin(context->GetBuiltin(), args_ast); SetExecutionResult(result); diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index 03c34b95e1..b348f932ca 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -24,6 +24,7 @@ #include "planner/plannodes/create_namespace_plan_node.h" #include "planner/plannodes/create_table_plan_node.h" #include "planner/plannodes/drop_database_plan_node.h" +#include "planner/plannodes/drop_function_plan_node.h" #include "planner/plannodes/drop_index_plan_node.h" #include "planner/plannodes/drop_namespace_plan_node.h" #include "planner/plannodes/drop_table_plan_node.h" @@ -142,14 +143,12 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointer( node->GetFunctionName(), parser::ReturnType::DataTypeToTypeId(node->GetReturnType()), std::move(types), std::unique_ptr(region), std::move(ast_context), file); - if (!accessor->SetFunctionContextPointer(proc_id, udf_context.get())) { + if (!accessor->SetFunctionContext(proc_id, udf_context.get())) { return false; } - // TODO(Kyle): Not quite sure how abort actions work, but is - // the implication here that we leak in the event that we do - // not abort and the associated transaction completes? accessor->GetTxn()->RegisterAbortAction([udf_context = udf_context.release()]() { delete udf_context; }); + return true; } @@ -280,4 +279,9 @@ bool DDLExecutors::CreateIndex(const common::ManagedPointer node, + common::ManagedPointer accessor) { + return accessor->DropProcedure(node->GetProcedureOid()); +} + } // namespace noisepage::execution::sql diff --git a/src/include/catalog/catalog_accessor.h b/src/include/catalog/catalog_accessor.h index e6d2247b15..b4bf640a17 100644 --- a/src/include/catalog/catalog_accessor.h +++ b/src/include/catalog/catalog_accessor.h @@ -342,21 +342,32 @@ class EXPORT CatalogAccessor { bool DropProcedure(proc_oid_t proc_oid); /** - * Gets the oid of a procedure from pg_proc given a requested name and namespace + * Get the OID of the procedure from pg_proc given a requested name and argument + * types as string identifiers. + * This lookup with return the first one found through a sequential scan through + * the current search path. + * @param procname name of the proc to lookup + * @param arg_types vector of type identifiers for the arguments of the procedure + * @return The OID of the resolved procedure if found, else `INVALID_PROC_OID` + */ + proc_oid_t GetProcOid(const std::string &procname, const std::vector &arg_types); + + /** + * Gets the OID of a procedure from pg_proc given a requested name and resolved argument types. * This lookup will return the first one found through a sequential scan through * the current search path * @param procname name of the proc to lookup - * @param all_arg_types vector of types of arguments of procedure to look up - * @return the oid of the found proc if found else INVALID_PROC_OID + * @param arg_types vector of types of arguments of procedure to look up + * @return The OID of the resolved procedure if found, else `INVALID_PROC_OID` */ - proc_oid_t GetProcOid(const std::string &procname, const std::vector &all_arg_types); + proc_oid_t GetProcOid(const std::string &procname, const std::vector &arg_types); /** - * TODO(Kyle): Document. + * Gets the proc context pointer column of proc_oid + * @param proc_oid The proc_oid whose pointer column we are getting here + * @return nullptr if proc_oid is either invalid or there is no context object set for this proc_oid */ - common::ManagedPointer GetProcCtxPtr(proc_oid_t proc_oid); - - // TODO(Kyle): Make these functions consistent + common::ManagedPointer GetFunctionContext(proc_oid_t proc_oid); /** * Sets the proc context pointer column of proc_oid to func_context @@ -364,14 +375,7 @@ class EXPORT CatalogAccessor { * @param func_context The context object to set to * @return False if the given proc_oid is invalid, True if else */ - bool SetFunctionContextPointer(proc_oid_t proc_oid, const execution::functions::FunctionContext *func_context); - - /** - * Gets the proc context pointer column of proc_oid - * @param proc_oid The proc_oid whose pointer column we are getting here - * @return nullptr if proc_oid is either invalid or there is no context object set for this proc_oid - */ - common::ManagedPointer GetFunctionContext(proc_oid_t proc_oid); + bool SetFunctionContext(proc_oid_t proc_oid, const execution::functions::FunctionContext *func_context); /** * Gets the statistics of a column from pg_statistic @@ -461,6 +465,13 @@ class EXPORT CatalogAccessor { static void NormalizeObjectName(std::string *name) { std::transform(name->begin(), name->end(), name->begin(), [](auto &&c) { return std::tolower(c); }); } + + /** + * Resolve a string type name identifier to a catalog type. + * @param type_name The type name + * @return The internal catalog type identifier for the type + */ + type_oid_t TypeNameToType(const std::string &type_name); }; } // namespace noisepage::catalog diff --git a/src/include/catalog/database_catalog.h b/src/include/catalog/database_catalog.h index cd8f3e12e0..24c605a438 100644 --- a/src/include/catalog/database_catalog.h +++ b/src/include/catalog/database_catalog.h @@ -175,15 +175,12 @@ class DatabaseCatalog { /** @brief Get the OID of the specified procedure. @see PgProcImpl::GetProcOid */ proc_oid_t GetProcOid(common::ManagedPointer txn, namespace_oid_t procns, const std::string &procname, const std::vector &all_arg_types); - /** @brief Get the procedure context pointer column of the specified procedure */ - common::ManagedPointer GetProcCtxPtr( - common::ManagedPointer txn, proc_oid_t proc_oid); - /** @brief Set the procedure context for the specified procedure. @see PgProcImpl::SetFunctionContextPointer */ - bool SetFunctionContextPointer(common::ManagedPointer txn, proc_oid_t proc_oid, - const execution::functions::FunctionContext *func_context); - /** @brief Get the procedure context for the specified procedure. @see PgProcImpl::GetFunctionContext */ + /** @brief Get the procedure context for the specified procedure. @see PgProcImpl::GetProcCtxPtr */ common::ManagedPointer GetFunctionContext( common::ManagedPointer txn, proc_oid_t proc_oid); + /** @brief Set the procedure context for the specified procedure. @see PgProcImpl::SetProcCtxPtr */ + bool SetFunctionContext(common::ManagedPointer txn, proc_oid_t proc_oid, + const execution::functions::FunctionContext *func_context); /** @brief Get the statistics for the specified column. @see PgStatisticImpl::GetColumnStatistics */ std::unique_ptr GetColumnStatistics( diff --git a/src/include/execution/sql/ddl_executors.h b/src/include/execution/sql/ddl_executors.h index 75272baa3d..8ece11b335 100644 --- a/src/include/execution/sql/ddl_executors.h +++ b/src/include/execution/sql/ddl_executors.h @@ -11,9 +11,10 @@ class CreateNamespacePlanNode; class CreateTablePlanNode; class CreateIndexPlanNode; class CreateViewPlanNode; +class CreateFunctionPlanNode; class DropDatabasePlanNode; class DropNamespacePlanNode; -class CreateFunctionPlanNode; +class DropFunctionPlanNode; class DropTablePlanNode; class DropIndexPlanNode; } // namespace noisepage::planner @@ -108,6 +109,14 @@ class DDLExecutors { static bool DropIndexExecutor(common::ManagedPointer node, common::ManagedPointer accessor); + /** + * @param node node to execute + * @param accessor accessor to use for execution + * @return `true` if operation succeeds, `false` otherwise + */ + static bool DropFunctionExecutor(common::ManagedPointer node, + common::ManagedPointer accessor); + private: static bool CreateIndex(common::ManagedPointer accessor, catalog::namespace_oid_t ns, const std::string &name, catalog::table_oid_t table, diff --git a/src/include/optimizer/child_property_deriver.h b/src/include/optimizer/child_property_deriver.h index a080b92217..358cd3e12d 100644 --- a/src/include/optimizer/child_property_deriver.h +++ b/src/include/optimizer/child_property_deriver.h @@ -276,6 +276,12 @@ class ChildPropertyDeriver : public OperatorVisitor { */ void Visit(const DropView *drop_view) override; + /** + * Visit a DropFunction operator + * @param drop_function operator + */ + void Visit(const DropFunction *drop_function) override; + /** * Visit an Analyze operator * @param analyze analyze operator diff --git a/src/include/optimizer/logical_operators.h b/src/include/optimizer/logical_operators.h index 14331e7336..6f1eff7ec5 100644 --- a/src/include/optimizer/logical_operators.h +++ b/src/include/optimizer/logical_operators.h @@ -1903,6 +1903,43 @@ class LogicalDropView : public OperatorNodeContents { bool if_exists_; }; +/** + * Logical operator for DropFunction + */ +class LogicalDropFunction : public OperatorNodeContents { + public: + /** + * @param database_oid OID of the database + * @param proc_oid OID of the function to be dropped + * @return LogicalDropFunction + */ + static Operator Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid); + + /** + * Copy + * @returns copy of this + */ + BaseOperatorNodeContents *Copy() const override; + + /** Comparison operator */ + bool operator==(const BaseOperatorNodeContents &r) override; + + /** @return The hash of the instance */ + common::hash_t Hash() const override; + + /** @return The OID of the database */ + catalog::db_oid_t GetDatabaseOid() const { return database_oid_; } + + /** @return The OID of the function to drop */ + catalog::proc_oid_t GetFunctionOid() const { return proc_oid_; } + + private: + /** OID of the database */ + catalog::db_oid_t database_oid_; + /** OID of the function to drop */ + catalog::proc_oid_t proc_oid_; +}; + /** * Logical operator for Analyze */ diff --git a/src/include/optimizer/operator_visitor.h b/src/include/optimizer/operator_visitor.h index 69f36ae655..f5862053ff 100644 --- a/src/include/optimizer/operator_visitor.h +++ b/src/include/optimizer/operator_visitor.h @@ -42,6 +42,7 @@ class DropIndex; class DropNamespace; class DropTrigger; class DropView; +class DropFunction; class Analyze; class LogicalGet; class LogicalExternalFileGet; @@ -77,6 +78,7 @@ class LogicalDropIndex; class LogicalDropNamespace; class LogicalDropTrigger; class LogicalDropView; +class LogicalDropFunction; class LogicalAnalyze; class LogicalCteScan; @@ -320,6 +322,12 @@ class OperatorVisitor { */ virtual void Visit(const DropView *drop_view) {} + /** + * Visit a DropFunction operator + * @param drop_function operator + */ + virtual void Visit(const DropFunction *drop_function) {} + /** * Visit a Analyze operator * @param analyze operator @@ -530,6 +538,12 @@ class OperatorVisitor { */ virtual void Visit(const LogicalDropView *logical_drop_view) {} + /** + * Visit a LogicalDropFunction operator + * @param logical_drop_function + */ + virtual void Visit(const LogicalDropFunction *logical_drop_function) {} + /** * Visit a LogicalAnalyze operator * @param logical_analyze operator diff --git a/src/include/optimizer/physical_operators.h b/src/include/optimizer/physical_operators.h index df3b7781af..0b03dfee9b 100644 --- a/src/include/optimizer/physical_operators.h +++ b/src/include/optimizer/physical_operators.h @@ -2115,6 +2115,40 @@ class DropView : public OperatorNodeContents { bool if_exists_; }; +/** + * Physical operator for DropFunction + */ +class DropFunction : public OperatorNodeContents { + public: + /** + * @param database_oid OID of database + * @param proc_oid OID of view to drop + * @return + */ + static Operator Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid); + + /** @return A copy of this */ + BaseOperatorNodeContents *Copy() const override; + + /** Comparison operator */ + bool operator==(const BaseOperatorNodeContents &r) override; + + /** @return The hash of this instance */ + common::hash_t Hash() const override; + + /** @return The OID of the database */ + catalog::db_oid_t GetDatabaseOid() const { return database_oid_; } + + /** @return The OID of the function to drop */ + catalog::proc_oid_t GetFunctionOID() const { return proc_oid_; } + + private: + /** OID of the database */ + catalog::db_oid_t database_oid_; + /** OID of the view to drop */ + catalog::proc_oid_t proc_oid_; +}; + /** * Physical operator for Analyze */ diff --git a/src/include/optimizer/plan_generator.h b/src/include/optimizer/plan_generator.h index 70f14e44ec..6edcb130ab 100644 --- a/src/include/optimizer/plan_generator.h +++ b/src/include/optimizer/plan_generator.h @@ -303,6 +303,12 @@ class PlanGenerator : public OperatorVisitor { */ void Visit(const DropView *drop_view) override; + /** + * Visit a DropFunction operator + * @param drop_function operator + */ + void Visit(const DropFunction *drop_function) override; + /** * Visit a Analyze operator * @param analyze operator diff --git a/src/include/optimizer/rule.h b/src/include/optimizer/rule.h index 0b0d94facb..d1cdcc1bf6 100644 --- a/src/include/optimizer/rule.h +++ b/src/include/optimizer/rule.h @@ -59,6 +59,7 @@ enum class RuleType : uint32_t { DROP_NAMESPACE_TO_PHYSICAL, DROP_TRIGGER_TO_PHYSICAL, DROP_VIEW_TO_PHYSICAL, + DROP_FUNCTION_TO_PHYSICAL, // Don't move this one RewriteDelimiter, diff --git a/src/include/optimizer/rules/implementation_rules.h b/src/include/optimizer/rules/implementation_rules.h index 314dd4d943..805c6a8b09 100644 --- a/src/include/optimizer/rules/implementation_rules.h +++ b/src/include/optimizer/rules/implementation_rules.h @@ -923,6 +923,33 @@ class LogicalDropViewToPhysicalDropView : public Rule { OptimizationContext *context) const override; }; +/** + * Rule transforms Logical DropFunction -> Physical DropFunction + */ +class LogicalDropFunctionToPhysicalDropFunction : public Rule { + public: + /** Constructor */ + LogicalDropFunctionToPhysicalDropFunction(); + + /** + * Checks whether the given rule can be applied + * @param plan AbstractOptimizerNode to check + * @param context Current OptimizationContext executing under + * @returns Whether the input AbstractOptimizerNode passes the check + */ + bool Check(common::ManagedPointer plan, OptimizationContext *context) const override; + + /** + * Transforms the input expression using the given rule + * @param input Input AbstractOptimizerNode to transform + * @param transformed Vector of transformed AbstractOptimizerNodes + * @param context Current OptimizationContext executing under + */ + void Transform(common::ManagedPointer input, + std::vector> *transformed, + OptimizationContext *context) const override; +}; + /** * Rule transforms Logical Analyze -> Physical Analyze */ diff --git a/src/include/parser/create_function_statement.h b/src/include/parser/create_function_statement.h index 64221d5790..96cc5f3b8e 100644 --- a/src/include/parser/create_function_statement.h +++ b/src/include/parser/create_function_statement.h @@ -10,8 +10,7 @@ #include "parser/sql_statement.h" // TODO(WAN): this file is messy -namespace noisepage { -namespace parser { +namespace noisepage::parser { /** Base function parameter. */ struct BaseFunctionParameter { // TODO(WAN): there used to be a FuncParamMode that was never used? @@ -134,29 +133,19 @@ class CreateFunctionStatement : public SQLStatement { void Accept(common::ManagedPointer v) override { v->Visit(common::ManagedPointer(this)); } - /** - * @return true if this function should replace existing definitions - */ + /** @return `true` if this function should replace existing definitions */ bool ShouldReplace() { return replace_; } - /** - * @return function name - */ + /** @return The function name */ std::string GetFuncName() { return func_name_; } - /** - * @return return type - */ + /** @return The function return type */ common::ManagedPointer GetFuncReturnType() { return common::ManagedPointer(return_type_); } - /** - * @return function body - */ + /** @return The function body */ std::vector GetFuncBody() { return func_body_; } - /** - * @return function parameters - */ + /** @return The function parameters */ std::vector> GetFuncParameters() { std::vector> params; params.reserve(func_parameters_.size()); @@ -166,14 +155,10 @@ class CreateFunctionStatement : public SQLStatement { return params; } - /** - * @return programming language type - */ + /** @return The programming language type */ PLType GetPLType() { return pl_type_; } - /** - * @return as type (executable or query string) - */ + /** @return As type (executable or query string) */ AsType GetAsType() { return as_type_; } private: @@ -186,5 +171,4 @@ class CreateFunctionStatement : public SQLStatement { const AsType as_type_; }; -} // namespace parser -} // namespace noisepage +} // namespace noisepage::parser diff --git a/src/include/parser/drop_statement.h b/src/include/parser/drop_statement.h index 59f707b496..b9647ac6fd 100644 --- a/src/include/parser/drop_statement.h +++ b/src/include/parser/drop_statement.h @@ -18,11 +18,6 @@ class DropStatement : public TableRefStatement { /** Drop statement type. */ enum class DropType { kDatabase, kTable, kSchema, kIndex, kView, kPreparedStatement, kTrigger, kFunction }; - // TODO(Kyle): This class is becoming overly-overloaded. - // For instance, I can't define the interface for a ctor - // for DROP FUNCTION that is identical to DROP INDEX. - // Additionally, we carry a bunch of useless state around. - /** * DROP DATABASE, DROP TABLE * @param table_info table information @@ -32,11 +27,21 @@ class DropStatement : public TableRefStatement { DropStatement(std::unique_ptr table_info, DropType type, bool if_exists) : TableRefStatement(StatementType::DROP, std::move(table_info)), type_(type), if_exists_(if_exists) {} + /** + * DROP INDEX + * @param table_info table information + * @param index_name index name + */ + DropStatement(std::unique_ptr table_info, std::string index_name) + : TableRefStatement(StatementType::DROP, std::move(table_info)), + type_(DropType::kIndex), + index_name_(std::move(index_name)) {} + /** * DROP FUNCTION * @param table_info table information * @param function_name function name - * @param function_args function argument types + * @param function_args function argument type identifiers */ DropStatement(std::unique_ptr table_info, std::string function_name, std::vector &&function_args) @@ -45,16 +50,6 @@ class DropStatement : public TableRefStatement { function_name_(std::move(function_name)), function_args_(std::move(function_args)) {} - /** - * DROP INDEX - * @param table_info table information - * @param index_name index name - */ - DropStatement(std::unique_ptr table_info, std::string index_name) - : TableRefStatement(StatementType::DROP, std::move(table_info)), - type_(DropType::kIndex), - index_name_(std::move(index_name)) {} - /** * DROP SCHEMA * @param table_info table information @@ -101,15 +96,15 @@ class DropStatement : public TableRefStatement { /** @return function name for [DROP FUNCTION] */ std::string GetFunctionName() { return function_name_; } - /** @return function arguments for [DROP FUNCTION] */ + /** @return function argument types for [DROP FUNCTION] */ const std::vector &GetFunctionArguments() const { return function_args_; } - // TODO(Kyle): Why are we returning all of these strings by value? - // It appears that we can just use const references here... - private: const DropType type_; + // TODO(Kyle): Maybe use a std::variant here to make + // the overloading of this type less wasteful? + // DROP DATABASE, SCHEMA const bool if_exists_ = false; diff --git a/src/include/parser/postgresparser.h b/src/include/parser/postgresparser.h index 99a3aab2ab..ebbd7ff88a 100644 --- a/src/include/parser/postgresparser.h +++ b/src/include/parser/postgresparser.h @@ -77,6 +77,11 @@ class PostgresParser { } } + /** + * Determine if the function identified by `fun_name` is an aggregate function. + * @param fun_name The function name + * @return `true` if the function is an aggregation, `false` otherwise + */ static bool IsAggregateFunction(const std::string &fun_name) { return (fun_name == "min" || fun_name == "max" || fun_name == "count" || fun_name == "avg" || fun_name == "sum"); } diff --git a/src/include/planner/plannodes/create_function_plan_node.h b/src/include/planner/plannodes/create_function_plan_node.h index b4cf421c8a..589872cd2b 100644 --- a/src/include/planner/plannodes/create_function_plan_node.h +++ b/src/include/planner/plannodes/create_function_plan_node.h @@ -13,7 +13,7 @@ namespace noisepage::planner { /** - * Plan node for creating user defined functions + * Plan node for creating user-defined functions */ class CreateFunctionPlanNode : public AbstractPlanNode { public: diff --git a/src/include/planner/plannodes/drop_function_plan_node.h b/src/include/planner/plannodes/drop_function_plan_node.h new file mode 100644 index 0000000000..1d66604076 --- /dev/null +++ b/src/include/planner/plannodes/drop_function_plan_node.h @@ -0,0 +1,109 @@ +#pragma once + +#include +#include + +#include "parser/drop_statement.h" +#include "parser/parser_defs.h" +#include "planner/plannodes/abstract_plan_node.h" +#include "planner/plannodes/plan_visitor.h" + +namespace noisepage::planner { + +/** + * Plan node for dropping user-defined functions. + */ +class DropFunctionPlanNode : public AbstractPlanNode { + public: + /** + * Builder for an create function plan node + */ + class Builder : public AbstractPlanNode::Builder { + public: + Builder() = default; + + /** + * Don't allow builder to be copied or moved + */ + DISALLOW_COPY_AND_MOVE(Builder); + + /** + * @param database_oid The OID of the database + * @return builder object + */ + Builder &SetDatabaseOid(catalog::db_oid_t database_oid) { + database_oid_ = database_oid; + return *this; + } + + /** + * @param proc_oid The OID of the procedure + * @return builder object + */ + Builder &SetProcedureOid(catalog::proc_oid_t proc_oid) { + proc_oid_ = proc_oid; + return *this; + } + + /** + * Build the drop function plan node + * @return plan node + */ + std::unique_ptr Build(); + + protected: + /** OID of the database */ + catalog::db_oid_t database_oid_; + /** OID of the procedure */ + catalog::proc_oid_t proc_oid_; + }; + + private: + /** + * @param children child plan nodes + * @param output_schema Schema representing the structure of the output of this plan node + * @param database_oid OID of the database + * @param proc_oid OID of the procedure + * @param plan_node_id Plan node ID + */ + DropFunctionPlanNode(std::vector> &&children, + std::unique_ptr output_schema, catalog::db_oid_t database_oid, + catalog::proc_oid_t proc_oid, plan_node_id_t plan_node_id); + + public: + /** Default constructor used for deserialization */ + DropFunctionPlanNode() = default; + + DISALLOW_COPY_AND_MOVE(DropFunctionPlanNode) + + /** @return the type of this plan node */ + PlanNodeType GetPlanNodeType() const override { return PlanNodeType::DROP_FUNC; } + + /** @return OID of the database */ + catalog::db_oid_t GetDatabaseOid() const { return database_oid_; } + + /** @return OID of the procedure */ + catalog::proc_oid_t GetProcedureOid() const { return proc_oid_; } + + /** @return the hashed value of this plan node */ + common::hash_t Hash() const override; + + bool operator==(const AbstractPlanNode &rhs) const override; + + void Accept(common::ManagedPointer v) const override { v->Visit(this); } + + /** Serialize to JSON representation */ + nlohmann::json ToJson() const override; + /** Deserialize from JSON representation */ + std::vector> FromJson(const nlohmann::json &j) override; + + private: + /** OID of database */ + catalog::db_oid_t database_oid_; + /** OID of procedure */ + catalog::proc_oid_t proc_oid_; +}; + +DEFINE_JSON_HEADER_DECLARATIONS(DropFunctionPlanNode); + +} // namespace noisepage::planner diff --git a/src/include/planner/plannodes/plan_node_defs.h b/src/include/planner/plannodes/plan_node_defs.h index 4e71d20da7..f79c807153 100644 --- a/src/include/planner/plannodes/plan_node_defs.h +++ b/src/include/planner/plannodes/plan_node_defs.h @@ -62,6 +62,7 @@ enum class PlanNodeType { DROP_NAMESPACE, DROP_TABLE, DROP_INDEX, + DROP_FUNC, DROP_TRIGGER, DROP_VIEW, ANALYZE, diff --git a/src/include/planner/plannodes/plan_visitor.h b/src/include/planner/plannodes/plan_visitor.h index c055c9587a..0f8fa51ef1 100644 --- a/src/include/planner/plannodes/plan_visitor.h +++ b/src/include/planner/plannodes/plan_visitor.h @@ -19,6 +19,7 @@ class DropNamespacePlanNode; class DropTablePlanNode; class DropTriggerPlanNode; class DropViewPlanNode; +class DropFunctionPlanNode; class ExportExternalFilePlanNode; class HashJoinPlanNode; class IndexJoinPlanNode; @@ -143,6 +144,12 @@ class PlanVisitor { */ virtual void Visit(UNUSED_ATTRIBUTE const DropViewPlanNode *plan) {} + /** + * Visit a DropFunctionPlanNode + * @param plan DropFunctionPlanNode + */ + virtual void Visit(UNUSED_ATTRIBUTE const DropFunctionPlanNode *plan) {} + /** * Visit an ExportExternalFilePlanNode * @param plan ExportExternalFilePlanNode diff --git a/src/network/postgres/postgres_packet_writer.cpp b/src/network/postgres/postgres_packet_writer.cpp index cbd693b480..71a635e851 100644 --- a/src/network/postgres/postgres_packet_writer.cpp +++ b/src/network/postgres/postgres_packet_writer.cpp @@ -177,6 +177,9 @@ void PostgresPacketWriter::WriteCommandComplete(const QueryType query_type, cons case QueryType::QUERY_DROP_SCHEMA: WriteCommandComplete("DROP SCHEMA"); break; + case QueryType::QUERY_DROP_FUNCTION: + WriteCommandComplete("DROP FUNCTION"); + break; case QueryType::QUERY_EXPLAIN: WriteCommandComplete("EXPLAIN"); break; diff --git a/src/optimizer/child_property_deriver.cpp b/src/optimizer/child_property_deriver.cpp index cd926f66a1..1cff543a86 100644 --- a/src/optimizer/child_property_deriver.cpp +++ b/src/optimizer/child_property_deriver.cpp @@ -268,6 +268,11 @@ void ChildPropertyDeriver::Visit(UNUSED_ATTRIBUTE const DropView *drop_view) { output_.emplace_back(new PropertySet(), std::vector{}); } +void ChildPropertyDeriver::Visit(UNUSED_ATTRIBUTE const DropFunction *drop_function) { + // Operator does not provide any properties + output_.emplace_back(new PropertySet(), std::vector{}); +} + void ChildPropertyDeriver::Visit(UNUSED_ATTRIBUTE const Analyze *analyze) { // Analyze does not provide any properties output_.emplace_back(new PropertySet(), std::vector{new PropertySet()}); diff --git a/src/optimizer/logical_operators.cpp b/src/optimizer/logical_operators.cpp index 37e326575c..67e7cb9728 100644 --- a/src/optimizer/logical_operators.cpp +++ b/src/optimizer/logical_operators.cpp @@ -1211,6 +1211,32 @@ bool LogicalDropView::operator==(const BaseOperatorNodeContents &r) { return if_exists_ == node.if_exists_; } +//===--------------------------------------------------------------------===// +// LogicalDropFunction +//===--------------------------------------------------------------------===// +BaseOperatorNodeContents *LogicalDropFunction::Copy() const { return new LogicalDropFunction(*this); } + +Operator LogicalDropFunction::Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid) { + auto *op = new LogicalDropFunction(); + op->database_oid_ = database_oid; + op->proc_oid_ = proc_oid; + return Operator(common::ManagedPointer(op)); +} + +common::hash_t LogicalDropFunction::Hash() const { + common::hash_t hash = BaseOperatorNodeContents::Hash(); + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(database_oid_)); + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(proc_oid_)); + return hash; +} + +bool LogicalDropFunction::operator==(const BaseOperatorNodeContents &r) { + if (r.GetOpType() != OpType::LOGICALDROPFUNCTION) return false; + const LogicalDropFunction &node = *dynamic_cast(&r); + if (database_oid_ != node.database_oid_) return false; + return proc_oid_ == node.proc_oid_; +} + //===--------------------------------------------------------------------===// // LogicalAnalyze //===--------------------------------------------------------------------===// @@ -1384,6 +1410,8 @@ const char *OperatorNodeContents::name = "LogicalDropTrigger template <> const char *OperatorNodeContents::name = "LogicalDropView"; template <> +const char *OperatorNodeContents::name = "LogicalDropFunction"; +template <> const char *OperatorNodeContents::name = "LogicalAnalyze"; template <> const char *OperatorNodeContents::name = "LogicalCteScan"; @@ -1460,6 +1488,8 @@ OpType OperatorNodeContents::type = OpType::LOGICALDROPTRIGG template <> OpType OperatorNodeContents::type = OpType::LOGICALDROPVIEW; template <> +OpType OperatorNodeContents::type = OpType::LOGICALDROPFUNCTION; +template <> OpType OperatorNodeContents::type = OpType::LOGICALANALYZE; template <> OpType OperatorNodeContents::type = OpType::LOGICALCTESCAN; diff --git a/src/optimizer/physical_operators.cpp b/src/optimizer/physical_operators.cpp index 7c8bd22c1e..406d99f5f0 100644 --- a/src/optimizer/physical_operators.cpp +++ b/src/optimizer/physical_operators.cpp @@ -1331,6 +1331,32 @@ bool DropView::operator==(const BaseOperatorNodeContents &r) { return if_exists_ == node.if_exists_; } +//===--------------------------------------------------------------------===// +// DropFunction +//===--------------------------------------------------------------------===// +BaseOperatorNodeContents *DropFunction::Copy() const { return new DropFunction(*this); } + +Operator DropFunction::Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid) { + auto *op = new DropFunction(); + op->database_oid_ = database_oid; + op->proc_oid_ = proc_oid; + return Operator(common::ManagedPointer(op)); +} + +common::hash_t DropFunction::Hash() const { + common::hash_t hash = BaseOperatorNodeContents::Hash(); + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(database_oid_)); + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(proc_oid_)); + return hash; +} + +bool DropFunction::operator==(const BaseOperatorNodeContents &r) { + if (r.GetOpType() != OpType::DROPFUNCTION) return false; + const DropFunction &node = *dynamic_cast(&r); + if (database_oid_ != node.database_oid_) return false; + return proc_oid_ == node.proc_oid_; +} + //===--------------------------------------------------------------------===// // Analyze //===--------------------------------------------------------------------===// @@ -1472,6 +1498,8 @@ const char *OperatorNodeContents::name = "DropTrigger"; template <> const char *OperatorNodeContents::name = "DropView"; template <> +const char *OperatorNodeContents::name = "DropFunction"; +template <> const char *OperatorNodeContents::name = "Analyze"; template <> const char *OperatorNodeContents::name = "CteScan"; @@ -1554,6 +1582,8 @@ OpType OperatorNodeContents::type = OpType::DROPTRIGGER; template <> OpType OperatorNodeContents::type = OpType::DROPVIEW; template <> +OpType OperatorNodeContents::type = OpType::DROPFUNCTION; +template <> OpType OperatorNodeContents::type = OpType::ANALYZE; template <> OpType OperatorNodeContents::type = OpType::CTESCAN; diff --git a/src/optimizer/plan_generator.cpp b/src/optimizer/plan_generator.cpp index 0fc7686c9a..f41d7582a6 100644 --- a/src/optimizer/plan_generator.cpp +++ b/src/optimizer/plan_generator.cpp @@ -31,6 +31,7 @@ #include "planner/plannodes/cte_scan_plan_node.h" #include "planner/plannodes/delete_plan_node.h" #include "planner/plannodes/drop_database_plan_node.h" +#include "planner/plannodes/drop_function_plan_node.h" #include "planner/plannodes/drop_index_plan_node.h" #include "planner/plannodes/drop_namespace_plan_node.h" #include "planner/plannodes/drop_table_plan_node.h" @@ -1136,6 +1137,14 @@ void PlanGenerator::Visit(const DropView *drop_view) { .Build(); } +void PlanGenerator::Visit(const DropFunction *drop_function) { + output_plan_ = planner::DropFunctionPlanNode::Builder() + .SetPlanNodeId(GetNextPlanNodeID()) + .SetDatabaseOid(drop_function->GetDatabaseOid()) + .SetProcedureOid(drop_function->GetFunctionOID()) + .Build(); +} + void PlanGenerator::Visit(const Analyze *analyze) { NOISEPAGE_ASSERT(children_plans_.size() == 1, "Analyze should have 1 child plan"); output_plan_ = planner::AnalyzePlanNode::Builder() diff --git a/src/optimizer/query_to_operator_transformer.cpp b/src/optimizer/query_to_operator_transformer.cpp index f14953b78c..9253ba6fbb 100644 --- a/src/optimizer/query_to_operator_transformer.cpp +++ b/src/optimizer/query_to_operator_transformer.cpp @@ -637,6 +637,10 @@ void QueryToOperatorTransformer::Visit(common::ManagedPointer>{}, txn_context); break; case parser::DropStatement::DropType::kFunction: + drop_expr = std::make_unique( + LogicalDropFunction::Make(db_oid_, accessor_->GetProcOid(op->GetFunctionName(), op->GetFunctionArguments())) + .RegisterWithTxnContext(txn_context), + std::vector>{}, txn_context); case parser::DropStatement::DropType::kTrigger: case parser::DropStatement::DropType::kView: case parser::DropStatement::DropType::kPreparedStatement: diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index a781cadefc..3fd811bc51 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -52,6 +52,7 @@ RuleSet::RuleSet() { AddRule(RuleSetName::PHYSICAL_IMPLEMENTATION, new LogicalDropNamespaceToPhysicalDropNamespace()); AddRule(RuleSetName::PHYSICAL_IMPLEMENTATION, new LogicalDropTriggerToPhysicalDropTrigger()); AddRule(RuleSetName::PHYSICAL_IMPLEMENTATION, new LogicalDropViewToPhysicalDropView()); + AddRule(RuleSetName::PHYSICAL_IMPLEMENTATION, new LogicalDropFunctionToPhysicalDropFunction()); AddRule(RuleSetName::PHYSICAL_IMPLEMENTATION, new LogicalAnalyzeToPhysicalAnalyze()); AddRule(RuleSetName::PHYSICAL_IMPLEMENTATION, new LogicalCteScanToPhysicalCteScan()); AddRule(RuleSetName::PHYSICAL_IMPLEMENTATION, new LogicalCteScanToPhysicalEmptyCteScan()); diff --git a/src/optimizer/rules/implementation_rules.cpp b/src/optimizer/rules/implementation_rules.cpp index 74916a944a..29c8419a93 100644 --- a/src/optimizer/rules/implementation_rules.cpp +++ b/src/optimizer/rules/implementation_rules.cpp @@ -1185,6 +1185,30 @@ void LogicalDropViewToPhysicalDropView::Transform(common::ManagedPointeremplace_back(std::move(op)); } +LogicalDropFunctionToPhysicalDropFunction::LogicalDropFunctionToPhysicalDropFunction() { + type_ = RuleType::DROP_FUNCTION_TO_PHYSICAL; + match_pattern_ = new Pattern(OpType::LOGICALDROPFUNCTION); +} + +bool LogicalDropFunctionToPhysicalDropFunction::Check(common::ManagedPointer plan, + OptimizationContext *context) const { + return true; +} + +void LogicalDropFunctionToPhysicalDropFunction::Transform( + common::ManagedPointer input, + std::vector> *transformed, + UNUSED_ATTRIBUTE OptimizationContext *context) const { + auto df_op = input->Contents()->GetContentsAs(); + NOISEPAGE_ASSERT(input->GetChildren().empty(), "LogicalDropFunction should have 0 children"); + + auto op = std::make_unique(DropFunction::Make(df_op->GetDatabaseOid(), df_op->GetFunctionOid()) + .RegisterWithTxnContext(context->GetOptimizerContext()->GetTxn()), + std::vector>(), + context->GetOptimizerContext()->GetTxn()); + transformed->emplace_back(std::move(op)); +} + LogicalAnalyzeToPhysicalAnalyze::LogicalAnalyzeToPhysicalAnalyze() { type_ = RuleType::ANALYZE_TO_PHYSICAL; match_pattern_ = new Pattern(OpType::LOGICALANALYZE); diff --git a/src/parser/postgresparser.cpp b/src/parser/postgresparser.cpp index ddb5c4e588..06fa60fc37 100644 --- a/src/parser/postgresparser.cpp +++ b/src/parser/postgresparser.cpp @@ -50,7 +50,8 @@ std::unique_ptr PostgresParser::BuildParseTree(const std::s // Parse the query string with the Postgres parser. // TODO(Kyle): Syntax "DROP FUNCTION fun;" fails in the - // Postgres parser, do we need to update the version? + // Postgres parser, do we need to update the version to + // add support for the shorthand syntax? if (result.error != nullptr) { PARSER_LOG_DEBUG("BuildParseTree error: msg {}, curpos {}", result.error->message, result.error->cursorpos); @@ -1792,20 +1793,22 @@ std::unique_ptr PostgresParser::DropFunctionTransform(ParseResult std::string function_name = reinterpret_cast(objects->head->data.ptr_value)->val_.str_; // Grab the argument types from the function signature - auto arguments = reinterpret_cast(root->arguments_->head->data.ptr_value); - std::vector function_args{}; - function_args.reserve(arguments->length); - for (auto *cell = arguments->head; cell != nullptr; cell = cell->next) { - // The descriptor for some types consists of a head node with - // "pg_catalog" as the string value, so we need to skip over - auto *descriptor = reinterpret_cast(cell->data.ptr_value)->names_; - if (descriptor->length > 1) { - std::string type = reinterpret_cast(descriptor->head->next->data.ptr_value)->val_.str_; - function_args.emplace_back(std::move(type)); - } else { - std::string type = reinterpret_cast(descriptor->head->data.ptr_value)->val_.str_; - function_args.emplace_back(std::move(type)); + + auto *arguments = reinterpret_cast(root->arguments_->head->data.ptr_value); + if (arguments != NULL) { + function_args.reserve(arguments->length); + for (auto *cell = arguments->head; cell != nullptr; cell = cell->next) { + // The descriptor for some types consists of a head node with + // "pg_catalog" as the string value, so we need to skip over + auto *descriptor = reinterpret_cast(cell->data.ptr_value)->names_; + if (descriptor->length > 1) { + std::string type = reinterpret_cast(descriptor->head->next->data.ptr_value)->val_.str_; + function_args.emplace_back(std::move(type)); + } else { + std::string type = reinterpret_cast(descriptor->head->data.ptr_value)->val_.str_; + function_args.emplace_back(std::move(type)); + } } } diff --git a/src/planner/plannodes/abstract_plan_node.cpp b/src/planner/plannodes/abstract_plan_node.cpp index 30e86899dd..282b8e5aeb 100644 --- a/src/planner/plannodes/abstract_plan_node.cpp +++ b/src/planner/plannodes/abstract_plan_node.cpp @@ -18,6 +18,7 @@ #include "planner/plannodes/csv_scan_plan_node.h" #include "planner/plannodes/delete_plan_node.h" #include "planner/plannodes/drop_database_plan_node.h" +#include "planner/plannodes/drop_function_plan_node.h" #include "planner/plannodes/drop_index_plan_node.h" #include "planner/plannodes/drop_namespace_plan_node.h" #include "planner/plannodes/drop_table_plan_node.h" @@ -216,6 +217,10 @@ JSONDeserializeNodeIntermediate DeserializePlanNode(const nlohmann::json &json) plan_node = std::make_unique(); break; } + case PlanNodeType::DROP_FUNC: { + plan_node = std::make_unique(); + break; + } case PlanNodeType::EXPORT_EXTERNAL_FILE: { plan_node = std::make_unique(); break; diff --git a/src/planner/plannodes/drop_function_plan_node.cpp b/src/planner/plannodes/drop_function_plan_node.cpp new file mode 100644 index 0000000000..a9c34f18df --- /dev/null +++ b/src/planner/plannodes/drop_function_plan_node.cpp @@ -0,0 +1,65 @@ +#include "planner/plannodes/drop_function_plan_node.h" + +#include +#include +#include + +#include "common/json.h" +#include "planner/plannodes/output_schema.h" + +namespace noisepage::planner { + +std::unique_ptr DropFunctionPlanNode::Builder::Build() { + return std::unique_ptr(new DropFunctionPlanNode(std::move(children_), std::move(output_schema_), + database_oid_, proc_oid_, plan_node_id_)); +} + +DropFunctionPlanNode::DropFunctionPlanNode(std::vector> &&children, + std::unique_ptr output_schema, catalog::db_oid_t database_oid, + catalog::proc_oid_t proc_oid, plan_node_id_t plan_node_id) + : AbstractPlanNode(std::move(children), std::move(output_schema), plan_node_id), + database_oid_(database_oid), + proc_oid_(proc_oid) {} + +common::hash_t DropFunctionPlanNode::Hash() const { + common::hash_t hash = AbstractPlanNode::Hash(); + // Hash database_oid + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(database_oid_)); + // Hash procedure oid + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(proc_oid_)); + return hash; +} + +bool DropFunctionPlanNode::operator==(const AbstractPlanNode &rhs) const { + if (!AbstractPlanNode::operator==(rhs)) return false; + + auto &other = dynamic_cast(rhs); + + // Database OID + if (database_oid_ != other.database_oid_) return false; + + // Namespace OID + if (proc_oid_ != other.proc_oid_) return false; + + return true; +} + +nlohmann::json DropFunctionPlanNode::ToJson() const { + nlohmann::json j = AbstractPlanNode::ToJson(); + j["database_oid"] = database_oid_; + j["proc_oid"] = proc_oid_; + return j; +} + +std::vector> DropFunctionPlanNode::FromJson(const nlohmann::json &j) { + std::vector> exprs; + auto e1 = AbstractPlanNode::FromJson(j); + exprs.insert(exprs.end(), std::make_move_iterator(e1.begin()), std::make_move_iterator(e1.end())); + database_oid_ = j.at("database_oid").get(); + proc_oid_ = j.at("proc_oid").get(); + return exprs; +} + +DEFINE_JSON_BODY_DECLARATIONS(DropFunctionPlanNode); + +} // namespace noisepage::planner diff --git a/src/planner/plannodes/plan_node_defs.cpp b/src/planner/plannodes/plan_node_defs.cpp index c42ab91451..fd955e35a4 100644 --- a/src/planner/plannodes/plan_node_defs.cpp +++ b/src/planner/plannodes/plan_node_defs.cpp @@ -52,6 +52,8 @@ std::string PlanNodeTypeToString(PlanNodeType type) { return "DropTrigger"; case PlanNodeType::DROP_VIEW: return "DropView"; + case PlanNodeType::DROP_FUNC: + return "DropFunction"; case PlanNodeType::ANALYZE: return "Analyze"; case PlanNodeType::AGGREGATE: diff --git a/src/storage/recovery/recovery_manager.cpp b/src/storage/recovery/recovery_manager.cpp index 22c81fcb73..87dfaeeda9 100644 --- a/src/storage/recovery/recovery_manager.cpp +++ b/src/storage/recovery/recovery_manager.cpp @@ -1111,7 +1111,7 @@ uint32_t RecoveryManager::ProcessSpecialCasePGProcRecord( auto result UNUSED_ATTRIBUTE = catalog_->GetDatabaseCatalog(common::ManagedPointer(txn), redo_record->GetDatabaseOid()) - ->SetFunctionContextPointer(common::ManagedPointer(txn), proc_oid, nullptr); + ->SetFunctionContext(common::ManagedPointer(txn), proc_oid, nullptr); NOISEPAGE_ASSERT(result, "Setting to null did not work"); return 0; // No additional records processed } diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index 9716a88ca4..b2f480e3db 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -41,6 +41,7 @@ #include "planner/plannodes/create_namespace_plan_node.h" #include "planner/plannodes/create_table_plan_node.h" #include "planner/plannodes/drop_database_plan_node.h" +#include "planner/plannodes/drop_function_plan_node.h" #include "planner/plannodes/drop_index_plan_node.h" #include "planner/plannodes/drop_namespace_plan_node.h" #include "planner/plannodes/drop_table_plan_node.h" @@ -323,7 +324,8 @@ TrafficCopResult TrafficCop::ExecuteDropStatement( NOISEPAGE_ASSERT( query_type == network::QueryType::QUERY_DROP_TABLE || query_type == network::QueryType::QUERY_DROP_SCHEMA || query_type == network::QueryType::QUERY_DROP_INDEX || query_type == network::QueryType::QUERY_DROP_DB || - query_type == network::QueryType::QUERY_DROP_VIEW || query_type == network::QueryType::QUERY_DROP_TRIGGER, + query_type == network::QueryType::QUERY_DROP_VIEW || query_type == network::QueryType::QUERY_DROP_TRIGGER || + query_type == network::QueryType::QUERY_DROP_FUNCTION, "ExecuteDropStatement called with invalid QueryType."); switch (query_type) { case network::QueryType::QUERY_DROP_TABLE: { @@ -355,6 +357,13 @@ TrafficCopResult TrafficCop::ExecuteDropStatement( } break; } + case network::QueryType::QUERY_DROP_FUNCTION: { + if (execution::sql::DDLExecutors::DropFunctionExecutor( + physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { + return {ResultType::COMPLETE, 0U}; + } + break; + } default: { return {ResultType::ERROR, common::ErrorData(common::ErrorSeverity::ERROR, "unsupported DROP statement type", common::ErrorCode::ERRCODE_FEATURE_NOT_SUPPORTED)}; From dc3ca9a5a7060b02e365bf915aac5b25dbed0a8f Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Tue, 10 Aug 2021 11:15:41 -0400 Subject: [PATCH 104/139] fix bug in PgProcImpl::DropProcedure, DROP FUNCTION now works --- src/catalog/postgres/pg_proc_impl.cpp | 6 ++++-- src/execution/sql/ddl_executors.cpp | 16 +++++++++------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/catalog/postgres/pg_proc_impl.cpp b/src/catalog/postgres/pg_proc_impl.cpp index 67f6b209b6..6159edfbb2 100644 --- a/src/catalog/postgres/pg_proc_impl.cpp +++ b/src/catalog/postgres/pg_proc_impl.cpp @@ -203,8 +203,10 @@ bool PgProcImpl::DropProcedure(const common::ManagedPointerGet(proc_pm[PgProc::PRONAME.oid_], nullptr); auto proc_ns = *table_pr->Get(proc_pm[PgProc::PRONAMESPACE.oid_], nullptr); - auto *ctx_ptr = reinterpret_cast( - table_pr->AccessWithNullCheck(proc_pm[PgProc::PRO_CTX_PTR.oid_])); + + auto *ptr_ptr = reinterpret_cast(table_pr->AccessWithNullCheck(proc_pm[PgProc::PRO_CTX_PTR.oid_])); + NOISEPAGE_ASSERT(ptr_ptr != nullptr, "DropProcedure called on an invalid OID or before SetFunctionContext."); + auto *ctx_ptr = *reinterpret_cast(ptr_ptr); // Delete from pg_proc_name_index. { diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index b348f932ca..2991df853b 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -83,11 +83,10 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerGetFunctionName()); - sema::ErrorReporter error_reporter{region}; + auto region = std::make_unique(node->GetFunctionName()); + sema::ErrorReporter error_reporter{region.get()}; - auto ast_context = std::make_unique(region, &error_reporter); + auto ast_context = std::make_unique(region.get(), &error_reporter); compiler::CodeGen codegen{ast_context.get(), accessor.Get()}; util::RegionVector fn_params{codegen.GetAstContext()->GetRegion()}; @@ -142,12 +141,15 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointer( node->GetFunctionName(), parser::ReturnType::DataTypeToTypeId(node->GetReturnType()), std::move(types), - std::unique_ptr(region), std::move(ast_context), file); - if (!accessor->SetFunctionContext(proc_id, udf_context.get())) { + std::move(region), std::move(ast_context), file); + if (!accessor->SetFunctionContext(proc_id, udf_context.release())) { return false; } - accessor->GetTxn()->RegisterAbortAction([udf_context = udf_context.release()]() { delete udf_context; }); + // TODO(Kyle): We used to manually register an abort action here to destroy the + // function context in the event the transaction aborts, but this is already + // done in the catalog (in the call to CatalogAccessor::SetFunctionContext), is + // this the "ownership model" for transaction abort that we want? return true; } From 43cef633fb099741073853ce723c9f1894fbf4ee Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Tue, 10 Aug 2021 11:24:58 -0400 Subject: [PATCH 105/139] fix some bugs found by clang tidy --- src/execution/compiler/udf/udf_codegen.cpp | 1 + src/execution/sql/ddl_executors.cpp | 5 +---- src/include/parser/nodes.h | 4 ++-- src/parser/postgresparser.cpp | 2 +- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 6f86b851e3..5398d26059 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -223,6 +223,7 @@ void UdfCodegen::Visit(ast::udf::ValueExprAST *ast) { case sql::TypeId::Float: case sql::TypeId::Double: expr = codegen_->FloatToSql(val->GetReal().val_); + break; case sql::TypeId::Date: expr = codegen_->DateToSql(val->GetDateVal().val_); break; diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index 2991df853b..0de8f4710c 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -142,16 +142,13 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointer( node->GetFunctionName(), parser::ReturnType::DataTypeToTypeId(node->GetReturnType()), std::move(types), std::move(region), std::move(ast_context), file); - if (!accessor->SetFunctionContext(proc_id, udf_context.release())) { - return false; - } // TODO(Kyle): We used to manually register an abort action here to destroy the // function context in the event the transaction aborts, but this is already // done in the catalog (in the call to CatalogAccessor::SetFunctionContext), is // this the "ownership model" for transaction abort that we want? - return true; + return accessor->SetFunctionContext(proc_id, udf_context.release()); } bool DDLExecutors::CreateTableExecutor(const common::ManagedPointer node, diff --git a/src/include/parser/nodes.h b/src/include/parser/nodes.h index bf5bd68f18..f534bf7db4 100644 --- a/src/include/parser/nodes.h +++ b/src/include/parser/nodes.h @@ -22,11 +22,11 @@ using value = struct Value { using typname = struct TypName { NodeTag type_; List *names_; - Oid typeOid_; + Oid type_oid_; bool setof_; bool pct_type_; List *typmods_; int32_t typemod_; - List *arrayBounds_; + List *array_bounds_; int location_; }; diff --git a/src/parser/postgresparser.cpp b/src/parser/postgresparser.cpp index 06fa60fc37..92a46bdf8b 100644 --- a/src/parser/postgresparser.cpp +++ b/src/parser/postgresparser.cpp @@ -1796,7 +1796,7 @@ std::unique_ptr PostgresParser::DropFunctionTransform(ParseResult std::vector function_args{}; auto *arguments = reinterpret_cast(root->arguments_->head->data.ptr_value); - if (arguments != NULL) { + if (arguments != nullptr) { function_args.reserve(arguments->length); for (auto *cell = arguments->head; cell != nullptr; cell = cell->next) { // The descriptor for some types consists of a head node with From 5c8f21e79718dcde2d61ae5479077b1ca07402e5 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Tue, 10 Aug 2021 16:00:54 -0400 Subject: [PATCH 106/139] add required information for DROP FUNCTION IF EXISTS, just not tied to the interface because of limitations of the current binder API --- src/binder/bind_node_visitor.cpp | 3 +++ src/catalog/postgres/pg_proc_impl.cpp | 2 +- src/include/optimizer/logical_operators.h | 8 +++++++- src/include/optimizer/physical_operators.h | 10 ++++++++-- src/include/parser/drop_statement.h | 6 ++++-- .../plannodes/drop_function_plan_node.h | 15 ++++++++++++++- src/optimizer/logical_operators.cpp | 7 +++++-- src/optimizer/physical_operators.cpp | 7 +++++-- src/optimizer/plan_generator.cpp | 3 ++- .../query_to_operator_transformer.cpp | 3 ++- src/optimizer/rules/implementation_rules.cpp | 8 ++++---- src/parser/postgresparser.cpp | 4 ++-- .../plannodes/drop_function_plan_node.cpp | 18 +++++++++++++----- 13 files changed, 70 insertions(+), 24 deletions(-) diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 8c7f683ca6..d2ddabff52 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -327,6 +327,9 @@ void BindNodeVisitor::Visit(common::ManagedPointer node) ValidateDatabaseName(node->GetDatabaseName()); if (catalog_accessor_->GetProcOid(node->GetFunctionName(), node->GetFunctionArguments()) == catalog::INVALID_PROC_OID) { + // TODO(Kyle): We have all of the information needed for DROP FUNCTION IF EXISTS, + // but it does not seem that there is a way to communicate a non-error failure + // condition during binding, maybe we need to add an error severity to the exception? throw BINDER_EXCEPTION(fmt::format("function \"{}\" does not exist", node->GetFunctionName()), common::ErrorCode::ERRCODE_UNDEFINED_OBJECT); } diff --git a/src/catalog/postgres/pg_proc_impl.cpp b/src/catalog/postgres/pg_proc_impl.cpp index 6159edfbb2..508e2a5eae 100644 --- a/src/catalog/postgres/pg_proc_impl.cpp +++ b/src/catalog/postgres/pg_proc_impl.cpp @@ -160,7 +160,7 @@ bool PgProcImpl::CreateProcedure(const common::ManagedPointer txn, proc_oid_t proc) { - NOISEPAGE_ASSERT(proc != INVALID_PROC_OID, "Invalid oid passed"); + NOISEPAGE_ASSERT(proc != INVALID_PROC_OID, "DropProcedure called with invalid procedure OID"); const auto &name_pri = procs_name_index_->GetProjectedRowInitializer(); const auto &oid_pri = procs_oid_index_->GetProjectedRowInitializer(); diff --git a/src/include/optimizer/logical_operators.h b/src/include/optimizer/logical_operators.h index 6f1eff7ec5..48074687b3 100644 --- a/src/include/optimizer/logical_operators.h +++ b/src/include/optimizer/logical_operators.h @@ -1911,9 +1911,10 @@ class LogicalDropFunction : public OperatorNodeContents { /** * @param database_oid OID of the database * @param proc_oid OID of the function to be dropped + * @param if_exists `true` if `IF EXISTS` specified * @return LogicalDropFunction */ - static Operator Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid); + static Operator Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid, bool if_exists); /** * Copy @@ -1933,11 +1934,16 @@ class LogicalDropFunction : public OperatorNodeContents { /** @return The OID of the function to drop */ catalog::proc_oid_t GetFunctionOid() const { return proc_oid_; } + /** @return `true` if `IF EXISTS` specified */ + bool GetIfExists() const { return if_exists_; } + private: /** OID of the database */ catalog::db_oid_t database_oid_; /** OID of the function to drop */ catalog::proc_oid_t proc_oid_; + /** `true` if `IF EXISTS` specified */ + bool if_exists_; }; /** diff --git a/src/include/optimizer/physical_operators.h b/src/include/optimizer/physical_operators.h index 0b03dfee9b..40898c1253 100644 --- a/src/include/optimizer/physical_operators.h +++ b/src/include/optimizer/physical_operators.h @@ -2123,9 +2123,10 @@ class DropFunction : public OperatorNodeContents { /** * @param database_oid OID of database * @param proc_oid OID of view to drop + * @param if_exists `true` if `IF_EXISTS` specified * @return */ - static Operator Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid); + static Operator Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid, bool if_exists); /** @return A copy of this */ BaseOperatorNodeContents *Copy() const override; @@ -2140,13 +2141,18 @@ class DropFunction : public OperatorNodeContents { catalog::db_oid_t GetDatabaseOid() const { return database_oid_; } /** @return The OID of the function to drop */ - catalog::proc_oid_t GetFunctionOID() const { return proc_oid_; } + catalog::proc_oid_t GetFunctionOid() const { return proc_oid_; } + + /** @return `true` if `IF EXISTS` specified */ + bool GetIfExists() const { return if_exists_; } private: /** OID of the database */ catalog::db_oid_t database_oid_; /** OID of the view to drop */ catalog::proc_oid_t proc_oid_; + /** `true` if `IF EXISTS` specified */ + bool if_exists_; }; /** diff --git a/src/include/parser/drop_statement.h b/src/include/parser/drop_statement.h index b9647ac6fd..69e773fbd8 100644 --- a/src/include/parser/drop_statement.h +++ b/src/include/parser/drop_statement.h @@ -42,11 +42,13 @@ class DropStatement : public TableRefStatement { * @param table_info table information * @param function_name function name * @param function_args function argument type identifiers + * @param if_exists `true` if `IF EXISTS` specified, `false` otherwise */ DropStatement(std::unique_ptr table_info, std::string function_name, - std::vector &&function_args) + std::vector &&function_args, bool if_exists) : TableRefStatement(StatementType::DROP, std::move(table_info)), type_(DropType::kFunction), + if_exists_(if_exists), function_name_(std::move(function_name)), function_args_(std::move(function_args)) {} @@ -105,7 +107,7 @@ class DropStatement : public TableRefStatement { // TODO(Kyle): Maybe use a std::variant here to make // the overloading of this type less wasteful? - // DROP DATABASE, SCHEMA + // DROP DATABASE, SCHEMA, FUNCTION const bool if_exists_ = false; // DROP INDEX diff --git a/src/include/planner/plannodes/drop_function_plan_node.h b/src/include/planner/plannodes/drop_function_plan_node.h index 1d66604076..3120ce1614 100644 --- a/src/include/planner/plannodes/drop_function_plan_node.h +++ b/src/include/planner/plannodes/drop_function_plan_node.h @@ -45,6 +45,11 @@ class DropFunctionPlanNode : public AbstractPlanNode { return *this; } + Builder &SetIfExists(bool if_exists) { + if_exists_ = if_exists; + return *this; + } + /** * Build the drop function plan node * @return plan node @@ -56,6 +61,8 @@ class DropFunctionPlanNode : public AbstractPlanNode { catalog::db_oid_t database_oid_; /** OID of the procedure */ catalog::proc_oid_t proc_oid_; + /** `true` if `IF EXISTS` specified */ + bool if_exists_; }; private: @@ -64,11 +71,12 @@ class DropFunctionPlanNode : public AbstractPlanNode { * @param output_schema Schema representing the structure of the output of this plan node * @param database_oid OID of the database * @param proc_oid OID of the procedure + * @param if_exists `true` if `IF EXISTS` specified * @param plan_node_id Plan node ID */ DropFunctionPlanNode(std::vector> &&children, std::unique_ptr output_schema, catalog::db_oid_t database_oid, - catalog::proc_oid_t proc_oid, plan_node_id_t plan_node_id); + catalog::proc_oid_t proc_oid, bool if_exists, plan_node_id_t plan_node_id); public: /** Default constructor used for deserialization */ @@ -85,6 +93,9 @@ class DropFunctionPlanNode : public AbstractPlanNode { /** @return OID of the procedure */ catalog::proc_oid_t GetProcedureOid() const { return proc_oid_; } + /** @return `true` if `IF EXISTS` is specified */ + bool GetIfExists() const { return if_exists_; } + /** @return the hashed value of this plan node */ common::hash_t Hash() const override; @@ -102,6 +113,8 @@ class DropFunctionPlanNode : public AbstractPlanNode { catalog::db_oid_t database_oid_; /** OID of procedure */ catalog::proc_oid_t proc_oid_; + /** `true` if `IF EXISTS` specified */ + bool if_exists_; }; DEFINE_JSON_HEADER_DECLARATIONS(DropFunctionPlanNode); diff --git a/src/optimizer/logical_operators.cpp b/src/optimizer/logical_operators.cpp index 67e7cb9728..9cc2014e2c 100644 --- a/src/optimizer/logical_operators.cpp +++ b/src/optimizer/logical_operators.cpp @@ -1216,10 +1216,11 @@ bool LogicalDropView::operator==(const BaseOperatorNodeContents &r) { //===--------------------------------------------------------------------===// BaseOperatorNodeContents *LogicalDropFunction::Copy() const { return new LogicalDropFunction(*this); } -Operator LogicalDropFunction::Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid) { +Operator LogicalDropFunction::Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid, bool if_exists) { auto *op = new LogicalDropFunction(); op->database_oid_ = database_oid; op->proc_oid_ = proc_oid; + op->if_exists_ = if_exists; return Operator(common::ManagedPointer(op)); } @@ -1227,6 +1228,7 @@ common::hash_t LogicalDropFunction::Hash() const { common::hash_t hash = BaseOperatorNodeContents::Hash(); hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(database_oid_)); hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(proc_oid_)); + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(if_exists_)); return hash; } @@ -1234,7 +1236,8 @@ bool LogicalDropFunction::operator==(const BaseOperatorNodeContents &r) { if (r.GetOpType() != OpType::LOGICALDROPFUNCTION) return false; const LogicalDropFunction &node = *dynamic_cast(&r); if (database_oid_ != node.database_oid_) return false; - return proc_oid_ == node.proc_oid_; + if (proc_oid_ == node.proc_oid_) return false; + return if_exists_ == node.if_exists_; } //===--------------------------------------------------------------------===// diff --git a/src/optimizer/physical_operators.cpp b/src/optimizer/physical_operators.cpp index 406d99f5f0..79b02efd5e 100644 --- a/src/optimizer/physical_operators.cpp +++ b/src/optimizer/physical_operators.cpp @@ -1336,10 +1336,11 @@ bool DropView::operator==(const BaseOperatorNodeContents &r) { //===--------------------------------------------------------------------===// BaseOperatorNodeContents *DropFunction::Copy() const { return new DropFunction(*this); } -Operator DropFunction::Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid) { +Operator DropFunction::Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid, bool if_exists) { auto *op = new DropFunction(); op->database_oid_ = database_oid; op->proc_oid_ = proc_oid; + op->if_exists_ = if_exists; return Operator(common::ManagedPointer(op)); } @@ -1347,6 +1348,7 @@ common::hash_t DropFunction::Hash() const { common::hash_t hash = BaseOperatorNodeContents::Hash(); hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(database_oid_)); hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(proc_oid_)); + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(if_exists_)); return hash; } @@ -1354,7 +1356,8 @@ bool DropFunction::operator==(const BaseOperatorNodeContents &r) { if (r.GetOpType() != OpType::DROPFUNCTION) return false; const DropFunction &node = *dynamic_cast(&r); if (database_oid_ != node.database_oid_) return false; - return proc_oid_ == node.proc_oid_; + if (proc_oid_ != node.proc_oid_) return false; + return if_exists_ == node.if_exists_; } //===--------------------------------------------------------------------===// diff --git a/src/optimizer/plan_generator.cpp b/src/optimizer/plan_generator.cpp index f41d7582a6..c987856877 100644 --- a/src/optimizer/plan_generator.cpp +++ b/src/optimizer/plan_generator.cpp @@ -1141,7 +1141,8 @@ void PlanGenerator::Visit(const DropFunction *drop_function) { output_plan_ = planner::DropFunctionPlanNode::Builder() .SetPlanNodeId(GetNextPlanNodeID()) .SetDatabaseOid(drop_function->GetDatabaseOid()) - .SetProcedureOid(drop_function->GetFunctionOID()) + .SetProcedureOid(drop_function->GetFunctionOid()) + .SetIfExists(drop_function->GetIfExists()) .Build(); } diff --git a/src/optimizer/query_to_operator_transformer.cpp b/src/optimizer/query_to_operator_transformer.cpp index 9253ba6fbb..8fa213d6e4 100644 --- a/src/optimizer/query_to_operator_transformer.cpp +++ b/src/optimizer/query_to_operator_transformer.cpp @@ -638,7 +638,8 @@ void QueryToOperatorTransformer::Visit(common::ManagedPointer( - LogicalDropFunction::Make(db_oid_, accessor_->GetProcOid(op->GetFunctionName(), op->GetFunctionArguments())) + LogicalDropFunction::Make(db_oid_, accessor_->GetProcOid(op->GetFunctionName(), op->GetFunctionArguments()), + op->IsIfExists()) .RegisterWithTxnContext(txn_context), std::vector>{}, txn_context); case parser::DropStatement::DropType::kTrigger: diff --git a/src/optimizer/rules/implementation_rules.cpp b/src/optimizer/rules/implementation_rules.cpp index 29c8419a93..9e21cc762b 100644 --- a/src/optimizer/rules/implementation_rules.cpp +++ b/src/optimizer/rules/implementation_rules.cpp @@ -1202,10 +1202,10 @@ void LogicalDropFunctionToPhysicalDropFunction::Transform( auto df_op = input->Contents()->GetContentsAs(); NOISEPAGE_ASSERT(input->GetChildren().empty(), "LogicalDropFunction should have 0 children"); - auto op = std::make_unique(DropFunction::Make(df_op->GetDatabaseOid(), df_op->GetFunctionOid()) - .RegisterWithTxnContext(context->GetOptimizerContext()->GetTxn()), - std::vector>(), - context->GetOptimizerContext()->GetTxn()); + auto op = std::make_unique( + DropFunction::Make(df_op->GetDatabaseOid(), df_op->GetFunctionOid(), df_op->GetIfExists()) + .RegisterWithTxnContext(context->GetOptimizerContext()->GetTxn()), + std::vector>(), context->GetOptimizerContext()->GetTxn()); transformed->emplace_back(std::move(op)); } diff --git a/src/parser/postgresparser.cpp b/src/parser/postgresparser.cpp index 92a46bdf8b..3661a17fdd 100644 --- a/src/parser/postgresparser.cpp +++ b/src/parser/postgresparser.cpp @@ -1811,9 +1811,9 @@ std::unique_ptr PostgresParser::DropFunctionTransform(ParseResult } } } - + const auto if_exists = root->missing_ok_; return std::make_unique(std::make_unique("", "", ""), std::move(function_name), - std::move(function_args)); + std::move(function_args), if_exists); } // Postgres.DropStmt -> noisepage.DropStatement diff --git a/src/planner/plannodes/drop_function_plan_node.cpp b/src/planner/plannodes/drop_function_plan_node.cpp index a9c34f18df..7cc4ce9a3c 100644 --- a/src/planner/plannodes/drop_function_plan_node.cpp +++ b/src/planner/plannodes/drop_function_plan_node.cpp @@ -10,16 +10,17 @@ namespace noisepage::planner { std::unique_ptr DropFunctionPlanNode::Builder::Build() { - return std::unique_ptr(new DropFunctionPlanNode(std::move(children_), std::move(output_schema_), - database_oid_, proc_oid_, plan_node_id_)); + return std::unique_ptr(new DropFunctionPlanNode( + std::move(children_), std::move(output_schema_), database_oid_, proc_oid_, if_exists_, plan_node_id_)); } DropFunctionPlanNode::DropFunctionPlanNode(std::vector> &&children, std::unique_ptr output_schema, catalog::db_oid_t database_oid, - catalog::proc_oid_t proc_oid, plan_node_id_t plan_node_id) + catalog::proc_oid_t proc_oid, bool if_exists, plan_node_id_t plan_node_id) : AbstractPlanNode(std::move(children), std::move(output_schema), plan_node_id), database_oid_(database_oid), - proc_oid_(proc_oid) {} + proc_oid_(proc_oid), + if_exists_(if_exists) {} common::hash_t DropFunctionPlanNode::Hash() const { common::hash_t hash = AbstractPlanNode::Hash(); @@ -27,6 +28,8 @@ common::hash_t DropFunctionPlanNode::Hash() const { hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(database_oid_)); // Hash procedure oid hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(proc_oid_)); + // Hash `IF EXISTS` + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(if_exists_)); return hash; } @@ -38,9 +41,12 @@ bool DropFunctionPlanNode::operator==(const AbstractPlanNode &rhs) const { // Database OID if (database_oid_ != other.database_oid_) return false; - // Namespace OID + // Procedure OID if (proc_oid_ != other.proc_oid_) return false; + // IF EXISTS + if (if_exists_ != other.if_exists_) return false; + return true; } @@ -48,6 +54,7 @@ nlohmann::json DropFunctionPlanNode::ToJson() const { nlohmann::json j = AbstractPlanNode::ToJson(); j["database_oid"] = database_oid_; j["proc_oid"] = proc_oid_; + j["if_exists"] = if_exists_; return j; } @@ -57,6 +64,7 @@ std::vector> DropFunctionPlanNode::F exprs.insert(exprs.end(), std::make_move_iterator(e1.begin()), std::make_move_iterator(e1.end())); database_oid_ = j.at("database_oid").get(); proc_oid_ = j.at("proc_oid").get(); + if_exists_ = j.at("if_exists").get(); return exprs; } From 9b982a568b655fb3f009331a741a13bafd2d3370 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Wed, 11 Aug 2021 12:57:44 -0400 Subject: [PATCH 107/139] clean up some of the incomplete type enumerations --- src/catalog/catalog_accessor.cpp | 20 ++++- src/include/parser/udf/plpgsql_parser.h | 8 ++ src/parser/postgresparser.cpp | 5 ++ src/parser/udf/plpgsql_parser.cpp | 97 ++++++++++++++----------- 4 files changed, 86 insertions(+), 44 deletions(-) diff --git a/src/catalog/catalog_accessor.cpp b/src/catalog/catalog_accessor.cpp index 6c36a01dfb..855197f2af 100644 --- a/src/catalog/catalog_accessor.cpp +++ b/src/catalog/catalog_accessor.cpp @@ -256,15 +256,31 @@ void CatalogAccessor::RegisterTempTable(table_oid_t table_oid, const common::Man } type_oid_t CatalogAccessor::TypeNameToType(const std::string &type_name) { - // TODO(Kyle): Complete this function type_oid_t type; - if (type_name == "int4") { + if (type_name == "int2") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::SmallInt); + } else if (type_name == "int4") { type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer); + } else if (type_name == "int8") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::BigInt); } else if (type_name == "bool") { type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Boolean); + } else if (type_name == "float4") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Real); + } else if (type_name == "float8") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Double); + } else if (type_name == "numeric") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Decimal); + } else if (type_name == "bpchar") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Char); + } else if (type_name == "varchar" || type_name == "text") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Varchar); + } else if (type_name == "varbinary") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Varbinary); } else { type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid); } + return type; } diff --git a/src/include/parser/udf/plpgsql_parser.h b/src/include/parser/udf/plpgsql_parser.h index b4f2f6803f..c95ec9beb2 100644 --- a/src/include/parser/udf/plpgsql_parser.h +++ b/src/include/parser/udf/plpgsql_parser.h @@ -238,6 +238,14 @@ class PLpgSQLParser { */ static bool HasEnclosingQuery(ParseResult *parse_result); + /** + * Get the internal type identifier for given type name. + * @param type_name The typename + * @return The type identifier for the type, or empty std::optional + * in the case of an unsupported or unrecognized type + */ + static std::optional TypeNameToType(const std::string &type_name); + private: /** The UDF AST context */ common::ManagedPointer udf_ast_context_; diff --git a/src/parser/postgresparser.cpp b/src/parser/postgresparser.cpp index 6927ae40fb..8c1dd194b9 100644 --- a/src/parser/postgresparser.cpp +++ b/src/parser/postgresparser.cpp @@ -1824,6 +1824,11 @@ std::unique_ptr PostgresParser::DropFunctionTransform(ParseResult } } } + + for (const auto &t : function_args) { + std::cout << t << std::endl; + } + const auto if_exists = root->missing_ok_; return std::make_unique(std::make_unique("", "", ""), std::move(function_name), std::move(function_args), if_exists); diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index f7b4381d3c..b45d248040 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -50,12 +50,27 @@ static constexpr const char K_UPPER[] = "upper"; static constexpr const char K_STEP[] = "step"; static constexpr const char K_VAR[] = "var"; -/** Variable declaration type identifiers */ +/** Integral types */ +static constexpr const char DECL_TYPE_ID_SMALLINT[] = "smallint"; static constexpr const char DECL_TYPE_ID_INT[] = "int"; static constexpr const char DECL_TYPE_ID_INTEGER[] = "integer"; +static constexpr const char DECL_TYPE_ID_BIGINT[] = "bigint"; + +/** Variable-precision floating point */ +static constexpr const char DECL_TYPE_ID_REAL[] = "real"; +static constexpr const char DECL_TYPE_ID_FLOAT[] = "float"; static constexpr const char DECL_TYPE_ID_DOUBLE[] = "double"; + +/** Arbitrary-precision floating point */ static constexpr const char DECL_TYPE_ID_NUMERIC[] = "numeric"; +static constexpr const char DECL_TYPE_ID_DECIMAL[] = "decimal"; + +/** Character types */ +static constexpr const char DECL_TYPE_ID_CHAR[] = "char"; static constexpr const char DECL_TYPE_ID_VARCHAR[] = "varchar"; +static constexpr const char DECL_TYPE_ID_TEXT[] = "text"; + +/** Other */ static constexpr const char DECL_TYPE_ID_DATE[] = "date"; static constexpr const char DECL_TYPE_ID_RECORD[] = "record"; @@ -172,10 +187,6 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo // Track the local variable (for assignment) udf_ast_context_->AddLocal(var_name); - // Grab the type identifier from the PL/pgSQL parser - const std::string type = StringUtils::Strip( - StringUtils::Lower(json[K_PLPGSQL_VAR][K_DATATYPE][K_PLPGSQL_TYPE][K_TYPENAME].get())); - // Parse the initializer, if present std::unique_ptr initial{nullptr}; if (json[K_PLPGSQL_VAR].find(K_DEFAULT_VAL) != json[K_PLPGSQL_VAR].end()) { @@ -192,45 +203,17 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo // Otherwise, we perform a string comparison with the type identifier // for the variable to determine the type for the declaration - if ((type == DECL_TYPE_ID_INT) || (type == DECL_TYPE_ID_INTEGER)) { - udf_ast_context_->SetVariableType(var_name, execution::sql::SqlTypeId::Integer); - return std::make_unique(var_name, execution::sql::SqlTypeId::Integer, - std::move(initial)); - } - if ((type == DECL_TYPE_ID_DOUBLE) || (type == DECL_TYPE_ID_NUMERIC)) { - // TODO(Kyle): type.rfind("numeric") - // TODO(Kyle): Should this support FLOAT and DECMIAL as well?? - udf_ast_context_->SetVariableType(var_name, execution::sql::SqlTypeId::Decimal); - return std::make_unique(var_name, execution::sql::SqlTypeId::Decimal, - std::move(initial)); - } - if (type == DECL_TYPE_ID_VARCHAR) { - udf_ast_context_->SetVariableType(var_name, execution::sql::SqlTypeId::Varchar); - return std::make_unique(var_name, execution::sql::SqlTypeId::Varchar, - std::move(initial)); - } - if (type == DECL_TYPE_ID_DATE) { - udf_ast_context_->SetVariableType(var_name, execution::sql::SqlTypeId::Date); - return std::make_unique(var_name, execution::sql::SqlTypeId::Date, - std::move(initial)); - } - if (type == DECL_TYPE_ID_RECORD) { - // TODO(Kyle): I don't like modeling RECORD types with the Invalid - // SqlTypeId, need to find a better way to integrate the type system - udf_ast_context_->SetVariableType(var_name, execution::sql::SqlTypeId::Invalid); - return std::make_unique(var_name, execution::sql::SqlTypeId::Invalid, - std::move(initial)); + // Grab the type identifier from the PL/pgSQL parser + const std::string type_name = StringUtils::Strip( + StringUtils::Lower(json[K_PLPGSQL_VAR][K_DATATYPE][K_PLPGSQL_TYPE][K_TYPENAME].get())); + auto type = TypeNameToType(type_name); + if (!type.has_value()) { + throw PARSER_EXCEPTION( + fmt::format("PL/pgSQL Parser : unsupported type '{}' for variable '{}'", type_name, var_name)); } - throw PARSER_EXCEPTION(fmt::format("PL/pgSQL Parser : unsupported type '{}' for variable '{}'", type, var_name)); - } - - // TODO(Kyle): Support row types later - if (declaration_type == K_PLPGSQL_ROW) { - const auto var_name = json[K_PLPGSQL_ROW][K_REFNAME].get(); - NOISEPAGE_ASSERT(var_name == "*internal*", "Unexpected refname"); - udf_ast_context_->SetVariableType(var_name, execution::sql::SqlTypeId::Invalid); - return std::make_unique(var_name, execution::sql::SqlTypeId::Invalid, nullptr); + udf_ast_context_->SetVariableType(var_name, type.value()); + return std::make_unique(var_name, type.value(), std::move(initial)); } // TODO(Kyle): Need to handle other types like row, table etc; @@ -586,4 +569,34 @@ bool PLpgSQLParser::HasEnclosingQuery(ParseResult *parse_result) { return (target->GetExpressionType() == parser::ExpressionType::ROW_SUBQUERY); } +std::optional PLpgSQLParser::TypeNameToType(const std::string &type_name) { + // TODO(Kyle): This is awkward control flow because we + // model RECORD types with the SqlTypeId::Invalid type + execution::sql::SqlTypeId type; + if (type_name == DECL_TYPE_ID_SMALLINT) { + type = execution::sql::SqlTypeId::SmallInt; + } else if (type_name == DECL_TYPE_ID_INT || type_name == DECL_TYPE_ID_INTEGER) { + type = execution::sql::SqlTypeId::Integer; + } else if (type_name == DECL_TYPE_ID_BIGINT) { + type = execution::sql::SqlTypeId::BigInt; + } else if (type_name == DECL_TYPE_ID_REAL || type_name == DECL_TYPE_ID_FLOAT) { + type = execution::sql::SqlTypeId::Real; + } else if (type_name == DECL_TYPE_ID_DOUBLE) { + type = execution::sql::SqlTypeId::Double; + } else if (type_name == DECL_TYPE_ID_NUMERIC || type_name == DECL_TYPE_ID_DECIMAL) { + type = execution::sql::SqlTypeId::Decimal; + } else if (type_name == DECL_TYPE_ID_CHAR) { + type = execution::sql::SqlTypeId::Char; + } else if (type_name == DECL_TYPE_ID_VARCHAR || type_name == DECL_TYPE_ID_TEXT) { + type = execution::sql::SqlTypeId::Varchar; + } else if (type_name == DECL_TYPE_ID_DATE) { + type = execution::sql::SqlTypeId::Date; + } else if (type_name == DECL_TYPE_ID_RECORD) { + type = execution::sql::SqlTypeId::Invalid; + } else { + return std::nullopt; + } + return std::make_optional(type); +} + } // namespace noisepage::parser::udf From a592f37e3add07c335135fcca013f897f6b1ca2e Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Wed, 11 Aug 2021 17:18:34 -0400 Subject: [PATCH 108/139] add DROP FUNCTION to integration tests --- script/testing/junit/sql/udf.sql | 43 +++++++++- script/testing/junit/traces/udf.test | 122 ++++++++++++++++++++++++++- src/parser/postgresparser.cpp | 4 - src/parser/udf/plpgsql_parser.cpp | 17 +++- 4 files changed, 176 insertions(+), 10 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 464edaf2a1..ca73e22e1b 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -23,6 +23,8 @@ $$ LANGUAGE PLPGSQL; SELECT return_constant(); +DROP FUNCTION return_constant(); + -- ---------------------------------------------------------------------------- -- return_input() @@ -34,6 +36,8 @@ $$ LANGUAGE PLPGSQL; SELECT x, return_input(x) FROM integers; +DROP FUNCTION return_input(INT); + -- ---------------------------------------------------------------------------- -- return_sum() @@ -45,6 +49,8 @@ $$ LANGUAGE PLPGSQL; SELECT x, y, return_sum(x, y) FROM integers; +DROP FUNCTION return_sum(INT, INT); + -- ---------------------------------------------------------------------------- -- return_prod() @@ -56,6 +62,8 @@ $$ LANGUAGE PLPGSQL; SELECT x, y, return_product(x, y) FROM integers; +DROP FUNCTION return_product(INT, INT); + -- ---------------------------------------------------------------------------- -- integer_decl() @@ -69,6 +77,8 @@ $$ LANGUAGE PLPGSQL; SELECT integer_decl(); +DROP FUNCTION integer_decl(); + -- ---------------------------------------------------------------------------- -- conditional() -- @@ -88,6 +98,8 @@ $$ LANGUAGE PLPGSQL; SELECT x, conditional(x) FROM integers; +DROP FUNCTION conditional(INT); + -- ---------------------------------------------------------------------------- -- proc_while() @@ -104,6 +116,8 @@ $$ LANGUAGE PLPGSQL; SELECT proc_while(); +DROP FUNCTION proc_while(); + -- ---------------------------------------------------------------------------- -- proc_fori() -- @@ -136,10 +150,12 @@ $$ LANGUAGE PLPGSQL; SELECT sql_select_single_constant(); +DROP FUNCTION sql_select_single_constant(); + -- ---------------------------------------------------------------------------- -- sql_select_mutliple_constants() -CREATE FUNCTION sql_select_mutliple_constants() RETURNS INT AS $$ \ +CREATE FUNCTION sql_select_multiple_constants() RETURNS INT AS $$ \ DECLARE \ x INT; \ y INT; \ @@ -149,7 +165,9 @@ BEGIN \ END \ $$ LANGUAGE PLPGSQL; -SELECT sql_select_mutliple_constants(); +SELECT sql_select_multiple_constants(); + +DROP FUNCTION sql_select_multiple_constants(); -- ---------------------------------------------------------------------------- -- sql_select_constant_assignment() @@ -167,6 +185,8 @@ $$ LANGUAGE PLPGSQL; SELECT sql_select_constant_assignment(); +DROP FUNCTION sql_select_constant_assignment(); + -- ---------------------------------------------------------------------------- -- sql_embedded_agg_count() @@ -181,6 +201,8 @@ $$ LANGUAGE PLPGSQL; SELECT sql_embedded_agg_count(); +DROP FUNCTION sql_embedded_agg_count(); + -- ---------------------------------------------------------------------------- -- sql_embedded_agg_min() @@ -195,6 +217,8 @@ $$ LANGUAGE PLPGSQL; SELECT sql_embedded_agg_min(); +DROP FUNCTION sql_embedded_agg_min(); + -- ---------------------------------------------------------------------------- -- sql_embedded_agg_max() @@ -209,6 +233,8 @@ $$ LANGUAGE PLPGSQL; SELECT sql_embedded_agg_max(); +DROP FUNCTION sql_embedded_agg_max(); + -- ---------------------------------------------------------------------------- -- sql_embedded_agg_multi() @@ -223,6 +249,8 @@ BEGIN \ END; \ $$ LANGUAGE PLPGSQL; +DROP FUNCTION sql_embedded_agg_multi(); + -- ---------------------------------------------------------------------------- -- proc_fors_constant_var() @@ -241,6 +269,8 @@ $$ LANGUAGE PLPGSQL; SELECT proc_fors_constant_var(); +DROP FUNCTION proc_fors_constant_var(); + -- ---------------------------------------------------------------------------- -- proc_fors_constant_vars() @@ -260,6 +290,8 @@ $$ LANGUAGE PLPGSQL; SELECT proc_fors_constant_vars(); +DROP FUNCTION proc_fors_constant_vars(); + -- ---------------------------------------------------------------------------- -- proc_fors_rec() -- @@ -298,6 +330,8 @@ $$ LANGUAGE PLPGSQL; SELECT proc_fors_var(); +DROP FUNCTION proc_fors_var(); + -- ---------------------------------------------------------------------------- -- proc_call_*() @@ -339,3 +373,8 @@ END \ $$ LANGUAGE PLPGSQL; SELECT proc_call_select(); + +DROP FUNCTION proc_call_callee(); +DROP FUNCTION proc_call_ret(); +DROP FUNCTION proc_call_assign(); +DROP FUNCTION proc_call_select(); diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 2d99216478..69cc723530 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -64,6 +64,12 @@ SELECT return_constant(); statement ok +statement ok +DROP FUNCTION return_constant(); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -93,6 +99,12 @@ SELECT x, return_input(x) FROM integers; statement ok +statement ok +DROP FUNCTION return_input(INT); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -125,6 +137,12 @@ SELECT x, y, return_sum(x, y) FROM integers; statement ok +statement ok +DROP FUNCTION return_sum(INT, INT); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -157,6 +175,12 @@ SELECT x, y, return_product(x, y) FROM integers; statement ok +statement ok +DROP FUNCTION return_product(INT, INT); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -181,6 +205,12 @@ SELECT integer_decl(); statement ok +statement ok +DROP FUNCTION integer_decl(); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -219,6 +249,12 @@ SELECT x, conditional(x) FROM integers; statement ok +statement ok +DROP FUNCTION conditional(INT); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -243,6 +279,12 @@ SELECT proc_while(); statement ok +statement ok +DROP FUNCTION proc_while(); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -294,6 +336,12 @@ SELECT sql_select_single_constant(); statement ok +statement ok +DROP FUNCTION sql_select_single_constant(); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -304,13 +352,13 @@ statement ok statement ok -CREATE FUNCTION sql_select_mutliple_constants() RETURNS INT AS $$ DECLARE x INT; y INT; BEGIN SELECT 1, 2 INTO x, y; RETURN x + y; END $$ LANGUAGE PLPGSQL; +CREATE FUNCTION sql_select_multiple_constants() RETURNS INT AS $$ DECLARE x INT; y INT; BEGIN SELECT 1, 2 INTO x, y; RETURN x + y; END $$ LANGUAGE PLPGSQL; statement ok query I rowsort -SELECT sql_select_mutliple_constants(); +SELECT sql_select_multiple_constants(); ---- 3 @@ -318,6 +366,12 @@ SELECT sql_select_mutliple_constants(); statement ok +statement ok +DROP FUNCTION sql_select_multiple_constants(); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -342,6 +396,12 @@ SELECT sql_select_constant_assignment(); statement ok +statement ok +DROP FUNCTION sql_select_constant_assignment(); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -366,6 +426,12 @@ SELECT sql_embedded_agg_count(); statement ok +statement ok +DROP FUNCTION sql_embedded_agg_count(); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -390,6 +456,12 @@ SELECT sql_embedded_agg_min(); statement ok +statement ok +DROP FUNCTION sql_embedded_agg_min(); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -414,6 +486,12 @@ SELECT sql_embedded_agg_max(); statement ok +statement ok +DROP FUNCTION sql_embedded_agg_max(); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -429,6 +507,12 @@ CREATE FUNCTION sql_embedded_agg_multi() RETURNS INT AS $$ DECLARE minimum INT; statement ok +statement ok +DROP FUNCTION sql_embedded_agg_multi(); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -456,6 +540,12 @@ SELECT proc_fors_constant_var(); statement ok +statement ok +DROP FUNCTION proc_fors_constant_var(); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -483,6 +573,12 @@ SELECT proc_fors_constant_vars(); statement ok +statement ok +DROP FUNCTION proc_fors_constant_vars(); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -543,6 +639,12 @@ SELECT proc_fors_var(); statement ok +statement ok +DROP FUNCTION proc_fors_var(); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -609,3 +711,19 @@ SELECT proc_call_select(); 1 +statement ok + + +statement ok +DROP FUNCTION proc_call_callee(); + +statement ok +DROP FUNCTION proc_call_ret(); + +statement ok +DROP FUNCTION proc_call_assign(); + +statement ok +DROP FUNCTION proc_call_select(); + +statement ok diff --git a/src/parser/postgresparser.cpp b/src/parser/postgresparser.cpp index 8c1dd194b9..3f813f5a6f 100644 --- a/src/parser/postgresparser.cpp +++ b/src/parser/postgresparser.cpp @@ -1825,10 +1825,6 @@ std::unique_ptr PostgresParser::DropFunctionTransform(ParseResult } } - for (const auto &t : function_args) { - std::cout << t << std::endl; - } - const auto if_exists = root->missing_ok_; return std::make_unique(std::make_unique("", "", ""), std::move(function_name), std::move(function_args), if_exists); diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index b45d248040..c476176cf4 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -109,11 +109,16 @@ std::unique_ptr PLpgSQLParser::ParseFunction(const const auto function_body = json[K_ACTION][K_PLPGSQL_STMT_BLOCK][K_BODY]; std::vector> statements{}; - // Skip the first declaration in the datums list + // Skip the first declaration in the datums list; parse all declarations std::transform(declarations.cbegin() + 1, declarations.cend(), std::back_inserter(statements), [this](const nlohmann::json &declaration) -> std::unique_ptr { return ParseDecl(declaration); }); + // Remove the invalid declarations + statements.erase( + std::remove_if(statements.begin(), statements.end(), + [](std::unique_ptr &stmt) { return !static_cast(stmt); }), + statements.end()); statements.push_back(ParseBlock(function_body)); return std::make_unique(std::move(statements)); } @@ -216,7 +221,15 @@ std::unique_ptr PLpgSQLParser::ParseDecl(const nlo return std::make_unique(var_name, type.value(), std::move(initial)); } - // TODO(Kyle): Need to handle other types like row, table etc; + if (declaration_type == K_PLPGSQL_ROW && json[K_PLPGSQL_ROW][K_REFNAME].get() == "*internal*") { + // For query-variant for-loop structures (For-S in PL/pgSQL parlance) + // the Postgres parser generates a dummy internal declaration for the + // variable that is a target of the `SELECT INTO`, we can elide this + return std::unique_ptr{}; + } + + // TODO(Kyle): Handle RECORD declarations + // TODO(Kyle): Handle table row declarations throw PARSER_EXCEPTION(fmt::format("PL/pgSQL Parser : declaration type '{}' not supported", declaration_type)); } From 3bfb561c1d8e6208440d860d266bb0e68f59a4e6 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Wed, 11 Aug 2021 17:39:31 -0400 Subject: [PATCH 109/139] fix small bug in catalog test from refactor of DropProcedure --- src/catalog/postgres/pg_proc_impl.cpp | 6 ++-- test/catalog/catalog_test.cpp | 51 +++++++++++++++------------ 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/src/catalog/postgres/pg_proc_impl.cpp b/src/catalog/postgres/pg_proc_impl.cpp index be8f121cf3..ef1957541a 100644 --- a/src/catalog/postgres/pg_proc_impl.cpp +++ b/src/catalog/postgres/pg_proc_impl.cpp @@ -228,9 +228,8 @@ bool PgProcImpl::DropProcedure(const common::ManagedPointerGet(proc_pm[PgProc::PRONAME.oid_], nullptr); auto proc_ns = *table_pr->Get(proc_pm[PgProc::PRONAMESPACE.oid_], nullptr); + // Grab a pointer to the procedure context (if present) auto *ptr_ptr = reinterpret_cast(table_pr->AccessWithNullCheck(proc_pm[PgProc::PRO_CTX_PTR.oid_])); - NOISEPAGE_ASSERT(ptr_ptr != nullptr, "DropProcedure called on an invalid OID or before SetFunctionContext."); - auto *ctx_ptr = *reinterpret_cast(ptr_ptr); // Delete from pg_proc_name_index. { @@ -242,7 +241,8 @@ bool PgProcImpl::DropProcedure(const common::ManagedPointer(ptr_ptr); txn->RegisterCommitAction([=](transaction::DeferredActionManager *deferred_action_manager) { deferred_action_manager->RegisterDeferredAction( [=]() { deferred_action_manager->RegisterDeferredAction([=]() { delete ctx_ptr; }); }); diff --git a/test/catalog/catalog_test.cpp b/test/catalog/catalog_test.cpp index 75f84280ac..3674bae72b 100644 --- a/test/catalog/catalog_test.cpp +++ b/test/catalog/catalog_test.cpp @@ -122,56 +122,61 @@ TEST_F(CatalogTests, ProcTest) { // Check visibility to me VerifyCatalogTables(*accessor); - auto lan_oid = accessor->CreateLanguage("test_language"); - auto ns_oid = accessor->GetDefaultNamespace(); - - EXPECT_NE(lan_oid, catalog::INVALID_LANGUAGE_OID); + const auto language_oid = accessor->CreateLanguage("test_language"); + const auto namespace_oid = accessor->GetDefaultNamespace(); + EXPECT_NE(language_oid, catalog::INVALID_LANGUAGE_OID); + EXPECT_NE(namespace_oid, catalog::INVALID_NAMESPACE_OID); txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); + /** User-defined procedure */ + + // Create the procedure txn = txn_manager_->BeginTransaction(); accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); - // create a sample proc - auto procname = "sample"; - std::vector args = {"arg1", "arg2", "arg3"}; - std::vector arg_types = {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer), - accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Boolean), - accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::SmallInt)}; - - auto src = "int sample(arg1, arg2, arg3){return 2;}"; + const std::string procname{"sample"}; + const std::vector args{"arg1", "arg2", "arg3"}; + const std::vector arg_types{accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Boolean), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::SmallInt)}; + const std::string src{"int sample(arg1, arg2, arg3){return 2;}"}; auto proc_oid = accessor->CreateProcedure( - procname, lan_oid, ns_oid, catalog::INVALID_TYPE_OID, args, arg_types, {}, {}, + procname, language_oid, namespace_oid, catalog::INVALID_TYPE_OID, args, arg_types, {}, {}, catalog::type_oid_t(static_cast(execution::sql::SqlTypeId::Integer)), src, false); EXPECT_NE(proc_oid, catalog::INVALID_PROC_OID); txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); + // Query the catalog for the procedure txn = txn_manager_->BeginTransaction(); accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); - // make sure we didn't find this proc that we never added - auto found_oid = accessor->GetProcOid("bad_proc", arg_types); - EXPECT_EQ(found_oid, catalog::INVALID_PROC_OID); + // Make sure we didn't find this proc that we never added + EXPECT_EQ(accessor->GetProcOid("bad_proc", arg_types), catalog::INVALID_PROC_OID); + + // Look for proc that we actually added + const auto found_oid = accessor->GetProcOid(procname, arg_types); + EXPECT_EQ(found_oid, proc_oid); + EXPECT_TRUE(accessor->DropProcedure(found_oid)); - // look for proc that we actually added - found_oid = accessor->GetProcOid(procname, arg_types); + /** Builting procedure */ - auto sin_oid = accessor->GetProcOid("sin", {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Double)}); + // The procedure should already exist + const auto sin_oid = accessor->GetProcOid("sin", {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Double)}); EXPECT_NE(sin_oid, catalog::INVALID_PROC_OID); + // The function context should already exist auto sin_context = accessor->GetFunctionContext(sin_oid); EXPECT_TRUE(sin_context->IsBuiltin()); EXPECT_EQ(sin_context->GetBuiltin(), execution::ast::Builtin::Sin); EXPECT_EQ(sin_context->GetFunctionReturnType(), execution::sql::SqlTypeId::Double); - auto sin_args = sin_context->GetFunctionArgsType(); + + auto sin_args = sin_context->GetFunctionArgTypes(); EXPECT_EQ(sin_args.size(), 1); EXPECT_EQ(sin_args.back(), execution::sql::SqlTypeId::Double); EXPECT_EQ(sin_context->GetFunctionName(), "sin"); - EXPECT_EQ(found_oid, proc_oid); - auto result = accessor->DropProcedure(found_oid); - EXPECT_TRUE(result); txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); } From 8aa7ad3a5cb98f974b9d76d04ee9085ffb8a8d80 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Wed, 11 Aug 2021 19:01:40 -0400 Subject: [PATCH 110/139] make doxygen happy --- src/include/execution/ast/udf/udf_ast_nodes.h | 5 +- .../execution/compiler/compilation_context.h | 1 - .../execution/compiler/function_builder.h | 2 +- src/include/parser/udf/plpgsql_parser.h | 1 - .../plannodes/drop_function_plan_node.h | 4 + src/parser/postgresparser.cpp | 95 ++++++++----------- 6 files changed, 51 insertions(+), 57 deletions(-) diff --git a/src/include/execution/ast/udf/udf_ast_nodes.h b/src/include/execution/ast/udf/udf_ast_nodes.h index 9fe25067a2..422e3f1d1e 100644 --- a/src/include/execution/ast/udf/udf_ast_nodes.h +++ b/src/include/execution/ast/udf/udf_ast_nodes.h @@ -410,7 +410,10 @@ class ForIStmtAST : public StmtAST { /** * Construct a new ForIStmtAST instance. - * @param variables The collection of variables in the loop + * @param variable The loop induction variable + * @param lower The loop lower bound + * @param upper The loop upper bound + * @param step The loop step * @param body The body of the loop */ ForIStmtAST(std::string variable, std::unique_ptr lower, std::unique_ptr upper, diff --git a/src/include/execution/compiler/compilation_context.h b/src/include/execution/compiler/compilation_context.h index 770db3c8b1..1ab178089d 100644 --- a/src/include/execution/compiler/compilation_context.h +++ b/src/include/execution/compiler/compilation_context.h @@ -55,7 +55,6 @@ class CompilationContext { * @param mode The compilation mode. * @param override_qid Optional indicating how to override the plan's query id * @param plan_meta_data Query plan meta data (stores cardinality information) - * @param query_text The SQL query string (temporary) * @param output_callback The lambda utilized as the output callback for the query * @param context The AST context for the query */ diff --git a/src/include/execution/compiler/function_builder.h b/src/include/execution/compiler/function_builder.h index 50269ffce6..3353977b15 100644 --- a/src/include/execution/compiler/function_builder.h +++ b/src/include/execution/compiler/function_builder.h @@ -39,7 +39,7 @@ class FunctionBuilder { * Construct a new FunctionBuilder instance for a closure. * @param codegen The code generation instance * @param params The function parameters - * @param closures The function closures + * @param captures The function captures * @param return_type The return type representation of the function */ FunctionBuilder(CodeGen *codegen, util::RegionVector &¶ms, diff --git a/src/include/parser/udf/plpgsql_parser.h b/src/include/parser/udf/plpgsql_parser.h index c95ec9beb2..3ff99ec376 100644 --- a/src/include/parser/udf/plpgsql_parser.h +++ b/src/include/parser/udf/plpgsql_parser.h @@ -49,7 +49,6 @@ class PLpgSQLParser { * @param param_names The names of the function parameters * @param param_types The types of the function parameters * @param func_body The input source for the function - * @param ast_context The AST context to use during parsing * @return The abstract syntax tree for the source function */ std::unique_ptr Parse(const std::vector ¶m_names, diff --git a/src/include/planner/plannodes/drop_function_plan_node.h b/src/include/planner/plannodes/drop_function_plan_node.h index 3120ce1614..1b58493523 100644 --- a/src/include/planner/plannodes/drop_function_plan_node.h +++ b/src/include/planner/plannodes/drop_function_plan_node.h @@ -45,6 +45,10 @@ class DropFunctionPlanNode : public AbstractPlanNode { return *this; } + /** + * @param if_exists `true` if `IF EXISTS` is specified + * @return builder object + */ Builder &SetIfExists(bool if_exists) { if_exists_ = if_exists; return *this; diff --git a/src/parser/postgresparser.cpp b/src/parser/postgresparser.cpp index 3f813f5a6f..828e23d308 100644 --- a/src/parser/postgresparser.cpp +++ b/src/parser/postgresparser.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -88,6 +89,41 @@ void PostgresParser::ListTransform(ParseResult *parse_result, List *root, const } } +/** + * Get the data type for the specified type name. + * @param name The type name (as C-style string) + * @return The data type + */ +static std::optional TypeNameToDataType(const char *name) { + BaseFunctionParameter::DataType data_type; + if ((strcmp(name, "int") == 0) || (strcmp(name, "int4") == 0)) { + data_type = BaseFunctionParameter::DataType::INT; + } else if (strcmp(name, "varchar") == 0) { + data_type = BaseFunctionParameter::DataType::VARCHAR; + } else if (strcmp(name, "int8") == 0) { + data_type = BaseFunctionParameter::DataType::BIGINT; + } else if (strcmp(name, "int2") == 0) { + data_type = BaseFunctionParameter::DataType::SMALLINT; + } else if ((strcmp(name, "double") == 0) || (strcmp(name, "float8") == 0)) { + data_type = BaseFunctionParameter::DataType::DOUBLE; + } else if ((strcmp(name, "real") == 0) || (strcmp(name, "float4") == 0)) { + data_type = BaseFunctionParameter::DataType::FLOAT; + } else if (strcmp(name, "text") == 0) { + data_type = BaseFunctionParameter::DataType::TEXT; + } else if (strcmp(name, "bpchar") == 0) { + data_type = BaseFunctionParameter::DataType::CHAR; + } else if (strcmp(name, "tinyint") == 0) { + data_type = BaseFunctionParameter::DataType::TINYINT; + } else if (strcmp(name, "bool") == 0) { + data_type = BaseFunctionParameter::DataType::BOOL; + } else if (strcmp(name, "date") == 0) { + data_type = BaseFunctionParameter::DataType::DATE; + } else { + return std::nullopt; + } + return std::make_optional(data_type); +} + std::unique_ptr PostgresParser::NodeTransform(ParseResult *parse_result, Node *node, const std::string &query_string) { // TODO(WAN): Document what input is parsed to nullptr @@ -1668,70 +1704,23 @@ std::unique_ptr PostgresParser::FunctionParameterTransform(ParseR FunctionParameter *root) { // TODO(WAN): significant code duplication, refactor out char* -> DataType char *name = (reinterpret_cast(root->arg_type_->names_->tail->data.ptr_value)->val_.str_); - parser::FuncParameter::DataType data_type; - - if ((strcmp(name, "int") == 0) || (strcmp(name, "int4") == 0)) { - data_type = BaseFunctionParameter::DataType::INT; - } else if (strcmp(name, "varchar") == 0) { - data_type = BaseFunctionParameter::DataType::VARCHAR; - } else if (strcmp(name, "int8") == 0) { - data_type = BaseFunctionParameter::DataType::BIGINT; - } else if (strcmp(name, "int2") == 0) { - data_type = BaseFunctionParameter::DataType::SMALLINT; - } else if ((strcmp(name, "double") == 0) || (strcmp(name, "float8") == 0)) { - data_type = BaseFunctionParameter::DataType::DOUBLE; - } else if ((strcmp(name, "real") == 0) || (strcmp(name, "float4") == 0)) { - data_type = BaseFunctionParameter::DataType::FLOAT; - } else if (strcmp(name, "text") == 0) { - data_type = BaseFunctionParameter::DataType::TEXT; - } else if (strcmp(name, "bpchar") == 0) { - data_type = BaseFunctionParameter::DataType::CHAR; - } else if (strcmp(name, "tinyint") == 0) { - data_type = BaseFunctionParameter::DataType::TINYINT; - } else if (strcmp(name, "bool") == 0) { - data_type = BaseFunctionParameter::DataType::BOOL; - } else { + auto data_type = TypeNameToDataType(name); + if (!data_type.has_value()) { PARSER_LOG_AND_THROW("FunctionParameterTransform", "DataType", name); } auto param_name = root->name_ != nullptr ? root->name_ : ""; - auto result = std::make_unique(data_type, param_name); - return result; + return std::make_unique(data_type.value(), param_name); } // Postgres.TypeName -> noisepage.ReturnType std::unique_ptr PostgresParser::ReturnTypeTransform(ParseResult *parse_result, TypeName *root) { char *name = (reinterpret_cast(root->names_->tail->data.ptr_value)->val_.str_); - ReturnType::DataType data_type; - - if ((strcmp(name, "int") == 0) || (strcmp(name, "int4") == 0)) { - data_type = BaseFunctionParameter::DataType::INT; - } else if (strcmp(name, "varchar") == 0) { - data_type = BaseFunctionParameter::DataType::VARCHAR; - } else if (strcmp(name, "int8") == 0) { - data_type = BaseFunctionParameter::DataType::BIGINT; - } else if (strcmp(name, "int2") == 0) { - data_type = BaseFunctionParameter::DataType::SMALLINT; - } else if ((strcmp(name, "double") == 0) || (strcmp(name, "float8") == 0)) { - data_type = BaseFunctionParameter::DataType::DOUBLE; - } else if ((strcmp(name, "real") == 0) || (strcmp(name, "float4") == 0)) { - data_type = BaseFunctionParameter::DataType::FLOAT; - } else if (strcmp(name, "text") == 0) { - data_type = BaseFunctionParameter::DataType::TEXT; - } else if (strcmp(name, "bpchar") == 0) { - data_type = BaseFunctionParameter::DataType::CHAR; - } else if (strcmp(name, "tinyint") == 0) { - data_type = BaseFunctionParameter::DataType::TINYINT; - } else if (strcmp(name, "bool") == 0) { - data_type = BaseFunctionParameter::DataType::BOOL; - } else if (strcmp(name, "date") == 0) { - data_type = BaseFunctionParameter::DataType::DATE; - } else { + auto data_type = TypeNameToDataType(name); + if (!data_type.has_value()) { PARSER_LOG_AND_THROW("ReturnTypeTransform", "ReturnType", name); } - - auto result = std::make_unique(data_type); - return result; + return std::make_unique(data_type.value()); } // Postgres.Node -> noisepage.AbstractExpression From 4a1b637ffafb5190a78463bd6c5f99b462adbe67 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Wed, 11 Aug 2021 22:18:41 -0400 Subject: [PATCH 111/139] fix some compilation bugs that only show up in release --- src/binder/bind_node_visitor.cpp | 4 +- src/execution/compiler/udf/udf_codegen.cpp | 75 ++++++++++++++-------- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 35a4608fb0..83f53f6a5a 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -1161,8 +1161,8 @@ bool BindNodeVisitor::HaveUDFVariableRef(const std::string &identifier) const { void BindNodeVisitor::AddUDFVariableReference(common::ManagedPointer expr, const std::string &table_name, const std::string &column_name) { - const execution::sql::SqlTypeId type = udf_ast_context_->GetVariableTypeFailFast(table_name); - NOISEPAGE_ASSERT(type == execution::sql::SqlTypeId::Invalid, "Must be a RECORD type"); + NOISEPAGE_ASSERT(udf_ast_context_->GetVariableTypeFailFast(table_name) == execution::sql::SqlTypeId::Invalid, + "Must be a RECORD type"); // Locate the column name in the structure const auto fields = udf_ast_context_->GetRecordTypeFailFast(table_name); diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index a81f3c914a..e459560d63 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -59,14 +59,30 @@ void UdfCodegen::GenerateUDF(ast::udf::AbstractAST *ast) { ast->Accept(this); } catalog::type_oid_t UdfCodegen::GetCatalogTypeOidFromSQLType(ast::BuiltinType::Kind type) { switch (type) { + case ast::BuiltinType::Kind::Boolean: { + return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Boolean); + } case ast::BuiltinType::Kind::Integer: { return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Integer); } - case ast::BuiltinType::Kind::Boolean: { - return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Boolean); + case ast::BuiltinType::Kind::Real: { + return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Real); + } + case ast::BuiltinType::Kind::Decimal: { + return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Decimal); + } + case ast::BuiltinType::Kind::StringVal: { + return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Varchar); + } + case ast::BuiltinType::Kind::Date: { + return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Date); + } + case ast::BuiltinType::Kind::Timestamp: { + return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Timestamp); } default: - NOISEPAGE_ASSERT(false, "Unsupported parameter type"); + NOISEPAGE_ASSERT(false, "Invalid SQL type in function call"); + return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Invalid); } } @@ -91,28 +107,29 @@ void UdfCodegen::Visit(ast::udf::DynamicSQLStmtAST *ast) { } void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { - std::vector args_ast{}; - std::vector args_ast_region_vec{}; - std::vector arg_types{}; - - // First argument to UDF is an execution context - args_ast_region_vec.push_back(GetExecutionContext()); - - // TODO(Kyle): Is this the semantics we want? The execution - // context for the entire TPL program is shared? - - // TODO(Kyle): Clean up this logic - for (auto &arg : ast->Args()) { - ast::Expr *result = EvaluateExpression(arg.get()); - args_ast.push_back(result); - args_ast_region_vec.push_back(result); - auto *builtin = result->GetType()->SafeAs(); - NOISEPAGE_ASSERT(builtin != nullptr, "Parameter must be a built-in type"); - NOISEPAGE_ASSERT(builtin->IsSqlValueType(), "Parameter must be a SQL value type"); - arg_types.push_back(GetCatalogTypeOidFromSQLType(builtin->GetKind())); - } - - const auto proc_oid = accessor_->GetProcOid(ast->Callee(), arg_types); + const auto &args = ast->Args(); + + // Evaluate all arguments to call + std::vector arguments{}; + arguments.reserve(ast->Args().size()); + std::transform(args.cbegin(), args.cend(), std::back_inserter(arguments), + [this](const std::unique_ptr &expr) { return EvaluateExpression(expr.get()); }); + + NOISEPAGE_ASSERT(std::all_of(arguments.cbegin(), arguments.cend(), + [](const ast::Expr *arg) { + auto *builtin = arg->GetType()->SafeAs(); + return builtin != nullptr && builtin->IsSqlValueType(); + }), + "Invalid argument type in function call"); + + // Get argument types + std::vector argument_types{}; + std::transform(arguments.cbegin(), arguments.cend(), std::back_inserter(argument_types), + [this](const ast::Expr *expr) { + return GetCatalogTypeOidFromSQLType(expr->GetType()->SafeAs()->GetKind()); + }); + + const auto proc_oid = accessor_->GetProcOid(ast->Callee(), argument_types); if (proc_oid == catalog::INVALID_PROC_OID) { throw BINDER_EXCEPTION(fmt::format("Invalid function call '{}'", ast->Callee()), common::ErrorCode::ERRCODE_PLPGSQL_ERROR); @@ -120,9 +137,13 @@ void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { auto context = accessor_->GetFunctionContext(proc_oid); if (context->IsBuiltin()) { - ast::Expr *result = codegen_->CallBuiltin(context->GetBuiltin(), args_ast); + ast::Expr *result = codegen_->CallBuiltin(context->GetBuiltin(), arguments); SetExecutionResult(result); } else { + // NOTE(Kyle): This is an unfortunate operation because it + // requires shifting all elements in the vector, but we + // don't typically see functions with super-high arity + arguments.insert(arguments.begin(), GetExecutionContext()); auto it = SymbolTable().find(ast->Callee()); ast::Identifier ident_expr; if (it != SymbolTable().end()) { @@ -137,7 +158,7 @@ void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { ident_expr = codegen_->MakeFreshIdentifier(file->Declarations().back()->Name().GetString()); SymbolTable()[file->Declarations().back()->Name().GetString()] = ident_expr; } - ast::Expr *result = codegen_->Call(ident_expr, args_ast_region_vec); + ast::Expr *result = codegen_->Call(ident_expr, arguments); SetExecutionResult(result); } } From c87f7a304226ea94fb1abbe2ac393c3fcde6829b Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Wed, 11 Aug 2021 22:48:39 -0400 Subject: [PATCH 112/139] add some integration tests with character types --- script/testing/junit/sql/udf.sql | 27 ++++++++++++-- script/testing/junit/traces/udf.test | 56 ++++++++++++++++++++++++++-- 2 files changed, 77 insertions(+), 6 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index ca73e22e1b..2e500d0cc2 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -6,11 +6,12 @@ -- because all user-defined functions are implemented -- in the Postgres PL/SQL dialect, PL/pgSQL. --- Create a test table +-- Create test tables CREATE TABLE integers(x INT, y INT); +INSERT INTO integers(x, y) VALUES (1, 1), (2, 2), (3, 3); --- Insert some data -INSERT INTO integers (x, y) VALUES (1, 1), (2, 2), (3, 3); +CREATE TABLE strings(s TEXT); +INSERT INTO strings(s) VALUES ('aaa'), ('bbb'), ('ccc'); -- ---------------------------------------------------------------------------- -- return_constant() @@ -25,6 +26,16 @@ SELECT return_constant(); DROP FUNCTION return_constant(); +CREATE FUNCTION return_constant() RETURNS TEXT AS $$ \ +BEGIN \ + RETURN 'hello, functions'; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT return_constant(); + +DROP FUNCTION return_constant(); + -- ---------------------------------------------------------------------------- -- return_input() @@ -38,6 +49,16 @@ SELECT x, return_input(x) FROM integers; DROP FUNCTION return_input(INT); +CREATE FUNCTION return_input(x TEXT) RETURNS TEXT AS $$ \ +BEGIN \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT s, return_input(s) FROM strings; + +DROP FUNCTION return_input(TEXT); + -- ---------------------------------------------------------------------------- -- return_sum() diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 69cc723530..86358a8569 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -23,19 +23,22 @@ statement ok statement ok --- Create a test table +-- Create test tables statement ok CREATE TABLE integers(x INT, y INT); +statement ok +INSERT INTO integers(x, y) VALUES (1, 1), (2, 2), (3, 3); + statement ok statement ok --- Insert some data +CREATE TABLE strings(s TEXT); statement ok -INSERT INTO integers (x, y) VALUES (1, 1), (2, 2), (3, 3); +INSERT INTO strings(s) VALUES ('aaa'), ('bbb'), ('ccc'); statement ok @@ -70,6 +73,27 @@ DROP FUNCTION return_constant(); statement ok +statement ok +CREATE FUNCTION return_constant() RETURNS TEXT AS $$ BEGIN RETURN 'hello, functions'; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query T rowsort +SELECT return_constant(); +---- +hello, functions + + +statement ok + + +statement ok +DROP FUNCTION return_constant(); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -105,6 +129,32 @@ DROP FUNCTION return_input(INT); statement ok +statement ok +CREATE FUNCTION return_input(x TEXT) RETURNS TEXT AS $$ BEGIN RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query TT rowsort +SELECT s, return_input(s) FROM strings; +---- +aaa +aaa +bbb +bbb +ccc +ccc + + +statement ok + + +statement ok +DROP FUNCTION return_input(TEXT); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- From 49faf61c0841fc7f3dd1dd29e6779f32f751e9e7 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 12 Aug 2021 15:02:37 -0400 Subject: [PATCH 113/139] fighting weird function name collision bug that only manifests in CI... --- script/testing/junit/sql/udf.sql | 6 +++--- script/testing/junit/traces/udf.test | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 2e500d0cc2..21c632624c 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -26,15 +26,15 @@ SELECT return_constant(); DROP FUNCTION return_constant(); -CREATE FUNCTION return_constant() RETURNS TEXT AS $$ \ +CREATE FUNCTION return_constant_str() RETURNS TEXT AS $$ \ BEGIN \ RETURN 'hello, functions'; \ END \ $$ LANGUAGE PLPGSQL; -SELECT return_constant(); +SELECT return_constant_str(); -DROP FUNCTION return_constant(); +DROP FUNCTION return_constant_str(); -- ---------------------------------------------------------------------------- -- return_input() diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 86358a8569..58542e2691 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -74,13 +74,13 @@ statement ok statement ok -CREATE FUNCTION return_constant() RETURNS TEXT AS $$ BEGIN RETURN 'hello, functions'; END $$ LANGUAGE PLPGSQL; +CREATE FUNCTION return_constant_str() RETURNS TEXT AS $$ BEGIN RETURN 'hello, functions'; END $$ LANGUAGE PLPGSQL; statement ok query T rowsort -SELECT return_constant(); +SELECT return_constant_str(); ---- hello, functions @@ -89,7 +89,7 @@ statement ok statement ok -DROP FUNCTION return_constant(); +DROP FUNCTION return_constant_str(); statement ok From d77b17a14b2bdb9f4d1ab25a474d7cdb0ffc2342 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 12 Aug 2021 23:14:26 -0400 Subject: [PATCH 114/139] fix memory leak from bad merge resolution --- src/catalog/postgres/pg_proc_impl.cpp | 1 + src/parser/udf/plpgsql_parser.cpp | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/catalog/postgres/pg_proc_impl.cpp b/src/catalog/postgres/pg_proc_impl.cpp index ef1957541a..3ea3e42850 100644 --- a/src/catalog/postgres/pg_proc_impl.cpp +++ b/src/catalog/postgres/pg_proc_impl.cpp @@ -180,6 +180,7 @@ bool PgProcImpl::CreateProcedure(const common::ManagedPointer PLpgSQLParser::Parse( const std::vector ¶m_names, const std::vector ¶m_types, const std::string &func_body) { - auto result = PLpgSQLParseResult{pg_query_parse_plpgsql(func_body.c_str())}; - if ((*result).error != nullptr) { - throw PARSER_EXCEPTION(fmt::format("PL/pgSQL parser : {}", (*result).error->message)); + auto* ctx = pg_query_parse_init(); + auto result = pg_query_parse_plpgsql(func_body.c_str()); + + if (result.error != nullptr) { + pg_query_parse_finish(ctx); + const auto message = fmt::format("PL/pgSQL parser : {}", result.error->message); + pg_query_free_plpgsql_parse_result(result); + throw PARSER_EXCEPTION(message); } // The result is a list, we need to wrap it - const auto ast_json_str = fmt::format("{{ \"{}\" : {} }}", K_FUNCTION_LIST, (*result).plpgsql_funcs); + const auto ast_json_str = fmt::format("{{ \"{}\" : {} }}", K_FUNCTION_LIST, result.plpgsql_funcs); + + // Now finished with the raw parse result from pg_query + pg_query_parse_finish(ctx); + pg_query_free_plpgsql_parse_result(result); const nlohmann::json ast_json = nlohmann::json::parse(ast_json_str); const auto function_list = ast_json[K_FUNCTION_LIST]; From d5f194b0c795cda8b8f67f4a656b2a4ec45b2d60 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 13 Aug 2021 10:24:51 -0400 Subject: [PATCH 115/139] fix memory leak in libpg_query --- src/parser/udf/plpgsql_parser.cpp | 17 ++++------------- .../libpg_query/src/pg_query_parse_plpgsql.c | 1 + 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index deb1480108..a0e394df41 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -77,22 +77,13 @@ static constexpr const char DECL_TYPE_ID_RECORD[] = "record"; std::unique_ptr PLpgSQLParser::Parse( const std::vector ¶m_names, const std::vector ¶m_types, const std::string &func_body) { - auto* ctx = pg_query_parse_init(); - auto result = pg_query_parse_plpgsql(func_body.c_str()); - - if (result.error != nullptr) { - pg_query_parse_finish(ctx); - const auto message = fmt::format("PL/pgSQL parser : {}", result.error->message); - pg_query_free_plpgsql_parse_result(result); - throw PARSER_EXCEPTION(message); + PLpgSQLParseResult result{pg_query_parse_plpgsql(func_body.c_str())}; + if ((*result).error != nullptr) { + throw PARSER_EXCEPTION(fmt::format("PL/pgSQL parser : {}", (*result).error->message)); } // The result is a list, we need to wrap it - const auto ast_json_str = fmt::format("{{ \"{}\" : {} }}", K_FUNCTION_LIST, result.plpgsql_funcs); - - // Now finished with the raw parse result from pg_query - pg_query_parse_finish(ctx); - pg_query_free_plpgsql_parse_result(result); + const auto ast_json_str = fmt::format("{{ \"{}\" : {} }}", K_FUNCTION_LIST, (*result).plpgsql_funcs); const nlohmann::json ast_json = nlohmann::json::parse(ast_json_str); const auto function_list = ast_json[K_FUNCTION_LIST]; diff --git a/third_party/libpg_query/src/pg_query_parse_plpgsql.c b/third_party/libpg_query/src/pg_query_parse_plpgsql.c index ba102157eb..af6773e04b 100644 --- a/third_party/libpg_query/src/pg_query_parse_plpgsql.c +++ b/third_party/libpg_query/src/pg_query_parse_plpgsql.c @@ -439,6 +439,7 @@ PgQueryPlpgsqlParseResult pg_query_parse_plpgsql(const char* input) result.plpgsql_funcs[strlen(result.plpgsql_funcs) - 2] = '\n'; result.plpgsql_funcs[strlen(result.plpgsql_funcs) - 1] = ']'; + free(parse_result.stderr_buffer); pg_query_exit_memory_context(ctx); return result; From de6e801937377bdb67f885064b27f894a01290cd Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 13 Aug 2021 17:41:33 -0400 Subject: [PATCH 116/139] remove old tpl tests, remove flag from debug compilation --- CMakeLists.txt | 2 +- sample_tpl/agg-lambda.tpl | 81 --------------------------------------- sample_tpl/lambda2.tpl | 9 ----- sample_tpl/tpl_tests.txt | 1 - 4 files changed, 1 insertion(+), 92 deletions(-) delete mode 100644 sample_tpl/agg-lambda.tpl delete mode 100644 sample_tpl/lambda2.tpl diff --git a/CMakeLists.txt b/CMakeLists.txt index 961e2ac79c..6f5b754caa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -235,7 +235,7 @@ set(NOISEPAGE_INCLUDE_DIRECTORIES "") # Add compilation flags to NOISEPAGE_COMPILE_OPTIONS based on the current CMAKE_BUILD_TYPE. string(TOUPPER ${CMAKE_BUILD_TYPE} CMAKE_BUILD_TYPE) if ("${CMAKE_BUILD_TYPE}" STREQUAL "DEBUG") - list(APPEND NOISEPAGE_COMPILE_OPTIONS "-ggdb" "-O0" "-fno-omit-frame-pointer" "-fno-optimize-sibling-calls" "-Wfatal-errors") + list(APPEND NOISEPAGE_COMPILE_OPTIONS "-ggdb" "-O0" "-fno-omit-frame-pointer" "-fno-optimize-sibling-calls") elseif ("${CMAKE_BUILD_TYPE}" STREQUAL "FASTDEBUG") list(APPEND NOISEPAGE_COMPILE_OPTIONS "-ggdb" "-O1" "-fno-omit-frame-pointer" "-fno-optimize-sibling-calls") elseif ("${CMAKE_BUILD_TYPE}" STREQUAL "RELEASE") diff --git a/sample_tpl/agg-lambda.tpl b/sample_tpl/agg-lambda.tpl deleted file mode 100644 index e1c5a6df30..0000000000 --- a/sample_tpl/agg-lambda.tpl +++ /dev/null @@ -1,81 +0,0 @@ -// Expected output: 10 -// SQL: SELECT col_b, count(col_a) FROM test_1 GROUP BY col_b - -struct State { - table: AggregationHashTable - count: int32 -} - -struct OutputStruct { - out0: Integer -} - -struct Agg { - key: Integer - count: CountStarAggregate -} - -fun setUpState(execCtx: *ExecutionContext, state: *State) -> nil { - state.count = 0 - @aggHTInit(&state.table, execCtx, @sizeOf(Agg)) -} - -fun tearDownState(state: *State) -> nil { - @aggHTFree(&state.table) -} - -fun keyCheck(agg: *Agg, vpi: *VectorProjectionIterator) -> bool { - var key = @vpiGetInt(vpi, 1) - return @sqlToBool(key == agg.key) -} - -fun constructAgg(agg: *Agg, vpi: *VectorProjectionIterator) -> nil { - agg.key = @vpiGetInt(vpi, 1) - @aggInit(&agg.count) -} - -fun updateAgg(agg: *Agg, vpi: *VectorProjectionIterator) -> nil { - var input = @vpiGetInt(vpi, 0) - @aggAdvance(&agg.count, &input) -} - -fun pipeline_1(execCtx: *ExecutionContext, state: *State, lam : lambda [(Integer)->nil] ) -> nil { - var ht = &state.table - var tvi: TableVectorIterator - var table_oid = @testCatalogLookup(execCtx, "test_1", "") - var col_oids: [2]uint32 - col_oids[0] = @testCatalogLookup(execCtx, "test_1", "cola") - col_oids[1] = @testCatalogLookup(execCtx, "test_1", "colb") - for (@tableIterInit(&tvi, execCtx, table_oid, col_oids); @tableIterAdvance(&tvi); ) { - var vec = @tableIterGetVPI(&tvi) - for (; @vpiHasNext(vec); @vpiAdvance(vec)) { - var output_row: OutputStruct - output_row.out0 = @vpiGetIntNull(vec, 0) - lam(output_row.out0) - } - } - @tableIterClose(&tvi) -} - -fun execQuery(execCtx: *ExecutionContext, qs: *State, lam : lambda [(Integer)->nil] ) -> nil { - pipeline_1(execCtx, qs, lam) -} - -fun main(execCtx: *ExecutionContext) -> int32 { - var count : Integer - count = @intToSql(0) - var lam = lambda [count] (x : Integer) -> nil { - count = count + 1 - } - var state: State - - setUpState(execCtx, &state) - execQuery(execCtx, &state, lam) - tearDownState(&state) - - var ret = state.count - if(count > 0) { - return 1 - } - return 0 -} diff --git a/sample_tpl/lambda2.tpl b/sample_tpl/lambda2.tpl deleted file mode 100644 index 3446cc80d5..0000000000 --- a/sample_tpl/lambda2.tpl +++ /dev/null @@ -1,9 +0,0 @@ -// Expected output: 2 - -fun addOne(x: int32) -> int32 { - return x + 1 -} - -fun main() -> int32 { - return addOne(1) -} diff --git a/sample_tpl/tpl_tests.txt b/sample_tpl/tpl_tests.txt index b54b45cae3..daf92456bb 100644 --- a/sample_tpl/tpl_tests.txt +++ b/sample_tpl/tpl_tests.txt @@ -68,7 +68,6 @@ types/timestamps.tpl,false,0 ################################################################################ agg.tpl,true,10 -#agg-lambda.tpl,true,10 TODO(Kyle): Requires lambdas #agg-vec.tpl,true,10 doesn't work on prashanth's branch #agg-vec-filter.tpl,true,10 doesn't work on prashanth's branch cte_scan_temp_table_insert.tpl,true,11 From 8dc0cff11d3bc820aa0b0d51d35d8b445982061c Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 13 Aug 2021 18:31:41 -0400 Subject: [PATCH 117/139] remove old comment --- src/execution/compiler/udf/udf_codegen.cpp | 8 ++++---- src/execution/parsing/parser.cpp | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index e459560d63..64927812ae 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -102,6 +102,10 @@ void UdfCodegen::Visit(ast::udf::AbstractAST *ast) { throw NOT_IMPLEMENTED_EXCEPTION("UdfCodegen::Visit(AbstractAST*)"); } +void UdfCodegen::Visit(ast::udf::StmtAST *ast) { UNREACHABLE("Not implemented"); } + +void UdfCodegen::Visit(ast::udf::ExprAST *ast) { UNREACHABLE("Not implemented"); } + void UdfCodegen::Visit(ast::udf::DynamicSQLStmtAST *ast) { throw NOT_IMPLEMENTED_EXCEPTION("UdfCodegen::Visit(DynamicSQLStmtAST*)"); } @@ -163,10 +167,6 @@ void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { } } -void UdfCodegen::Visit(ast::udf::StmtAST *ast) { UNREACHABLE("Not implemented"); } - -void UdfCodegen::Visit(ast::udf::ExprAST *ast) { UNREACHABLE("Not implemented"); } - void UdfCodegen::Visit(ast::udf::DeclStmtAST *ast) { if (ast->Name() == INTERNAL_DECL_ID) { return; diff --git a/src/execution/parsing/parser.cpp b/src/execution/parsing/parser.cpp index f8daaedb29..549cefa852 100644 --- a/src/execution/parsing/parser.cpp +++ b/src/execution/parsing/parser.cpp @@ -451,7 +451,6 @@ ast::Expr *Parser::ParseLambdaExpr() { auto *fun = ParseFunctionLitExpr()->As(); // Create declaration - // ast::FunctionDecl *decl = node_factory_->NewFunctionDecl(position, name fun); auto *lambda = node_factory_->NewLambdaExpr(position, fun, std::move(captures)); // Done From 727933890223dbb68daa1600cc76293cb9a7ec8d Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 9 Sep 2021 10:26:05 -0400 Subject: [PATCH 118/139] update installation script, add requirements.txt --- requirements.txt | 29 ++++++++++++++++++++++++ script/installation/packages.sh | 40 +-------------------------------- 2 files changed, 30 insertions(+), 39 deletions(-) create mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..6dbc5c5462 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,29 @@ +certifi==2021.5.30 +charset-normalizer==2.0.4 +coverage==5.5 +distro==1.6.0 +idna==3.2 +importlib-metadata==4.8.1 +joblib==1.0.1 +lightgbm==3.2.1 +numpy==1.21.2 +pandas==1.1.5 +prettytable==2.2.0 +psutil==5.8.0 +psycopg2==2.9.1 +pyarrow==5.0.0 +python-dateutil==2.8.2 +pytz==2021.1 +pyzmq==22.2.1 +requests==2.26.0 +scikit-learn==0.24.2 +scipy==1.7.1 +six==1.16.0 +sklearn==0.0 +threadpoolctl==2.2.0 +torch==1.9.0 +tqdm==4.62.2 +typing-extensions==3.10.0.2 +urllib3==1.26.6 +wcwidth==0.2.5 +zipp==3.5.0 diff --git a/script/installation/packages.sh b/script/installation/packages.sh index c3f7c6c0e3..58e5c3dbbb 100755 --- a/script/installation/packages.sh +++ b/script/installation/packages.sh @@ -30,12 +30,12 @@ LINUX_BUILD_PACKAGES=(\ "llvm-8" \ "pkg-config" \ "postgresql-client" \ - "python3-pip" \ "ninja-build" "wget" \ "zlib1g-dev" \ "time" \ ) + LINUX_TEST_PACKAGES=(\ "ant" \ "ccache" \ @@ -44,27 +44,6 @@ LINUX_TEST_PACKAGES=(\ "lsof" \ ) -# Packages to be installed through pip3. -PYTHON_BUILD_PACKAGES=( -) -PYTHON_TEST_PACKAGES=(\ - "distro" \ - "lightgbm" \ - "numpy" \ - "pandas" \ - "prettytable" \ - "psutil" \ - "psycopg2" \ - "pyarrow" \ - "pyzmq" \ - "requests" \ - "sklearn" \ - "torch" \ - "tqdm" \ - "coverage" \ -) - - ## ================================================================= @@ -143,12 +122,6 @@ install() { esac } -install_pip() { - curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py - python get-pip.py - rm get-pip.py -} - install_linux() { # Update apt-get. apt-get -y update @@ -160,17 +133,6 @@ install_linux() { if [ "$INSTALL_TYPE" == "test" ] || [ "$INSTALL_TYPE" = "all" ]; then apt-get -y install $( IFS=$' '; echo "${LINUX_TEST_PACKAGES[*]}" ) fi - - if [ "$INSTALL_TYPE" == "build" ] || [ "$INSTALL_TYPE" = "all" ]; then - for pkg in "${PYTHON_BUILD_PACKAGES[@]}"; do - python3 -m pip show $pkg || python3 -m pip install $pkg - done - fi - if [ "$INSTALL_TYPE" == "test" ] || [ "$INSTALL_TYPE" = "all" ]; then - for pkg in "${PYTHON_TEST_PACKAGES[@]}"; do - python3 -m pip show $pkg || python3 -m pip install $pkg - done - fi } main "$@" From 9d5a4a40d8432e498963dc29b4d2a1ddc42e2f4b Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 23 Sep 2021 10:04:13 -0400 Subject: [PATCH 119/139] update postgresparser --- src/parser/postgresparser.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/parser/postgresparser.cpp b/src/parser/postgresparser.cpp index 828e23d308..ffc4ce6e2a 100644 --- a/src/parser/postgresparser.cpp +++ b/src/parser/postgresparser.cpp @@ -108,6 +108,8 @@ static std::optional TypeNameToDataType(const c data_type = BaseFunctionParameter::DataType::DOUBLE; } else if ((strcmp(name, "real") == 0) || (strcmp(name, "float4") == 0)) { data_type = BaseFunctionParameter::DataType::FLOAT; + } else if ((strcmp(name, "decimal") == 0) || strcmp(name, "numeric") == 0) { + return BaseFunctionParameter::DataType::DECIMAL; } else if (strcmp(name, "text") == 0) { data_type = BaseFunctionParameter::DataType::TEXT; } else if (strcmp(name, "bpchar") == 0) { @@ -1704,6 +1706,7 @@ std::unique_ptr PostgresParser::FunctionParameterTransform(ParseR FunctionParameter *root) { // TODO(WAN): significant code duplication, refactor out char* -> DataType char *name = (reinterpret_cast(root->arg_type_->names_->tail->data.ptr_value)->val_.str_); + std::cout << name << std::endl; auto data_type = TypeNameToDataType(name); if (!data_type.has_value()) { PARSER_LOG_AND_THROW("FunctionParameterTransform", "DataType", name); From b42be3b270448f5e18bd2b052df19aee57c19d37 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 23 Sep 2021 11:20:09 -0400 Subject: [PATCH 120/139] issues with types --- src/execution/sql/ddl_executors.cpp | 4 +++- src/execution/sql/sql.cpp | 7 +++++-- src/include/execution/sql/sql.h | 10 ++++++++++ src/include/parser/create_function_statement.h | 7 ++----- src/parser/postgresparser.cpp | 1 - src/parser/udf/plpgsql_parser.cpp | 3 ++- 6 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index 08c29e3bbf..b4e44aaf08 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -99,8 +99,10 @@ bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointerGetFunctionParameterNames().size(); i++) { + const auto raw = node->GetFunctionParameterTypes()[i]; + (void)raw; const auto name = node->GetFunctionParameterNames()[i]; - const auto type = parser::ReturnType::DataTypeToTypeId(node->GetFunctionParameterTypes()[i]); + const auto type = parser::BaseFunctionParameter::DataTypeToTypeId(node->GetFunctionParameterTypes()[i]); fn_params.emplace_back( codegen.MakeField(ast_context->GetIdentifier(name), codegen.TplType(execution::sql::GetTypeId(type)))); } diff --git a/src/execution/sql/sql.cpp b/src/execution/sql/sql.cpp index 4922c57f1d..96dac8c124 100644 --- a/src/execution/sql/sql.cpp +++ b/src/execution/sql/sql.cpp @@ -272,8 +272,12 @@ std::string SqlTypeIdToString(SqlTypeId type) { return "Integer"; case SqlTypeId::BigInt: return "BigInt"; + case SqlTypeId::Real: + return "Real"; case SqlTypeId::Double: return "Double"; + case SqlTypeId::Decimal: + return "Decimal"; case SqlTypeId::Date: return "Date"; case SqlTypeId::Timestamp: @@ -329,8 +333,7 @@ SqlTypeId SqlTypeIdFromString(const std::string &type_string) { } TypeId GetTypeId(SqlTypeId frontend_type) { - execution::sql::TypeId execution_type_id; - + TypeId execution_type_id; switch (frontend_type) { case SqlTypeId::Boolean: execution_type_id = execution::sql::TypeId::Boolean; diff --git a/src/include/execution/sql/sql.h b/src/include/execution/sql/sql.h index 5e6fecffce..346d6aca58 100644 --- a/src/include/execution/sql/sql.h +++ b/src/include/execution/sql/sql.h @@ -129,8 +129,18 @@ uint16_t GetSqlTypeIdSize(SqlTypeId type); */ std::size_t GetTypeIdAlignment(TypeId type); +/** + * Parse a SQL type ID from a string. + * @param type_string The string representation of the type name + * @return The SQL type ID + */ SqlTypeId SqlTypeIdFromString(const std::string &type_string); +/** + * Convert a SQL type ID to a human-readable string. + * @param type The SQL type ID + * @return The string representation of the type + */ std::string SqlTypeIdToString(SqlTypeId type); /** diff --git a/src/include/parser/create_function_statement.h b/src/include/parser/create_function_statement.h index 44b374e329..bef5b8e228 100644 --- a/src/include/parser/create_function_statement.h +++ b/src/include/parser/create_function_statement.h @@ -9,12 +9,9 @@ #include "expression/abstract_expression.h" #include "parser/sql_statement.h" -// TODO(WAN): this file is messy namespace noisepage::parser { /** Base function parameter. */ struct BaseFunctionParameter { - // TODO(WAN): there used to be a FuncParamMode that was never used? - /** Parameter data types. */ enum class DataType { INT, @@ -57,9 +54,9 @@ struct BaseFunctionParameter { case DataType::CHAR: return execution::sql::SqlTypeId::Invalid; case DataType::DOUBLE: - return execution::sql::SqlTypeId::Decimal; + return execution::sql::SqlTypeId::Double; case DataType::FLOAT: - return execution::sql::SqlTypeId::Decimal; + return execution::sql::SqlTypeId::Double; case DataType::DECIMAL: return execution::sql::SqlTypeId::Decimal; case DataType::VARCHAR: diff --git a/src/parser/postgresparser.cpp b/src/parser/postgresparser.cpp index ffc4ce6e2a..396392df5f 100644 --- a/src/parser/postgresparser.cpp +++ b/src/parser/postgresparser.cpp @@ -1706,7 +1706,6 @@ std::unique_ptr PostgresParser::FunctionParameterTransform(ParseR FunctionParameter *root) { // TODO(WAN): significant code duplication, refactor out char* -> DataType char *name = (reinterpret_cast(root->arg_type_->names_->tail->data.ptr_value)->val_.str_); - std::cout << name << std::endl; auto data_type = TypeNameToDataType(name); if (!data_type.has_value()) { PARSER_LOG_AND_THROW("FunctionParameterTransform", "DataType", name); diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index a0e394df41..dc7b59d9f5 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -60,6 +60,7 @@ static constexpr const char DECL_TYPE_ID_BIGINT[] = "bigint"; static constexpr const char DECL_TYPE_ID_REAL[] = "real"; static constexpr const char DECL_TYPE_ID_FLOAT[] = "float"; static constexpr const char DECL_TYPE_ID_DOUBLE[] = "double"; +static constexpr const char DECL_TYPE_ID_FLOAT8[] = "float8"; /** Arbitrary-precision floating point */ static constexpr const char DECL_TYPE_ID_NUMERIC[] = "numeric"; @@ -594,7 +595,7 @@ std::optional PLpgSQLParser::TypeNameToType(const std type = execution::sql::SqlTypeId::BigInt; } else if (type_name == DECL_TYPE_ID_REAL || type_name == DECL_TYPE_ID_FLOAT) { type = execution::sql::SqlTypeId::Real; - } else if (type_name == DECL_TYPE_ID_DOUBLE) { + } else if (type_name == DECL_TYPE_ID_DOUBLE || type_name == DECL_TYPE_ID_FLOAT8) { type = execution::sql::SqlTypeId::Double; } else if (type_name == DECL_TYPE_ID_NUMERIC || type_name == DECL_TYPE_ID_DECIMAL) { type = execution::sql::SqlTypeId::Decimal; From ddb1a39527fc6f4b2feaa5c3975251b9a0d31e7c Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 23 Sep 2021 21:17:58 -0400 Subject: [PATCH 121/139] updates and bugfixes in udf processing, able to compile and invoke procbench fn01 --- .gitignore | 4 +- Makefile | 17 ++++++++ script/testing/junit/sql/udf.sql | 18 ++++++++ script/testing/junit/traces/udf.test | 42 +++++++++++++++++++ src/binder/bind_node_visitor.cpp | 4 +- src/execution/compiler/udf/udf_codegen.cpp | 23 ++++++++-- src/execution/sema/sema_expr.cpp | 3 +- .../execution/compiler/udf/udf_codegen.h | 4 +- .../expression/column_value_expression.h | 9 ++-- src/include/parser/udf/variable_ref.h | 3 ++ 10 files changed, 114 insertions(+), 13 deletions(-) create mode 100644 Makefile diff --git a/.gitignore b/.gitignore index ec48ea7f41..f09148e5ba 100644 --- a/.gitignore +++ b/.gitignore @@ -48,8 +48,8 @@ config/ configure config-h.in autom4te.cache -*Makefile.in -*Makefile +build/*Makefile.in +build/*Makefile libtool aclocal.m4 config.log diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..10d071272d --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +# Makefile +# Some scripting shortcuts. + +# Run all regression checks in all execution modes +check-regress: check-regress-interpreted check-regress-compiled + +# Run all regression tests in interpreted mode +check-regress-interpreted: + cd build && PYTHONPATH=.. python -m script.testing.junit --build-type=debug --query-mode=simple + +# Run all regression tests in compiled mode +check-regress-compiled: + PYTHONPATH=.. python -m script.testing.junit --build-type=debug --query-mode=simple -a 'compiled_query_execution=True' -a 'bytecode_handlers_path=./bytecode_handlers_ir.bc' + +# Re-generate the trace file for UDF regression tests +generate-regress-udf: + cd script/testing/junit/ && ant generate-trace -Dpath=sql/udf.sql -Ddb-url=jdbc:postgresql://localhost/test -Ddb-user=postgres -Ddb-password=password -Doutput-name=udf.test diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 21c632624c..2661f8598c 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -399,3 +399,21 @@ DROP FUNCTION proc_call_callee(); DROP FUNCTION proc_call_ret(); DROP FUNCTION proc_call_assign(); DROP FUNCTION proc_call_select(); + +-- ---------------------------------------------------------------------------- +-- proc_predicate() + +CREATE FUNCTION proc_predicate(threshold INT) RETURNS INT AS $$ \ +DECLARE \ + c INT; \ +BEGIN \ + SELECT COUNT(x) FROM integers WHERE x > threshold INTO c; \ + RETURN c; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_predicate(0); +SELECT proc_predicate(1); +SELECT proc_predicate(2); + +DROP FUNCTION proc_predicate(); diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 58542e2691..9587b8c35d 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -777,3 +777,45 @@ statement ok DROP FUNCTION proc_call_select(); statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_predicate() + +statement ok + + +statement ok +CREATE FUNCTION proc_predicate(threshold INT) RETURNS INT AS $$ DECLARE c INT; BEGIN SELECT COUNT(x) FROM integers WHERE x > threshold INTO c; RETURN c; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_predicate(0); +---- +3 + + +query I rowsort +SELECT proc_predicate(1); +---- +2 + + +query I rowsort +SELECT proc_predicate(2); +---- +1 + + +statement ok + + +statement error +DROP FUNCTION proc_predicate(); + +statement ok diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 83f53f6a5a..efa8a7bbd5 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -770,8 +770,8 @@ void BindNodeVisitor::Visit(common::ManagedPointer for (auto i = 0UL; i < expr->GetChildrenSize(); ++i) { auto child = expr->GetChild(i); if (child->GetExpressionType() == parser::ExpressionType::COLUMN_VALUE) { - auto index = child.CastManagedPointerTo()->GetParamIdx(); - if (index >= 0) { + const auto index = child.CastManagedPointerTo()->GetParamIdx(); + if (index > parser::ColumnValueExpression::INVALID_PARAM_INDEX) { // replace with PVE std::unique_ptr pve = std::make_unique(index); pve->SetReturnValueType(child->GetReturnValueType()); diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 64927812ae..b733d77904 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -393,7 +393,7 @@ void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { const auto variable_refs = BindQueryAndGetVariableRefs(ast->Query()); // Optimize the embedded query - auto optimize_result = OptimizeEmbeddedQuery(ast->Query()); + auto optimize_result = OptimizeEmbeddedQuery(ast->Query(), variable_refs); auto plan = optimize_result->GetPlanNode(); // Start construction of the lambda expression @@ -576,7 +576,7 @@ void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { const auto variable_refs = BindQueryAndGetVariableRefs(ast->Query()); // Optimize the query and generate get a reference to the plan - auto optimize_result = OptimizeEmbeddedQuery(ast->Query()); + auto optimize_result = OptimizeEmbeddedQuery(ast->Query(), variable_refs); auto plan = optimize_result->GetPlanNode(); // Construct a lambda that writes the output of the query @@ -905,12 +905,23 @@ std::vector UdfCodegen::BindQueryAndGetVariableRefs(pa return visitor.BindAndGetUDFVariableRefs(common::ManagedPointer{query}, common::ManagedPointer{udf_ast_context_}); } -std::unique_ptr UdfCodegen::OptimizeEmbeddedQuery(parser::ParseResult *parsed_query) { +std::unique_ptr UdfCodegen::OptimizeEmbeddedQuery( + parser::ParseResult *parsed_query, const std::vector &variable_refs) { + // For each variable reference, we provide a dummy ConstantValueExpression + std::vector parameters{}; + parameters.reserve(variable_refs.size()); + std::transform(variable_refs.cbegin(), variable_refs.cend(), std::back_inserter(parameters), + [](const parser::udf::VariableRef &v) -> parser::ConstantValueExpression { + return parser::ConstantValueExpression{sql::SqlTypeId::Integer, sql::Integer{0}}; + }); + + // Optimize the query optimizer::StatsStorage stats{}; const std::uint64_t optimizer_timeout = 1000000; return trafficcop::TrafficCopUtil::Optimize( accessor_->GetTxn(), common::ManagedPointer(accessor_), common::ManagedPointer(parsed_query), db_oid_, - common::ManagedPointer(&stats), std::make_unique(), optimizer_timeout, nullptr); + common::ManagedPointer(&stats), std::make_unique(), optimizer_timeout, + common::ManagedPointer{¶meters}); } // Static @@ -933,6 +944,10 @@ ast::Builtin UdfCodegen::AddParamBuiltinForParameterType(sql::SqlTypeId paramete return ast::Builtin::AddParamBigInt; case sql::SqlTypeId::Decimal: return ast::Builtin::AddParamDouble; + case sql::SqlTypeId::Real: + return ast::Builtin::AddParamReal; + case sql::SqlTypeId::Double: + return ast::Builtin::AddParamDouble; case sql::SqlTypeId::Date: return ast::Builtin::AddParamDate; case sql::SqlTypeId::Timestamp: diff --git a/src/execution/sema/sema_expr.cpp b/src/execution/sema/sema_expr.cpp index 7448e90597..66f99453e8 100644 --- a/src/execution/sema/sema_expr.cpp +++ b/src/execution/sema/sema_expr.cpp @@ -301,7 +301,8 @@ void Sema::VisitIdentifierExpr(ast::IdentifierExpr *node) { } void Sema::VisitImplicitCastExpr(ast::ImplicitCastExpr *node) { - throw std::runtime_error("Should never perform semantic checking on implicit cast expressions"); + // TODO(Kyle): Why did we throw here before? + Visit(node->Input()); } void Sema::VisitIndexExpr(ast::IndexExpr *node) { diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index 847f57fdcc..c3fe150958 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -403,9 +403,11 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { /** * Run the optimizer on an embedded SQL query. * @param parsed_query The result of parsing the query + * @param variable_refs The vector of variable references within query * @return The optimized result */ - std::unique_ptr OptimizeEmbeddedQuery(parser::ParseResult *parsed_query); + std::unique_ptr OptimizeEmbeddedQuery( + parser::ParseResult *parsed_query, const std::vector &variable_refs); /** * Determine if the function described by the given metdata is a diff --git a/src/include/parser/expression/column_value_expression.h b/src/include/parser/expression/column_value_expression.h index 67ceb1c124..9429399479 100644 --- a/src/include/parser/expression/column_value_expression.h +++ b/src/include/parser/expression/column_value_expression.h @@ -36,6 +36,9 @@ class ColumnValueExpression : public AbstractExpression { friend class noisepage::TpccPlanTest; public: + /** Denotes an invalid parameter index */ + static constexpr const std::int32_t INVALID_PARAM_INDEX{-1}; + /** * This constructor is called only in postgresparser, setting the column name, * and optionally setting the table name and alias. @@ -146,10 +149,10 @@ class ColumnValueExpression : public AbstractExpression { /** @return column oid */ catalog::col_oid_t GetColumnOid() const { return column_oid_; } - /** @return parameter index */ + /** @return The parameter index */ std::int32_t GetParamIdx() const { return param_idx_; } - /** @brief set the parameter index */ + /** @brief Set the parameter index */ void SetParamIdx(const std::size_t param_idx) { param_idx_ = static_cast(param_idx); } /** @@ -241,7 +244,7 @@ class ColumnValueExpression : public AbstractExpression { catalog::col_oid_t column_oid_ = catalog::INVALID_COLUMN_OID; /** parameter index */ - std::int32_t param_idx_{-1}; + std::int32_t param_idx_{INVALID_PARAM_INDEX}; }; DEFINE_JSON_HEADER_DECLARATIONS(ColumnValueExpression); diff --git a/src/include/parser/udf/variable_ref.h b/src/include/parser/udf/variable_ref.h index bcde6b1451..dcd3434e33 100644 --- a/src/include/parser/udf/variable_ref.h +++ b/src/include/parser/udf/variable_ref.h @@ -56,6 +56,9 @@ class VariableRef { /** @return The index of the variable reference */ std::size_t Index() const { return index_; } + /** @return A string representation of the variable reference */ + std::string ToString() const { return fmt::format("{} {} {}", table_name_.c_str(), column_name_.c_str(), index_); } + private: /** The type of this variable reference */ const VariableRefType type_; From dceb1280567c4d40f28b4f66db32843ddd8f2d7e Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 24 Sep 2021 07:23:27 -0400 Subject: [PATCH 122/139] working on fixing function call argument type resolution --- script/testing/junit/sql/udf.sql | 2 + src/execution/compiler/udf/udf_codegen.cpp | 151 +++++++++++------- .../execution/compiler/udf/udf_codegen.h | 17 ++ 3 files changed, 113 insertions(+), 57 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 2661f8598c..62c62be8a9 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -417,3 +417,5 @@ SELECT proc_predicate(1); SELECT proc_predicate(2); DROP FUNCTION proc_predicate(); + +CREATE FUNCTION foo() RETURNS INT AS $$ DECLARE x INT := 1; y INT := 2; z INT := 3; BEGIN RETURN x * y + z; END $$ LANGUAGE PLPGSQL; \ No newline at end of file diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index b733d77904..e4d6b3bb13 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -110,63 +110,6 @@ void UdfCodegen::Visit(ast::udf::DynamicSQLStmtAST *ast) { throw NOT_IMPLEMENTED_EXCEPTION("UdfCodegen::Visit(DynamicSQLStmtAST*)"); } -void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { - const auto &args = ast->Args(); - - // Evaluate all arguments to call - std::vector arguments{}; - arguments.reserve(ast->Args().size()); - std::transform(args.cbegin(), args.cend(), std::back_inserter(arguments), - [this](const std::unique_ptr &expr) { return EvaluateExpression(expr.get()); }); - - NOISEPAGE_ASSERT(std::all_of(arguments.cbegin(), arguments.cend(), - [](const ast::Expr *arg) { - auto *builtin = arg->GetType()->SafeAs(); - return builtin != nullptr && builtin->IsSqlValueType(); - }), - "Invalid argument type in function call"); - - // Get argument types - std::vector argument_types{}; - std::transform(arguments.cbegin(), arguments.cend(), std::back_inserter(argument_types), - [this](const ast::Expr *expr) { - return GetCatalogTypeOidFromSQLType(expr->GetType()->SafeAs()->GetKind()); - }); - - const auto proc_oid = accessor_->GetProcOid(ast->Callee(), argument_types); - if (proc_oid == catalog::INVALID_PROC_OID) { - throw BINDER_EXCEPTION(fmt::format("Invalid function call '{}'", ast->Callee()), - common::ErrorCode::ERRCODE_PLPGSQL_ERROR); - } - - auto context = accessor_->GetFunctionContext(proc_oid); - if (context->IsBuiltin()) { - ast::Expr *result = codegen_->CallBuiltin(context->GetBuiltin(), arguments); - SetExecutionResult(result); - } else { - // NOTE(Kyle): This is an unfortunate operation because it - // requires shifting all elements in the vector, but we - // don't typically see functions with super-high arity - arguments.insert(arguments.begin(), GetExecutionContext()); - auto it = SymbolTable().find(ast->Callee()); - ast::Identifier ident_expr; - if (it != SymbolTable().end()) { - ident_expr = it->second; - } else { - auto file = reinterpret_cast( - ast::AstClone::Clone(context->GetFile(), codegen_->GetAstContext()->GetNodeFactory(), - context->GetASTContext(), codegen_->GetAstContext().Get())); - for (auto decl : file->Declarations()) { - aux_decls_.push_back(decl); - } - ident_expr = codegen_->MakeFreshIdentifier(file->Declarations().back()->Name().GetString()); - SymbolTable()[file->Declarations().back()->Name().GetString()] = ident_expr; - } - ast::Expr *result = codegen_->Call(ident_expr, arguments); - SetExecutionResult(result); - } -} - void UdfCodegen::Visit(ast::udf::DeclStmtAST *ast) { if (ast->Name() == INTERNAL_DECL_ID) { return; @@ -374,6 +317,100 @@ void UdfCodegen::Visit(ast::udf::MemberExprAST *ast) { SetExecutionResult(access); } +/* ---------------------------------------------------------------------------- + Code Generation: Function Calls +---------------------------------------------------------------------------- */ + +void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { + const auto &args = ast->Args(); + + // Evaluate all arguments to call + std::vector arguments{}; + arguments.reserve(ast->Args().size()); + std::transform(args.cbegin(), args.cend(), std::back_inserter(arguments), + [this](const std::unique_ptr &expr) { return EvaluateExpression(expr.get()); }); + + // Each argument must be one of: + // - A full-evaluated expression + // - An identifier expression + + NOISEPAGE_ASSERT(std::all_of(arguments.cbegin(), arguments.cend(), + [](const ast::Expr *arg) { + return CallArgumentIsValid(arg); + }), + "Invalid argument type in function call"); + + // Get argument types + std::vector argument_types{}; + std::transform(arguments.cbegin(), arguments.cend(), std::back_inserter(argument_types), + [this](const ast::Expr *expr) { + return GetCatalogTypeOidFromSQLType(expr->GetType()->SafeAs()->GetKind()); + }); + + const auto proc_oid = accessor_->GetProcOid(ast->Callee(), argument_types); + if (proc_oid == catalog::INVALID_PROC_OID) { + throw BINDER_EXCEPTION(fmt::format("Invalid function call '{}'", ast->Callee()), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + auto context = accessor_->GetFunctionContext(proc_oid); + if (context->IsBuiltin()) { + ast::Expr *result = codegen_->CallBuiltin(context->GetBuiltin(), arguments); + SetExecutionResult(result); + } else { + // NOTE(Kyle): This is an unfortunate operation because it + // requires shifting all elements in the vector, but we + // don't typically see functions with super-high arity + arguments.insert(arguments.begin(), GetExecutionContext()); + auto it = SymbolTable().find(ast->Callee()); + ast::Identifier ident_expr; + if (it != SymbolTable().end()) { + ident_expr = it->second; + } else { + auto file = reinterpret_cast( + ast::AstClone::Clone(context->GetFile(), codegen_->GetAstContext()->GetNodeFactory(), + context->GetASTContext(), codegen_->GetAstContext().Get())); + for (auto decl : file->Declarations()) { + aux_decls_.push_back(decl); + } + ident_expr = codegen_->MakeFreshIdentifier(file->Declarations().back()->Name().GetString()); + SymbolTable()[file->Declarations().back()->Name().GetString()] = ident_expr; + } + ast::Expr *result = codegen_->Call(ident_expr, arguments); + SetExecutionResult(result); + } +} + +ast::Type* UdfCodegen::ResolveType(const ast::Expr* expr) const { + switch (expr->GetKind()) { + case ast::AstNode::Kind::LitExpr: + return ResolveTypeForLiteralExpression(expr); + case ast::AstNode::Kind::BinaryOpExpr: { + return ResolveTypeForBinaryExpression(expr); + case ast::AstNode::Kind::IdentifierExpr: + return ResolveTypeForIdentifierExpression(expr); + default: + UNREACHABLE("Function call argument type cannot be resolved"); + } +} + +ast::Type* ResolveTypeForLiteralExpression(const ast::Expr* expr) const { + NOISEPAGE_ASSERT(expr->IsLitExpr(), "Broken precondition."); + return expr->GetType(); +} + +ast::Type* ResolveTypeForBinaryExpression(const ast::Expr* expr) const { + NOISEPAGE_ASSERT(expr->IsBinaryOpEx(), "Broken precondition"); + const auto* binary = expr->SafeAs(); + const ast::Type* left = ResolveType(binary->Left()); + const ast::Type* right = ResolveType(binary->Right()); +} + +ast::Type* ResolveTypeForIdentifierExpression(const ast::Expr* expr) const { + NOISEPAGE_ASSERT(expr->IsIdentifierExpr(), "Broken precondition."); + const Identifier name = expr->GetName(); +} + /* ---------------------------------------------------------------------------- Code Generation: Integer-Variant For-Loops ---------------------------------------------------------------------------- */ diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index c3fe150958..4cc11ebf8f 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -237,6 +237,23 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { static const char *GetReturnParamString(); private: + /* -------------------------------------------------------------------------- + Code Generation: Function Calls + -------------------------------------------------------------------------- */ + + /** + * Resolve the type of an expression. + * @param expr The expression + * @return The resolved type + */ + ast::Type* ResolveType(const ast::Expr* expr) const; + + ast::Type* ResolveTypeForLiteralExpression(const ast::Expr* expr) const; + + ast::Type* ResolveTypeForBinaryExpression(const ast::Expr* expr) const; + + ast::Type* ResolveTypeForIdentifierExpression(const ast::Expr* expr) const; + /* -------------------------------------------------------------------------- Code Generation: For-S Loops -------------------------------------------------------------------------- */ From c9181e1c175cf456eaaf1c90160a8d4beb8ac3c1 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 24 Sep 2021 07:35:13 -0400 Subject: [PATCH 123/139] wip --- src/execution/compiler/udf/udf_codegen.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index e4d6b3bb13..a5493dc2d6 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -404,11 +404,16 @@ ast::Type* ResolveTypeForBinaryExpression(const ast::Expr* expr) const { const auto* binary = expr->SafeAs(); const ast::Type* left = ResolveType(binary->Left()); const ast::Type* right = ResolveType(binary->Right()); + switch (binary->Op()) { + default: + break; + } + UNREACHABLE("Binary operation not supported"); } ast::Type* ResolveTypeForIdentifierExpression(const ast::Expr* expr) const { NOISEPAGE_ASSERT(expr->IsIdentifierExpr(), "Broken precondition."); - const Identifier name = expr->GetName(); + return GetVariableType(expr->GetName().GetString()); } /* ---------------------------------------------------------------------------- From b4ea2c0c0b9a78fa027bfbe974708855ae9f4eb2 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 24 Sep 2021 10:47:31 -0400 Subject: [PATCH 124/139] type resolution for call expressions, almost there --- Makefile | 3 + script/testing/junit/sql/udf.sql | 37 +++++- script/testing/junit/traces/udf.test | 61 +++++++++- script/testing/util/db_server.py | 1 - src/execution/compiler/udf/udf_codegen.cpp | 115 ++++++++++-------- .../execution/compiler/udf/udf_codegen.h | 32 ++++- src/network/noisepage_server.cpp | 3 +- 7 files changed, 191 insertions(+), 61 deletions(-) diff --git a/Makefile b/Makefile index 10d071272d..145e25be92 100644 --- a/Makefile +++ b/Makefile @@ -12,6 +12,9 @@ check-regress-interpreted: check-regress-compiled: PYTHONPATH=.. python -m script.testing.junit --build-type=debug --query-mode=simple -a 'compiled_query_execution=True' -a 'bytecode_handlers_path=./bytecode_handlers_ir.bc' +check-regress-udf: + cd build && PYTHONPATH=.. python -m script.testing.junit --build-type=debug --query-mode=simple --tracefile-test=udf.test + # Re-generate the trace file for UDF regression tests generate-regress-udf: cd script/testing/junit/ && ant generate-trace -Dpath=sql/udf.sql -Ddb-url=jdbc:postgresql://localhost/test -Ddb-user=postgres -Ddb-password=password -Doutput-name=udf.test diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 62c62be8a9..e00af9b50a 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -416,6 +416,39 @@ SELECT proc_predicate(0); SELECT proc_predicate(1); SELECT proc_predicate(2); -DROP FUNCTION proc_predicate(); +DROP FUNCTION proc_predicate(INT); -CREATE FUNCTION foo() RETURNS INT AS $$ DECLARE x INT := 1; y INT := 2; z INT := 3; BEGIN RETURN x * y + z; END $$ LANGUAGE PLPGSQL; \ No newline at end of file +-- ---------------------------------------------------------------------------- +-- proc_call_args() + +-- Argument to call can be an expression +CREATE FUNCTION proc_call_args() RETURNS INT AS $$ \ +DECLARE \ + x INT := 1; \ + y INT := 2; \ + z INT := 3; \ +BEGIN \ + RETURN ABS(x * y + z); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_call_args(); + +DROP FUNCTION proc_call_args(); + +-- Argument to call can be an identifier +CREATE FUNCTION proc_call_args() RETURNS INT AS $$ \ +DECLARE \ + x INT := 1; \ + y INT := 2; \ + z INT := 3; \ + r INT; \ +BEGIN \ + r = x * y + z; \ + RETURN ABS(r); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_call_args(); + +DROP FUNCTION proc_call_args(); diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 9587b8c35d..8429dc5467 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -815,7 +815,64 @@ SELECT proc_predicate(2); statement ok -statement error -DROP FUNCTION proc_predicate(); +statement ok +DROP FUNCTION proc_predicate(INT); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_call_args() + +statement ok + + +statement ok +-- Argument to call can be an expression + +statement ok +CREATE FUNCTION proc_call_args() RETURNS INT AS $$ DECLARE x INT := 1; y INT := 2; z INT := 3; BEGIN RETURN ABS(x * y + z); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_call_args(); +---- +5 + + +statement ok + + +statement ok +DROP FUNCTION proc_call_args(); + +statement ok + + +statement ok +-- Argument to call can be an identifier + +statement ok +CREATE FUNCTION proc_call_args() RETURNS INT AS $$ DECLARE x INT := 1; y INT := 2; z INT := 3; r INT; BEGIN r = x * y + z; RETURN ABS(r); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_call_args(); +---- +5 + + +statement ok + + +statement ok +DROP FUNCTION proc_call_args(); statement ok diff --git a/script/testing/util/db_server.py b/script/testing/util/db_server.py index 0e3809c9f9..88538a600a 100644 --- a/script/testing/util/db_server.py +++ b/script/testing/util/db_server.py @@ -164,7 +164,6 @@ def stop_db(self, is_dry_run=False): finally: unix_socket = os.path.join("/tmp/", f".s.PGSQL.{self.db_port}") if os.path.exists(unix_socket): - os.remove(unix_socket) LOG.info(f"Removing: {unix_socket}") self.print_db_logs() exit_code = self.db_process.returncode diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index a5493dc2d6..d8235811fd 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -13,6 +13,7 @@ #include "execution/compiler/if.h" #include "execution/compiler/loop.h" #include "execution/exec/execution_settings.h" +#include "execution/parsing/token.h" #include "execution/vm/bytecode_function_info.h" #include "optimizer/cost_model/trivial_cost_model.h" #include "optimizer/statistics/stats_storage.h" @@ -57,28 +58,32 @@ const char *UdfCodegen::GetReturnParamString() { return "return_val"; } void UdfCodegen::GenerateUDF(ast::udf::AbstractAST *ast) { ast->Accept(this); } -catalog::type_oid_t UdfCodegen::GetCatalogTypeOidFromSQLType(ast::BuiltinType::Kind type) { +catalog::type_oid_t UdfCodegen::GetCatalogTypeOidFromSQLType(sql::SqlTypeId type) { + return accessor_->GetTypeOidFromTypeId(type); +} + +catalog::type_oid_t UdfCodegen::GetCatalogTypeFromBuiltinKind(ast::BuiltinType::Kind type) { switch (type) { case ast::BuiltinType::Kind::Boolean: { - return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Boolean); + return GetCatalogTypeOidFromSQLType(sql::SqlTypeId::Boolean); } case ast::BuiltinType::Kind::Integer: { - return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Integer); + return GetCatalogTypeOidFromSQLType(sql::SqlTypeId::Integer); } case ast::BuiltinType::Kind::Real: { - return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Real); + return GetCatalogTypeOidFromSQLType(sql::SqlTypeId::Real); } case ast::BuiltinType::Kind::Decimal: { - return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Decimal); + return GetCatalogTypeOidFromSQLType(sql::SqlTypeId::Decimal); } case ast::BuiltinType::Kind::StringVal: { - return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Varchar); + return GetCatalogTypeOidFromSQLType(sql::SqlTypeId::Varchar); } case ast::BuiltinType::Kind::Date: { - return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Date); + return GetCatalogTypeOidFromSQLType(sql::SqlTypeId::Date); } case ast::BuiltinType::Kind::Timestamp: { - return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Timestamp); + return GetCatalogTypeOidFromSQLType(sql::SqlTypeId::Timestamp); } default: NOISEPAGE_ASSERT(false, "Invalid SQL type in function call"); @@ -324,28 +329,17 @@ void UdfCodegen::Visit(ast::udf::MemberExprAST *ast) { void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { const auto &args = ast->Args(); - // Evaluate all arguments to call + // Generate code to evaluate call arguments std::vector arguments{}; arguments.reserve(ast->Args().size()); std::transform(args.cbegin(), args.cend(), std::back_inserter(arguments), [this](const std::unique_ptr &expr) { return EvaluateExpression(expr.get()); }); - // Each argument must be one of: - // - A full-evaluated expression - // - An identifier expression - - NOISEPAGE_ASSERT(std::all_of(arguments.cbegin(), arguments.cend(), - [](const ast::Expr *arg) { - return CallArgumentIsValid(arg); - }), - "Invalid argument type in function call"); - - // Get argument types std::vector argument_types{}; - std::transform(arguments.cbegin(), arguments.cend(), std::back_inserter(argument_types), - [this](const ast::Expr *expr) { - return GetCatalogTypeOidFromSQLType(expr->GetType()->SafeAs()->GetKind()); - }); + argument_types.reserve(arguments.size()); + std::transform( + arguments.cbegin(), arguments.cend(), std::back_inserter(argument_types), + [this](const ast::Expr *expr) -> catalog::type_oid_t { return GetCatalogTypeOidFromSQLType(ResolveType(expr)); }); const auto proc_oid = accessor_->GetProcOid(ast->Callee(), argument_types); if (proc_oid == catalog::INVALID_PROC_OID) { @@ -381,39 +375,62 @@ void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { } } -ast::Type* UdfCodegen::ResolveType(const ast::Expr* expr) const { +sql::SqlTypeId UdfCodegen::ResolveType(const ast::Expr *expr) const { switch (expr->GetKind()) { - case ast::AstNode::Kind::LitExpr: - return ResolveTypeForLiteralExpression(expr); - case ast::AstNode::Kind::BinaryOpExpr: { - return ResolveTypeForBinaryExpression(expr); - case ast::AstNode::Kind::IdentifierExpr: - return ResolveTypeForIdentifierExpression(expr); - default: - UNREACHABLE("Function call argument type cannot be resolved"); + case ast::AstNode::Kind::LitExpr: + return ResolveTypeForLiteralExpression(expr->SafeAs()); + case ast::AstNode::Kind::BinaryOpExpr: + return ResolveTypeForBinaryExpression(expr->SafeAs()); + case ast::AstNode::Kind::IdentifierExpr: + return ResolveTypeForIdentifierExpression(expr->SafeAs()); + default: + UNREACHABLE("Function call argument type cannot be resolved"); } } -ast::Type* ResolveTypeForLiteralExpression(const ast::Expr* expr) const { +sql::SqlTypeId UdfCodegen::ResolveTypeForLiteralExpression(const ast::LitExpr *expr) const { NOISEPAGE_ASSERT(expr->IsLitExpr(), "Broken precondition."); - return expr->GetType(); -} - -ast::Type* ResolveTypeForBinaryExpression(const ast::Expr* expr) const { - NOISEPAGE_ASSERT(expr->IsBinaryOpEx(), "Broken precondition"); - const auto* binary = expr->SafeAs(); - const ast::Type* left = ResolveType(binary->Left()); - const ast::Type* right = ResolveType(binary->Right()); - switch (binary->Op()) { - default: - break; - } - UNREACHABLE("Binary operation not supported"); + // TODO(Kyle): What to do about the ambiguity here? + // e.g. a literal might be a float vs double + switch (expr->GetLiteralKind()) { + case ast::LitExpr::LitKind::Boolean: + return sql::SqlTypeId::Boolean; + case ast::LitExpr::LitKind::Float: + return sql::SqlTypeId::Double; + case ast::LitExpr::LitKind::Int: + return sql::SqlTypeId::Integer; + case ast::LitExpr::LitKind::String: + return sql::SqlTypeId::Varchar; + default: + UNREACHABLE("Invalid type"); + } +} + +sql::SqlTypeId UdfCodegen::ResolveTypeForBinaryExpression(const ast::BinaryOpExpr *expr) const { + NOISEPAGE_ASSERT(expr->IsBinaryOpExpr(), "Broken precondition"); + const auto *binary = expr->SafeAs(); + sql::SqlTypeId left = ResolveType(binary->Left()); + sql::SqlTypeId right = ResolveType(binary->Right()); + switch (binary->Op()) { + // Basic arithmetic operators + case parsing::Token::Type::PLUS: + case parsing::Token::Type::MINUS: + case parsing::Token::Type::STAR: + case parsing::Token::Type::SLASH: + if (left == right) { + return left; + } + UNREACHABLE("Implicit conversions not supported"); + default: + break; + } + UNREACHABLE("Binary operation not supported"); } -ast::Type* ResolveTypeForIdentifierExpression(const ast::Expr* expr) const { +sql::SqlTypeId UdfCodegen::ResolveTypeForIdentifierExpression(const ast::IdentifierExpr *expr) const { NOISEPAGE_ASSERT(expr->IsIdentifierExpr(), "Broken precondition."); - return GetVariableType(expr->GetName().GetString()); + // Just lookup the type for the variable with which it was declared + return GetVariableType(expr->Name().GetString()); } /* ---------------------------------------------------------------------------- diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index 4cc11ebf8f..3d6bdf199d 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -246,13 +246,28 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { * @param expr The expression * @return The resolved type */ - ast::Type* ResolveType(const ast::Expr* expr) const; + sql::SqlTypeId ResolveType(const ast::Expr *expr) const; - ast::Type* ResolveTypeForLiteralExpression(const ast::Expr* expr) const; + /** + * Resolve the type of a literal expression in a function call argument. + * @param expr The literal expression + * @return The resolved type of the literal expression + */ + sql::SqlTypeId ResolveTypeForLiteralExpression(const ast::LitExpr *expr) const; - ast::Type* ResolveTypeForBinaryExpression(const ast::Expr* expr) const; + /** + * Resolve the type of a binary expression in a function call argument. + * @param expr The binary expression + * @return The resolved type of the binary expression + */ + sql::SqlTypeId ResolveTypeForBinaryExpression(const ast::BinaryOpExpr *expr) const; - ast::Type* ResolveTypeForIdentifierExpression(const ast::Expr* expr) const; + /** + * Resolve the type of an identifier expression in a function call argument. + * @param expr The identifier expression + * @return The resolved type of the identifier expression + */ + sql::SqlTypeId ResolveTypeForIdentifierExpression(const ast::IdentifierExpr *expr) const; /* -------------------------------------------------------------------------- Code Generation: For-S Loops @@ -386,7 +401,14 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { * @param type The SQL type of interest * @return The corresponding catalog type */ - catalog::type_oid_t GetCatalogTypeOidFromSQLType(execution::ast::BuiltinType::Kind type); + catalog::type_oid_t GetCatalogTypeOidFromSQLType(execution::sql::SqlTypeId type); + + /** + * Translate a builtin type Kind to its corresponding catalog type. + * @param type The builtin type of interst + * @return The corresponding catalog type + */ + catalog::type_oid_t GetCatalogTypeFromBuiltinKind(execution::ast::BuiltinType::Kind type); /** @return A mutable reference to the symbol table */ std::unordered_map &SymbolTable() { return symbol_table_; } diff --git a/src/network/noisepage_server.cpp b/src/network/noisepage_server.cpp index de83c1dbd1..b99d456a00 100644 --- a/src/network/noisepage_server.cpp +++ b/src/network/noisepage_server.cpp @@ -139,8 +139,7 @@ void TerrierServer::RunServer() { // Register the network socket. RegisterSocket(); - // Register the Unix domain socket. - RegisterSocket(); + // TODO(Kyle): Removed UNIX domain socket. // Register the ConnectionDispatcherTask. This handles connections to the sockets created above. dispatcher_task_ = thread_registry_->RegisterDedicatedThread( From f65453abe5f434ad05b4bbe2584ee5889fb01108 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 24 Sep 2021 16:44:14 -0400 Subject: [PATCH 125/139] updates for procedures, promotion --- script/testing/junit/sql/udf.sql | 103 ++++++++++ script/testing/junit/traces/udf.test | 177 ++++++++++++++++++ src/catalog/catalog_accessor.cpp | 5 +- src/execution/compiler/udf/udf_codegen.cpp | 19 +- src/execution/sql/sql.cpp | 3 + .../parser/create_function_statement.h | 7 +- src/parser/udf/plpgsql_parser.cpp | 9 +- 7 files changed, 317 insertions(+), 6 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index e00af9b50a..f0210ea664 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -452,3 +452,106 @@ $$ LANGUAGE PLPGSQL; SELECT proc_call_args(); DROP FUNCTION proc_call_args(); + +-- ---------------------------------------------------------------------------- +-- proc_promotion() + +-- Able to (silently) promote REAL to DOUBLE PRECISION +CREATE FUNCTION proc_promotion() RETURNS REAL AS $$ \ +DECLARE \ + x INT := 1; \ + y REAL := 1.0; \ + t REAL; \ +BEGIN \ + t = x * y; \ + RETURN FLOOR(t); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(); +DROP FUNCTION proc_promotion(); + +-- Able to (silently) promote FLOAT to DOUBLE PRECISION +CREATE FUNCTION proc_promotion() RETURNS FLOAT AS $$ \ +DECLARE \ + x INT := 1; \ + y FLOAT := 1.0; \ + t FLOAT; \ +BEGIN \ + t = x * y; \ + RETURN FLOOR(t); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(); +DROP FUNCTION proc_promotion(); + +-- Promotion does not affect correct operation of DOUBLE PRECISION +CREATE FUNCTION proc_promotion() RETURNS DOUBLE PRECISION AS $$ \ +DECLARE \ + x INT := 1; \ + y DOUBLE PRECISION := 1.0; \ + t DOUBLE PRECISION; \ +BEGIN \ + t = x * y; \ + RETURN FLOOR(t); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(); +DROP FUNCTION proc_promotion(); + +-- Promotion does not affect correct operation of FLOAT8 +CREATE FUNCTION proc_promotion() RETURNS DOUBLE PRECISION AS $$ \ +DECLARE \ + x INT := 1; \ + y DOUBLE PRECISION := 1.0; \ + t DOUBLE PRECISION; \ +BEGIN \ + t = x * y; \ + RETURN FLOOR(t); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(); +DROP FUNCTION proc_promotion(); + +-- Promotion works as expected with UDF arguments +CREATE FUNCTION proc_promotion(x FLOAT) RETURNS FLOAT AS $$ \ +BEGIN \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(1337.0); +DROP FUNCTION proc_promotion(FLOAT); + +-- Promotion works as expected with UDF arguments +CREATE FUNCTION proc_promotion(x REAL) RETURNS REAL AS $$ \ +BEGIN \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(1337.0); +DROP FUNCTION proc_promotion(REAL); + +-- Promotion works as expected with UDF arguments +CREATE FUNCTION proc_promotion(x DOUBLE PRECISION) RETURNS DOUBLE PRECISION AS $$ \ +BEGIN \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(1337.0); +DROP FUNCTION proc_promotion(DOUBLE PRECISION); + +-- Promotion works as expected with UDF arguments +CREATE FUNCTION proc_promotion(x FLOAT8) RETURNS FLOAT8 AS $$ \ +BEGIN \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(1337.0); +DROP FUNCTION proc_promotion(FLOAT8); diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 8429dc5467..27454163e4 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -876,3 +876,180 @@ statement ok DROP FUNCTION proc_call_args(); statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_promotion() + +statement ok + + +statement ok +-- Able to (silently) promote REAL to DOUBLE PRECISION + +statement ok +CREATE FUNCTION proc_promotion() RETURNS REAL AS $$ DECLARE x INT := 1; y REAL := 1.0; t REAL; BEGIN t = x * y; RETURN FLOOR(t); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query rowsort +SELECT proc_promotion(); +---- +1 + + +statement ok +DROP FUNCTION proc_promotion(); + +statement ok + + +statement ok +-- Able to (silently) promote FLOAT to DOUBLE PRECISION + +statement ok +CREATE FUNCTION proc_promotion() RETURNS FLOAT AS $$ DECLARE x INT := 1; y FLOAT := 1.0; t FLOAT; BEGIN t = x * y; RETURN FLOOR(t); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query R rowsort +SELECT proc_promotion(); +---- +1 + + +statement ok +DROP FUNCTION proc_promotion(); + +statement ok + + +statement ok +-- Promotion does not affect correct operation of DOUBLE PRECISION + +statement ok +CREATE FUNCTION proc_promotion() RETURNS DOUBLE PRECISION AS $$ DECLARE x INT := 1; y DOUBLE PRECISION := 1.0; t DOUBLE PRECISION; BEGIN t = x * y; RETURN FLOOR(t); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query R rowsort +SELECT proc_promotion(); +---- +1 + + +statement ok +DROP FUNCTION proc_promotion(); + +statement ok + + +statement ok +-- Promotion does not affect correct operation of FLOAT8 + +statement ok +CREATE FUNCTION proc_promotion() RETURNS DOUBLE PRECISION AS $$ DECLARE x INT := 1; y DOUBLE PRECISION := 1.0; t DOUBLE PRECISION; BEGIN t = x * y; RETURN FLOOR(t); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query R rowsort +SELECT proc_promotion(); +---- +1 + + +statement ok +DROP FUNCTION proc_promotion(); + +statement ok + + +statement ok +-- Promotion works as expected with UDF arguments + +statement ok +CREATE FUNCTION proc_promotion(x FLOAT) RETURNS FLOAT AS $$ BEGIN RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query R rowsort +SELECT proc_promotion(1337.0); +---- +1337 + + +statement ok +DROP FUNCTION proc_promotion(FLOAT); + +statement ok + + +statement ok +-- Promotion works as expected with UDF arguments + +statement ok +CREATE FUNCTION proc_promotion(x REAL) RETURNS REAL AS $$ BEGIN RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query rowsort +SELECT proc_promotion(1337.0); +---- +1337 + + +statement ok +DROP FUNCTION proc_promotion(REAL); + +statement ok + + +statement ok +-- Promotion works as expected with UDF arguments + +statement ok +CREATE FUNCTION proc_promotion(x DOUBLE PRECISION) RETURNS DOUBLE PRECISION AS $$ BEGIN RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query R rowsort +SELECT proc_promotion(1337.0); +---- +1337 + + +statement ok +DROP FUNCTION proc_promotion(DOUBLE PRECISION); + +statement ok + + +statement ok +-- Promotion works as expected with UDF arguments + +statement ok +CREATE FUNCTION proc_promotion(x FLOAT8) RETURNS FLOAT8 AS $$ BEGIN RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query R rowsort +SELECT proc_promotion(1337.0); +---- +1337 + + +statement ok +DROP FUNCTION proc_promotion(FLOAT8); + +statement ok diff --git a/src/catalog/catalog_accessor.cpp b/src/catalog/catalog_accessor.cpp index 855197f2af..92191beb79 100644 --- a/src/catalog/catalog_accessor.cpp +++ b/src/catalog/catalog_accessor.cpp @@ -266,7 +266,10 @@ type_oid_t CatalogAccessor::TypeNameToType(const std::string &type_name) { } else if (type_name == "bool") { type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Boolean); } else if (type_name == "float4") { - type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Real); + // NOTE(Kyle): The "regular" SQL frontend always promotes + // FLOAT / REAL to DOUBLE PRECISION / FLOAT8, so we do the + // same here to remain consistent + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Double); } else if (type_name == "float8") { type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Double); } else if (type_name == "numeric") { diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index d8235811fd..b2b1063595 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -406,6 +406,17 @@ sql::SqlTypeId UdfCodegen::ResolveTypeForLiteralExpression(const ast::LitExpr *e } } +/** @return `true` if the given type is an integral type */ +static bool IsIntegral(sql::SqlTypeId type) { + return type == sql::SqlTypeId::TinyInt || type == sql::SqlTypeId::SmallInt || type == sql::SqlTypeId::Integer || + type == sql::SqlTypeId::BigInt; +} + +/** @return `true` if the given type is a floating-point type */ +static bool IsFloatingPoint(sql::SqlTypeId type) { + return type == sql::SqlTypeId::Real || type == sql::SqlTypeId::Double; +} + sql::SqlTypeId UdfCodegen::ResolveTypeForBinaryExpression(const ast::BinaryOpExpr *expr) const { NOISEPAGE_ASSERT(expr->IsBinaryOpExpr(), "Broken precondition"); const auto *binary = expr->SafeAs(); @@ -420,7 +431,13 @@ sql::SqlTypeId UdfCodegen::ResolveTypeForBinaryExpression(const ast::BinaryOpExp if (left == right) { return left; } - UNREACHABLE("Implicit conversions not supported"); + if (IsFloatingPoint(left) && IsIntegral(right)) { + return left; + } + if (IsIntegral(left) && IsFloatingPoint(right)) { + return right; + } + UNREACHABLE("Unsupported types for arithmetic operations"); default: break; } diff --git a/src/execution/sql/sql.cpp b/src/execution/sql/sql.cpp index 96dac8c124..7637bfd1a8 100644 --- a/src/execution/sql/sql.cpp +++ b/src/execution/sql/sql.cpp @@ -350,6 +350,9 @@ TypeId GetTypeId(SqlTypeId frontend_type) { case SqlTypeId::BigInt: execution_type_id = execution::sql::TypeId::BigInt; break; + case SqlTypeId::Real: + execution_type_id = execution::sql::TypeId::Float; + break; case SqlTypeId::Double: execution_type_id = execution::sql::TypeId::Double; break; diff --git a/src/include/parser/create_function_statement.h b/src/include/parser/create_function_statement.h index bef5b8e228..be620b71c9 100644 --- a/src/include/parser/create_function_statement.h +++ b/src/include/parser/create_function_statement.h @@ -53,9 +53,12 @@ struct BaseFunctionParameter { return execution::sql::SqlTypeId::BigInt; case DataType::CHAR: return execution::sql::SqlTypeId::Invalid; - case DataType::DOUBLE: - return execution::sql::SqlTypeId::Double; case DataType::FLOAT: + // NOTE(Kyle): The "regular" SQL frontend automatically + // promotes FLOAT / REAL to DOUBLE PRECISION / FLOAT8; + // we do the same here to remain consistent + return execution::sql::SqlTypeId::Double; + case DataType::DOUBLE: return execution::sql::SqlTypeId::Double; case DataType::DECIMAL: return execution::sql::SqlTypeId::Decimal; diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index dc7b59d9f5..0c78a2fc42 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -59,7 +59,7 @@ static constexpr const char DECL_TYPE_ID_BIGINT[] = "bigint"; /** Variable-precision floating point */ static constexpr const char DECL_TYPE_ID_REAL[] = "real"; static constexpr const char DECL_TYPE_ID_FLOAT[] = "float"; -static constexpr const char DECL_TYPE_ID_DOUBLE[] = "double"; +static constexpr const char DECL_TYPE_ID_DOUBLE[] = "double precision"; static constexpr const char DECL_TYPE_ID_FLOAT8[] = "float8"; /** Arbitrary-precision floating point */ @@ -586,6 +586,7 @@ bool PLpgSQLParser::HasEnclosingQuery(ParseResult *parse_result) { std::optional PLpgSQLParser::TypeNameToType(const std::string &type_name) { // TODO(Kyle): This is awkward control flow because we // model RECORD types with the SqlTypeId::Invalid type + execution::sql::SqlTypeId type; if (type_name == DECL_TYPE_ID_SMALLINT) { type = execution::sql::SqlTypeId::SmallInt; @@ -594,7 +595,11 @@ std::optional PLpgSQLParser::TypeNameToType(const std } else if (type_name == DECL_TYPE_ID_BIGINT) { type = execution::sql::SqlTypeId::BigInt; } else if (type_name == DECL_TYPE_ID_REAL || type_name == DECL_TYPE_ID_FLOAT) { - type = execution::sql::SqlTypeId::Real; + // NOTE(Kyle): We perform a sneaky trick here: the "normal" + // SQL frontend automatically promotes all floating-point + // types to DOUBLE PRECISION (FLOAT8); we do the same thing + // here to remain consistent across the entire system. + type = execution::sql::SqlTypeId::Double; } else if (type_name == DECL_TYPE_ID_DOUBLE || type_name == DECL_TYPE_ID_FLOAT8) { type = execution::sql::SqlTypeId::Double; } else if (type_name == DECL_TYPE_ID_NUMERIC || type_name == DECL_TYPE_ID_DECIMAL) { From 6a08cf464d030ce35e110aecbe41ab2c1ce08114 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 24 Sep 2021 21:09:39 -0400 Subject: [PATCH 126/139] going to punt on cast() for now --- script/testing/junit/sql/udf.sql | 30 ++++++++++ script/testing/junit/traces/udf.test | 60 +++++++++++++++++++ .../parser/expression/type_cast_expression.h | 1 + 3 files changed, 91 insertions(+) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index f0210ea664..92b43117a8 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -555,3 +555,33 @@ $$ LANGUAGE PLPGSQL; SELECT proc_promotion(1337.0); DROP FUNCTION proc_promotion(FLOAT8); + +-- ---------------------------------------------------------------------------- +-- proc_cast() + +-- CAST works in assignment expression +CREATE FUNCTION proc_cast() RETURNS FLOAT AS $$ \ +DECLARE \ + x FLOAT; \ +BEGIN \ + x = CAST(1 AS FLOAT); \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_cast(); +DROP FUNCTION proc_cast(); + +-- TODO(Kyle): this is a great example of a function that +-- we can't currently compile because we only resort to a +-- full handoff to the SQL execution infrastructure in the +-- case of assignment expressions. For everything else, in +-- this case a RETURN statement, we don't yet have the +-- ability to defer this to the SQL engine, and we also can't +-- handle it in the "builtin" manner, so we just fail. + +-- CREATE FUNCTION proc_cast() RETURNS FLOAT AS $$ \ +-- BEGIN \ +-- RETURN CAST(1 AS FLOAT); \ +-- END \ +-- $$ LANGUAGE PLPGSQL; diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 27454163e4..e7e8a4f83f 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -1053,3 +1053,63 @@ statement ok DROP FUNCTION proc_promotion(FLOAT8); statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_cast() + +statement ok + + +statement ok +-- CAST works in assignment expression + +statement ok +CREATE FUNCTION proc_cast() RETURNS FLOAT AS $$ DECLARE x FLOAT; BEGIN x = CAST(1 AS FLOAT); RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query R rowsort +SELECT proc_cast(); +---- +1 + + +statement ok +DROP FUNCTION proc_cast(); + +statement ok + + +statement ok +-- TODO(Kyle): this is a great example of a function that + +statement ok +-- we can't currently compile because we only resort to a + +statement ok +-- full handoff to the SQL execution infrastructure in the + +statement ok +-- case of assignment expressions. For everything else, in + +statement ok +-- this case a RETURN statement, we don't yet have the + +statement ok +-- ability to defer this to the SQL engine, and we also can't + +statement ok +-- handle it in the "builtin" manner, so we just fail. + +statement ok + + +statement ok +-- CREATE FUNCTION proc_cast() RETURNS FLOAT AS $$ -- BEGIN -- RETURN CAST(1 AS FLOAT); -- END -- $$ LANGUAGE PLPGSQL; + +statement ok diff --git a/src/include/parser/expression/type_cast_expression.h b/src/include/parser/expression/type_cast_expression.h index 13260ecab4..27f3637b5a 100644 --- a/src/include/parser/expression/type_cast_expression.h +++ b/src/include/parser/expression/type_cast_expression.h @@ -36,6 +36,7 @@ class TypeCastExpression : public AbstractExpression { * @returns copy of this */ std::unique_ptr Copy() const override; + /** * Creates a copy of the current AbstractExpression with new children implanted. * The children should not be owned by any other AbstractExpression. From a0ee099101d4ea1bc32f40b9ddeedb3190860ab7 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 24 Sep 2021 22:12:32 -0400 Subject: [PATCH 127/139] add support for RANDOM --- script/testing/junit/traces/udf.test | 15 +++++++++++++++ src/catalog/postgres/pg_proc_impl.cpp | 2 ++ src/execution/sema/sema_builtin.cpp | 11 +++++++++++ src/execution/sql/functions/system_functions.cpp | 5 +++++ src/execution/vm/bytecode_generator.cpp | 9 +++++++++ src/execution/vm/vm.cpp | 6 ++++++ src/include/execution/ast/builtins.h | 1 + src/include/execution/sema/sema.h | 1 + .../execution/sql/functions/system_functions.h | 10 +++++++++- src/include/execution/vm/bytecode_generator.h | 1 + src/include/execution/vm/bytecode_handlers.h | 4 ++++ src/include/execution/vm/bytecodes.h | 1 + 12 files changed, 65 insertions(+), 1 deletion(-) diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index e7e8a4f83f..16089db81e 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -1113,3 +1113,18 @@ statement ok -- CREATE FUNCTION proc_cast() RETURNS FLOAT AS $$ -- BEGIN -- RETURN CAST(1 AS FLOAT); -- END -- $$ LANGUAGE PLPGSQL; statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_random() + +statement ok + + +statement ok +CREATE FUNCTION proc_random() RETURNS FLOAT AS $$ DECLARE x FLOAT; BEGIN x = (SELECT RANDOM()); RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok diff --git a/src/catalog/postgres/pg_proc_impl.cpp b/src/catalog/postgres/pg_proc_impl.cpp index 3ea3e42850..52b0e87deb 100644 --- a/src/catalog/postgres/pg_proc_impl.cpp +++ b/src/catalog/postgres/pg_proc_impl.cpp @@ -529,6 +529,7 @@ void PgProcImpl::BootstrapProcs(const common::ManagedPointernext_oid_++}, "nprunnersemitint", PgLanguage::INTERNAL_LANGUAGE_OID, PgNamespace::NAMESPACE_DEFAULT_NAMESPACE_OID, INVALID_TYPE_OID, @@ -643,6 +644,7 @@ void PgProcImpl::BootstrapProcContexts(const common::ManagedPointerSetType(ast::BuiltinType::Get(GetContext(), ast::BuiltinType::Nil)); } +void Sema::CheckBuiltinRandomCall(ast::CallExpr *call, ast::Builtin builtin) { + if (!CheckArgCount(call, 0)) { + return; + } + call->SetType(ast::BuiltinType::Get(GetContext(), ast::BuiltinType::Kind::Real)); +} + void Sema::CheckBuiltinStringCall(ast::CallExpr *call, ast::Builtin builtin) { ast::BuiltinType::Kind sql_type; @@ -4136,6 +4143,10 @@ void Sema::CheckBuiltinCall(ast::CallExpr *call) { CheckBuiltinTestCatalogIndexLookup(call); break; } + case ast::Builtin::Random: { + CheckBuiltinRandomCall(call, builtin); + break; + } default: UNREACHABLE("Unhandled builtin!"); } diff --git a/src/execution/sql/functions/system_functions.cpp b/src/execution/sql/functions/system_functions.cpp index 4d023ce669..087ae649c9 100644 --- a/src/execution/sql/functions/system_functions.cpp +++ b/src/execution/sql/functions/system_functions.cpp @@ -10,4 +10,9 @@ void SystemFunctions::Version(UNUSED_ATTRIBUTE exec::ExecutionContext *ctx, Stri *result = StringVal(version); } +void SystemFunctions::Random(Real *result) { + // TODO(Kyle): Actually generate a random value + *result = Real(1.0); +} + } // namespace noisepage::execution::sql diff --git a/src/execution/vm/bytecode_generator.cpp b/src/execution/vm/bytecode_generator.cpp index 2099cc1f2a..06062bb371 100644 --- a/src/execution/vm/bytecode_generator.cpp +++ b/src/execution/vm/bytecode_generator.cpp @@ -774,6 +774,11 @@ void BytecodeGenerator::VisitBuiltinDateFunctionCall(ast::CallExpr *call, ast::B GetExecutionResult()->SetDestination(dest); } +void BytecodeGenerator::VisitBuiltinRandomFunctionCall(ast::CallExpr *call, ast::Builtin builtin) { + LocalVar ret = GetExecutionResult()->GetOrCreateDestination(call->GetType()); + GetEmitter()->Emit(Bytecode::Random, ret); +} + void BytecodeGenerator::VisitBuiltinTableIterCall(ast::CallExpr *call, ast::Builtin builtin) { // The first argument to all calls is a pointer to the TVI LocalVar iter = VisitExpressionForRValue(call->Arguments()[0]); @@ -2861,6 +2866,10 @@ void BytecodeGenerator::VisitBuiltinCallExpr(ast::CallExpr *call) { VisitBuiltinDateFunctionCall(call, builtin); break; } + case ast::Builtin::Random: { + VisitBuiltinRandomFunctionCall(call, builtin); + break; + } case ast::Builtin::RegisterThreadWithMetricsManager: { LocalVar exec_ctx = VisitExpressionForRValue(call->Arguments()[0]); GetEmitter()->Emit(Bytecode::RegisterThreadWithMetricsManager, exec_ctx); diff --git a/src/execution/vm/vm.cpp b/src/execution/vm/vm.cpp index 02a59e42ed..8136a1aee7 100644 --- a/src/execution/vm/vm.cpp +++ b/src/execution/vm/vm.cpp @@ -2793,6 +2793,12 @@ void VM::Interpret(const uint8_t *ip, Frame *frame) { // NOLINT DISPATCH_NEXT(); } + OP(Random) : { + auto *result = frame->LocalAt(READ_LOCAL_ID()); + OpRandom(result); + DISPATCH_NEXT(); + } + OP(InitCap) : { auto *result = frame->LocalAt(READ_LOCAL_ID()); auto *exec_ctx = frame->LocalAt(READ_LOCAL_ID()); diff --git a/src/include/execution/ast/builtins.h b/src/include/execution/ast/builtins.h index 64474fd877..65034a1472 100644 --- a/src/include/execution/ast/builtins.h +++ b/src/include/execution/ast/builtins.h @@ -34,6 +34,7 @@ namespace noisepage::execution::ast { /* SQL Functions */ \ F(Like, like) \ F(DatePart, datePart) \ + F(Random, random) \ \ /* Thread State Container */ \ F(ExecutionContextAddRowsAffected, execCtxAddRowsAffected) \ diff --git a/src/include/execution/sema/sema.h b/src/include/execution/sema/sema.h index c07121fed9..31c36598f9 100644 --- a/src/include/execution/sema/sema.h +++ b/src/include/execution/sema/sema.h @@ -166,6 +166,7 @@ class Sema : public ast::AstVisitor { void CheckBuiltinParamCall(ast::CallExpr *call, ast::Builtin builtin); void CheckBuiltinCteScanCall(ast::CallExpr *call, ast::Builtin builtin); void CheckBuiltinStringCall(ast::CallExpr *call, ast::Builtin builtin); + void CheckBuiltinRandomCall(ast::CallExpr *call, ast::Builtin builtin); void CheckBuiltinReplicationCall(ast::CallExpr *call, ast::Builtin builtin); diff --git a/src/include/execution/sql/functions/system_functions.h b/src/include/execution/sql/functions/system_functions.h index 7db0646d20..eee51a4dac 100644 --- a/src/include/execution/sql/functions/system_functions.h +++ b/src/include/execution/sql/functions/system_functions.h @@ -19,9 +19,17 @@ class EXPORT SystemFunctions { SystemFunctions() = delete; /** - * Gets the version of the database + * Get the version of the database. + * @param ctx The execution context + * @param result The out parameter that receives version string */ static void Version(exec::ExecutionContext *ctx, StringVal *result); + + /** + * Generate a random floating point value on [0.0, 1.0). + * @param result The out parameter that receives the result + */ + static void Random(Real *result); }; } // namespace noisepage::execution::sql diff --git a/src/include/execution/vm/bytecode_generator.h b/src/include/execution/vm/bytecode_generator.h index 910f9c357a..a9fc19227c 100644 --- a/src/include/execution/vm/bytecode_generator.h +++ b/src/include/execution/vm/bytecode_generator.h @@ -88,6 +88,7 @@ class BytecodeGenerator final : public ast::AstVisitor { void VisitNullValueCall(ast::CallExpr *call, ast::Builtin builtin); void VisitSqlStringLikeCall(ast::CallExpr *call); void VisitBuiltinDateFunctionCall(ast::CallExpr *call, ast::Builtin builtin); + void VisitBuiltinRandomFunctionCall(ast::CallExpr *call, ast::Builtin builtin); void VisitBuiltinTableIterCall(ast::CallExpr *call, ast::Builtin builtin); void VisitBuiltinTableIterParallelCall(ast::CallExpr *call); void VisitBuiltinVPICall(ast::CallExpr *call, ast::Builtin builtin); diff --git a/src/include/execution/vm/bytecode_handlers.h b/src/include/execution/vm/bytecode_handlers.h index ea47131ae5..ecbccaf767 100644 --- a/src/include/execution/vm/bytecode_handlers.h +++ b/src/include/execution/vm/bytecode_handlers.h @@ -1903,6 +1903,10 @@ VM_OP_WARM void OpVersion(noisepage::execution::exec::ExecutionContext *ctx, noisepage::execution::sql::SystemFunctions::Version(ctx, result); } +VM_OP_WARM void OpRandom(noisepage::execution::sql::Real *result) { + noisepage::execution::sql::SystemFunctions::Random(result); +} + VM_OP_WARM void OpInitCap(noisepage::execution::sql::StringVal *result, noisepage::execution::exec::ExecutionContext *ctx, const noisepage::execution::sql::StringVal *str) { diff --git a/src/include/execution/vm/bytecodes.h b/src/include/execution/vm/bytecodes.h index 73f3b6b444..359103b4f2 100644 --- a/src/include/execution/vm/bytecodes.h +++ b/src/include/execution/vm/bytecodes.h @@ -772,6 +772,7 @@ namespace noisepage::execution::vm { \ /* Miscellaneous functions. */ \ F(Version, OperandType::Local, OperandType::Local) \ + F(Random, OperandType::Local) \ \ /* Parameter support. */ \ F(GetParamBool, OperandType::Local, OperandType::Local, OperandType::Local) \ From f130ff2ffaf29cd9e209a00e7928592d04a60188 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sat, 25 Sep 2021 10:09:07 -0400 Subject: [PATCH 128/139] actual pseudorandom support for RANDOM --- src/execution/sql/functions/system_functions.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/execution/sql/functions/system_functions.cpp b/src/execution/sql/functions/system_functions.cpp index 087ae649c9..3e71dc4071 100644 --- a/src/execution/sql/functions/system_functions.cpp +++ b/src/execution/sql/functions/system_functions.cpp @@ -1,5 +1,7 @@ #include "execution/sql/functions/system_functions.h" +#include + #include "common/version.h" #include "execution/exec/execution_context.h" @@ -11,8 +13,11 @@ void SystemFunctions::Version(UNUSED_ATTRIBUTE exec::ExecutionContext *ctx, Stri } void SystemFunctions::Random(Real *result) { - // TODO(Kyle): Actually generate a random value - *result = Real(1.0); + // TODO(Kyle): Static locals are kind of gross, where + // should state for this type of one-off thing live? + static std::mt19937 generator{std::random_device{}()}; // NOLINT + static std::uniform_real_distribution<> distribution{0, 1}; + *result = Real(distribution(generator)); } } // namespace noisepage::execution::sql From 51310f50fb47cbe1b0b4183286d4f427eaef7788 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sat, 25 Sep 2021 10:29:25 -0400 Subject: [PATCH 129/139] compiling and invoking up through function 5 From c52fb132e801188e2d03cc01555513f5fa24d904 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sat, 25 Sep 2021 22:13:09 -0400 Subject: [PATCH 130/139] add support for resolving functions with NULL actual parameters --- script/testing/junit/sql/udf.sql | 35 +++++++++ script/testing/junit/traces/udf.test | 34 ++++++++- src/catalog/catalog_accessor.cpp | 2 + src/catalog/database_catalog.cpp | 46 +++++++++++ src/include/catalog/catalog_accessor.h | 2 +- src/include/catalog/database_catalog.h | 20 +++++ test/catalog/catalog_test.cpp | 101 ++++++++++++++++++++++--- 7 files changed, 228 insertions(+), 12 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 92b43117a8..d811bd7865 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -585,3 +585,38 @@ DROP FUNCTION proc_cast(); -- RETURN CAST(1 AS FLOAT); \ -- END \ -- $$ LANGUAGE PLPGSQL; + +-- ---------------------------------------------------------------------------- +-- proc_is_null() + +CREATE FUNCTION proc_is_null(x INT) RETURNS INT AS $$ \ +DECLARE \ + r INT; \ +BEGIN \ + IF x IS NULL THEN \ + r = 1; \ + ELSE \ + r = 2; \ + END IF; \ + RETURN r; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_is_null(1); +DROP FUNCTION proc_is_null(INT); + +CREATE FUNCTION proc_is_null(x INT) RETURNS INT AS $$ \ +DECLARE \ + r INT; \ +BEGIN \ + IF x IS NOT NULL THEN \ + r = 1; \ + ELSE \ + r = 2; \ + END IF; \ + RETURN r; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_is_null(1); +DROP FUNCTION proc_is_null(INT); diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 16089db81e..0d865dc6b7 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -1119,12 +1119,42 @@ statement ok -- ---------------------------------------------------------------------------- statement ok --- proc_random() +-- proc_is_null() statement ok statement ok -CREATE FUNCTION proc_random() RETURNS FLOAT AS $$ DECLARE x FLOAT; BEGIN x = (SELECT RANDOM()); RETURN x; END $$ LANGUAGE PLPGSQL; +CREATE FUNCTION proc_is_null(x INT) RETURNS INT AS $$ DECLARE r INT; BEGIN IF x IS NULL THEN r = 1; ELSE r = 2; END IF; RETURN r; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_is_null(1); +---- +2 + + +statement ok +DROP FUNCTION proc_is_null(INT); + +statement ok + + +statement ok +CREATE FUNCTION proc_is_null(x INT) RETURNS INT AS $$ DECLARE r INT; BEGIN IF x IS NOT NULL THEN r = 1; ELSE r = 2; END IF; RETURN r; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_is_null(1); +---- +1 + + +statement ok +DROP FUNCTION proc_is_null(INT); statement ok diff --git a/src/catalog/catalog_accessor.cpp b/src/catalog/catalog_accessor.cpp index 92191beb79..80c998504c 100644 --- a/src/catalog/catalog_accessor.cpp +++ b/src/catalog/catalog_accessor.cpp @@ -209,6 +209,8 @@ proc_oid_t CatalogAccessor::GetProcOid(const std::string &procname, const std::v return GetProcOid(procname, types); } +proc_oid_t GetProcOid(const std::string &procname, const std::vector &arg_types); + proc_oid_t CatalogAccessor::GetProcOid(const std::string &procname, const std::vector &arg_types) { proc_oid_t ret; for (auto ns_oid : search_path_) { diff --git a/src/catalog/database_catalog.cpp b/src/catalog/database_catalog.cpp index c3c6c4e68f..99f9c2640c 100644 --- a/src/catalog/database_catalog.cpp +++ b/src/catalog/database_catalog.cpp @@ -468,6 +468,26 @@ bool DatabaseCatalog::DropProcedure(const common::ManagedPointer txn, namespace_oid_t procns, const std::string &procname, const std::vector &arg_types) { + // Handle the case where an untyped NULL is passed as an argument to the function; + // in this case, we enumerate all possible combinations of types for the NULL argument + if (ContainsUntypedNull(arg_types)) { + // TODO(Kyle): This is a brittle hack + for (int8_t type_value = static_cast(execution::sql::SqlTypeId::Boolean); + type_value <= static_cast(execution::sql::SqlTypeId::Varbinary); ++type_value) { + const execution::sql::SqlTypeId type = static_cast(type_value); + const std::vector swapped = ReplaceFirstUntypedNullWith(arg_types, type); + // Recursively invoke this function; there may be further untyped NULLs + const proc_oid_t oid = GetProcOid(txn, procns, procname, swapped); + if (oid != INVALID_PROC_OID) { + return oid; + } + } + + // All of the potential types were swapped and no match was found + return INVALID_PROC_OID; + } + + // Base case: all of the types are fully specified return pg_proc_.GetProcOid(txn, common::ManagedPointer(this), procns, procname, arg_types); } @@ -477,6 +497,32 @@ bool DatabaseCatalog::SetClassPointer(const common::ManagedPointer &arg_types) { + const type_oid_t null_oid = GetTypeOidForType(execution::sql::SqlTypeId::Invalid); + return std::any_of(arg_types.cbegin(), arg_types.cend(), [null_oid](const type_oid_t t) { return t == null_oid; }); +} + +std::vector DatabaseCatalog::ReplaceFirstUntypedNullWith(const std::vector &arg_types, + execution::sql::SqlTypeId type) { + NOISEPAGE_ASSERT(ContainsUntypedNull(arg_types), "Broken precondition"); + const type_oid_t null_oid = GetTypeOidForType(execution::sql::SqlTypeId::Invalid); + auto it = std::find(arg_types.cbegin(), arg_types.cend(), null_oid); + NOISEPAGE_ASSERT(it != arg_types.cend(), "Broken invariant"); + const std::size_t index = std::distance(arg_types.cbegin(), it); + + // Manually construct the modified vector + std::vector modified{}; + modified.reserve(arg_types.size()); + for (std::size_t i = 0; i < arg_types.size(); ++i) { + if (i == index) { + modified.push_back(GetTypeOidForType(type)); + } else { + modified.push_back(arg_types.at(i)); + } + } + return modified; +} + // Template instantiations. #define DEFINE_SET_CLASS_POINTER(ClassOid, Ptr) \ diff --git a/src/include/catalog/catalog_accessor.h b/src/include/catalog/catalog_accessor.h index 0bd5eed790..21b48c75f6 100644 --- a/src/include/catalog/catalog_accessor.h +++ b/src/include/catalog/catalog_accessor.h @@ -355,7 +355,7 @@ class EXPORT CatalogAccessor { /** * Gets the OID of a procedure from pg_proc given a requested name and resolved argument types. * This lookup will return the first one found through a sequential scan through - * the current search path + * the current search path. * @param procname name of the proc to lookup * @param arg_types vector of types of arguments of procedure to look up * @return The OID of the resolved procedure if found, else `INVALID_PROC_OID` diff --git a/src/include/catalog/database_catalog.h b/src/include/catalog/database_catalog.h index 715fc44172..0ff877e3b0 100644 --- a/src/include/catalog/database_catalog.h +++ b/src/include/catalog/database_catalog.h @@ -349,5 +349,25 @@ class DatabaseCatalog { template bool SetClassPointer(common::ManagedPointer txn, ClassOid oid, const Ptr *pointer, col_oid_t class_col); + + /* -------------------------------------------------------------------------- + Function Lookup + -------------------------------------------------------------------------- */ + + /** + * Determine if the vector of argument types contains an untyped NULL. + * @param arg_types The vector of argument types + * @return `true` if the vector contains an untyped NULL type, `false` otherwise + */ + bool ContainsUntypedNull(const std::vector &arg_types); + + /** + * Swap the first untyped NULL argument type in `arg_types` with `type`. + * @param arg_types The vector of argument types that is mutated + * @param type The type that is swapped in for the untyped NULL + * @return The modified vector + */ + std::vector ReplaceFirstUntypedNullWith(const std::vector &arg_types, + execution::sql::SqlTypeId type); }; } // namespace noisepage::catalog diff --git a/test/catalog/catalog_test.cpp b/test/catalog/catalog_test.cpp index 3674bae72b..da874f597b 100644 --- a/test/catalog/catalog_test.cpp +++ b/test/catalog/catalog_test.cpp @@ -115,11 +115,11 @@ TEST_F(CatalogTests, LanguageTest) { txn_manager_->Abort(txn); } -TEST_F(CatalogTests, ProcTest) { +/** User-defined function */ +TEST_F(CatalogTests, ProcTest0) { auto txn = txn_manager_->BeginTransaction(); auto accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); - // Check visibility to me VerifyCatalogTables(*accessor); const auto language_oid = accessor->CreateLanguage("test_language"); @@ -129,8 +129,6 @@ TEST_F(CatalogTests, ProcTest) { txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); - /** User-defined procedure */ - // Create the procedure txn = txn_manager_->BeginTransaction(); accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); @@ -142,10 +140,11 @@ TEST_F(CatalogTests, ProcTest) { accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::SmallInt)}; const std::string src{"int sample(arg1, arg2, arg3){return 2;}"}; - auto proc_oid = accessor->CreateProcedure( + const auto proc_oid = accessor->CreateProcedure( procname, language_oid, namespace_oid, catalog::INVALID_TYPE_OID, args, arg_types, {}, {}, catalog::type_oid_t(static_cast(execution::sql::SqlTypeId::Integer)), src, false); EXPECT_NE(proc_oid, catalog::INVALID_PROC_OID); + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); // Query the catalog for the procedure @@ -156,13 +155,30 @@ TEST_F(CatalogTests, ProcTest) { EXPECT_EQ(accessor->GetProcOid("bad_proc", arg_types), catalog::INVALID_PROC_OID); // Look for proc that we actually added - const auto found_oid = accessor->GetProcOid(procname, arg_types); - EXPECT_EQ(found_oid, proc_oid); - EXPECT_TRUE(accessor->DropProcedure(found_oid)); + EXPECT_EQ(proc_oid, accessor->GetProcOid(procname, arg_types)); + EXPECT_TRUE(accessor->DropProcedure(proc_oid)); + + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); +} - /** Builting procedure */ +/** Builtin procedure */ +TEST_F(CatalogTests, ProcTest1) { + auto txn = txn_manager_->BeginTransaction(); + auto accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); + + VerifyCatalogTables(*accessor); + + const auto language_oid = accessor->CreateLanguage("test_language"); + const auto namespace_oid = accessor->GetDefaultNamespace(); + EXPECT_NE(language_oid, catalog::INVALID_LANGUAGE_OID); + EXPECT_NE(namespace_oid, catalog::INVALID_NAMESPACE_OID); + + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); // The procedure should already exist + txn = txn_manager_->BeginTransaction(); + accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); + const auto sin_oid = accessor->GetProcOid("sin", {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Double)}); EXPECT_NE(sin_oid, catalog::INVALID_PROC_OID); @@ -180,6 +196,73 @@ TEST_F(CatalogTests, ProcTest) { txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); } +/** Untyped NULL arguments */ +TEST_F(CatalogTests, ProcTest2) { + auto txn = txn_manager_->BeginTransaction(); + auto accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); + + VerifyCatalogTables(*accessor); + + const auto language_oid = accessor->CreateLanguage("test_language"); + const auto namespace_oid = accessor->GetDefaultNamespace(); + EXPECT_NE(language_oid, catalog::INVALID_LANGUAGE_OID); + EXPECT_NE(namespace_oid, catalog::INVALID_NAMESPACE_OID); + + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); + + // Create the procedure + txn = txn_manager_->BeginTransaction(); + accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); + + const std::string procname{"foo"}; + const std::vector args{"a", "b"}; + const std::vector arg_types{accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer)}; + const std::string src{"int foo(a, b){ return 1337; }"}; + + const auto proc_oid = accessor->CreateProcedure( + procname, language_oid, namespace_oid, catalog::INVALID_TYPE_OID, args, arg_types, {}, {}, + catalog::type_oid_t(static_cast(execution::sql::SqlTypeId::Integer)), src, false); + EXPECT_NE(proc_oid, catalog::INVALID_PROC_OID); + + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); + + // Query the catalog for the procedure + txn = txn_manager_->BeginTransaction(); + accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); + + // Look for proc that we added, with fully-specified types + EXPECT_EQ(proc_oid, accessor->GetProcOid(procname, arg_types)); + + // Look for the same proc, but with the first type unspecified + EXPECT_EQ(proc_oid, + accessor->GetProcOid(procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer)})); + + // Look for the same proc, but with the second type unspecified + EXPECT_EQ(proc_oid, + accessor->GetProcOid(procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid)})); + + // Look for the same proc, but with both types unspecified + EXPECT_EQ(proc_oid, + accessor->GetProcOid(procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid)})); + + // Look for a proc with one fixed, incorrect parameter + EXPECT_EQ(catalog::INVALID_PROC_OID, + accessor->GetProcOid(procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Real)})); + + // Look for a proc with one fixed, incorrect parameter + EXPECT_EQ(catalog::INVALID_PROC_OID, + accessor->GetProcOid(procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Real), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid)})); + + EXPECT_TRUE(accessor->DropProcedure(proc_oid)); + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); +} + /* * Create and delete a database */ From bcecad1f32284bb7113c50fdc9cd87ee7e10b0ba Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sun, 26 Sep 2021 08:15:55 -0400 Subject: [PATCH 131/139] refactor function call type resolution --- src/binder/bind_node_visitor.cpp | 11 ++++- src/catalog/catalog_accessor.cpp | 26 ++++++++-- src/catalog/database_catalog.cpp | 57 ++++++++++++++-------- src/execution/compiler/udf/udf_codegen.cpp | 13 ++++- src/include/catalog/catalog_accessor.h | 23 ++++++++- src/include/catalog/database_catalog.h | 26 ++++++++-- test/catalog/catalog_test.cpp | 47 +++++++++++++----- 7 files changed, 160 insertions(+), 43 deletions(-) diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index efa8a7bbd5..c97652b467 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -823,7 +823,16 @@ void BindNodeVisitor::Visit(common::ManagedPointer e arg_types.push_back(catalog_accessor_->GetTypeOidFromTypeId(child->GetReturnValueType())); } - auto proc_oid = catalog_accessor_->GetProcOid(expr->GetFuncName(), arg_types); + // Resolve the argument types to handle the case where an untyped NULL is passed + const auto resolved_types = catalog_accessor_->ResolveProcArgumentTypes(expr->GetFuncName(), arg_types); + if (resolved_types.empty()) { + throw BINDER_EXCEPTION("Procedure not registered", common::ErrorCode::ERRCODE_UNDEFINED_FUNCTION); + } else if (resolved_types.size() > 1) { + throw BINDER_EXCEPTION("Procedure call is ambiguous", common::ErrorCode::ERRCODE_UNDEFINED_FUNCTION); + } + + // This lookup should now always succeed + auto proc_oid = catalog_accessor_->GetProcOid(expr->GetFuncName(), resolved_types.front()); if (proc_oid == catalog::INVALID_PROC_OID) { throw BINDER_EXCEPTION("Procedure not registered", common::ErrorCode::ERRCODE_UNDEFINED_FUNCTION); } diff --git a/src/catalog/catalog_accessor.cpp b/src/catalog/catalog_accessor.cpp index 80c998504c..4777b64ff7 100644 --- a/src/catalog/catalog_accessor.cpp +++ b/src/catalog/catalog_accessor.cpp @@ -209,8 +209,6 @@ proc_oid_t CatalogAccessor::GetProcOid(const std::string &procname, const std::v return GetProcOid(procname, types); } -proc_oid_t GetProcOid(const std::string &procname, const std::vector &arg_types); - proc_oid_t CatalogAccessor::GetProcOid(const std::string &procname, const std::vector &arg_types) { proc_oid_t ret; for (auto ns_oid : search_path_) { @@ -222,6 +220,26 @@ proc_oid_t CatalogAccessor::GetProcOid(const std::string &procname, const std::v return catalog::INVALID_PROC_OID; } +std::vector> CatalogAccessor::ResolveProcArgumentTypes( + const std::string &procname, const std::vector &arg_types) const { + // Transform the string type identifiers to internal type IDs + std::vector types{}; + types.reserve(arg_types.size()); + std::transform(arg_types.cbegin(), arg_types.cend(), std::back_inserter(types), + [this](const std::string &name) { return TypeNameToType(name); }); + return ResolveProcArgumentTypes(procname, arg_types); +} + +std::vector> CatalogAccessor::ResolveProcArgumentTypes( + const std::string &procname, const std::vector &arg_types) const { + std::vector> types{}; + for (auto ns_oid : search_path_) { + const auto resolved = dbc_->ResolveProcArgumentTypes(txn_, ns_oid, procname, arg_types); + types.insert(types.cend(), resolved.cbegin(), resolved.cend()); + } + return types; +} + common::ManagedPointer CatalogAccessor::GetFunctionContext(proc_oid_t proc_oid) { return dbc_->GetFunctionContext(txn_, proc_oid); } @@ -240,7 +258,7 @@ optimizer::TableStats CatalogAccessor::GetTableStatistics(table_oid_t table_oid) return dbc_->GetTableStatistics(txn_, table_oid); } -type_oid_t CatalogAccessor::GetTypeOidFromTypeId(execution::sql::SqlTypeId type) { +type_oid_t CatalogAccessor::GetTypeOidFromTypeId(execution::sql::SqlTypeId type) const { return dbc_->GetTypeOidForType(type); } @@ -257,7 +275,7 @@ void CatalogAccessor::RegisterTempTable(table_oid_t table_oid, const common::Man temp_schemas_[table_oid] = schema; } -type_oid_t CatalogAccessor::TypeNameToType(const std::string &type_name) { +type_oid_t CatalogAccessor::TypeNameToType(const std::string &type_name) const { type_oid_t type; if (type_name == "int2") { type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::SmallInt); diff --git a/src/catalog/database_catalog.cpp b/src/catalog/database_catalog.cpp index 99f9c2640c..4dade0a352 100644 --- a/src/catalog/database_catalog.cpp +++ b/src/catalog/database_catalog.cpp @@ -312,7 +312,7 @@ std::vector, const Index return pg_core_.GetIndexes(txn, table); } -type_oid_t DatabaseCatalog::GetTypeOidForType(const execution::sql::SqlTypeId type) { +type_oid_t DatabaseCatalog::GetTypeOidForType(const execution::sql::SqlTypeId type) const { // TODO(WAN): WARNING! Do not change this seeing PgCoreImpl::MakeColumn and PgCoreImpl::CreateColumn. return type_oid_t(static_cast(type)); } @@ -468,42 +468,59 @@ bool DatabaseCatalog::DropProcedure(const common::ManagedPointer txn, namespace_oid_t procns, const std::string &procname, const std::vector &arg_types) { - // Handle the case where an untyped NULL is passed as an argument to the function; - // in this case, we enumerate all possible combinations of types for the NULL argument if (ContainsUntypedNull(arg_types)) { - // TODO(Kyle): This is a brittle hack - for (int8_t type_value = static_cast(execution::sql::SqlTypeId::Boolean); - type_value <= static_cast(execution::sql::SqlTypeId::Varbinary); ++type_value) { - const execution::sql::SqlTypeId type = static_cast(type_value); - const std::vector swapped = ReplaceFirstUntypedNullWith(arg_types, type); - // Recursively invoke this function; there may be further untyped NULLs - const proc_oid_t oid = GetProcOid(txn, procns, procname, swapped); - if (oid != INVALID_PROC_OID) { - return oid; - } - } - - // All of the potential types were swapped and no match was found + // NOTE(Kyle): Should this be a harder error condition (i.e. assertion failure)? return INVALID_PROC_OID; } - - // Base case: all of the types are fully specified return pg_proc_.GetProcOid(txn, common::ManagedPointer(this), procns, procname, arg_types); } +std::vector> DatabaseCatalog::ResolveProcArgumentTypes( + common::ManagedPointer txn, namespace_oid_t procns, const std::string &procname, + const std::vector &arg_types) { + std::vector> result{}; + ResolveProcArgumentTypes(txn, procns, procname, arg_types, &result); + return result; +} + +void DatabaseCatalog::ResolveProcArgumentTypes(common::ManagedPointer txn, + namespace_oid_t procns, const std::string &procname, + const std::vector &arg_types, + std::vector> *result) { + // If the provided collection of arguments does not contain + // an untyped NULL, all types are fully resolved, bottom out + if (!ContainsUntypedNull(arg_types)) { + if (pg_proc_.GetProcOid(txn, common::ManagedPointer(this), procns, procname, arg_types) != INVALID_PROC_OID) { + result->push_back(arg_types); + } + return; + } + + // Handle the case where an untyped NULL is passed as an argument to the function; + // in this case, we enumerate all possible combinations of types for the NULL argument + + // TODO(Kyle): This is a brittle hack + for (int8_t type_value = static_cast(execution::sql::SqlTypeId::Boolean); + type_value <= static_cast(execution::sql::SqlTypeId::Varbinary); ++type_value) { + const execution::sql::SqlTypeId type = static_cast(type_value); + // Recursively invoke this function; there may be further untyped NULLs + ResolveProcArgumentTypes(txn, procns, procname, ReplaceFirstUntypedNullWith(arg_types, type), result); + } +} + template bool DatabaseCatalog::SetClassPointer(const common::ManagedPointer txn, const ClassOid oid, const Ptr *const pointer, const col_oid_t class_col) { return pg_core_.SetClassPointer(txn, oid, pointer, class_col); } -bool DatabaseCatalog::ContainsUntypedNull(const std::vector &arg_types) { +bool DatabaseCatalog::ContainsUntypedNull(const std::vector &arg_types) const { const type_oid_t null_oid = GetTypeOidForType(execution::sql::SqlTypeId::Invalid); return std::any_of(arg_types.cbegin(), arg_types.cend(), [null_oid](const type_oid_t t) { return t == null_oid; }); } std::vector DatabaseCatalog::ReplaceFirstUntypedNullWith(const std::vector &arg_types, - execution::sql::SqlTypeId type) { + execution::sql::SqlTypeId type) const { NOISEPAGE_ASSERT(ContainsUntypedNull(arg_types), "Broken precondition"); const type_oid_t null_oid = GetTypeOidForType(execution::sql::SqlTypeId::Invalid); auto it = std::find(arg_types.cbegin(), arg_types.cend(), null_oid); diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index b2b1063595..4d455728b3 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -341,7 +341,18 @@ void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { arguments.cbegin(), arguments.cend(), std::back_inserter(argument_types), [this](const ast::Expr *expr) -> catalog::type_oid_t { return GetCatalogTypeOidFromSQLType(ResolveType(expr)); }); - const auto proc_oid = accessor_->GetProcOid(ast->Callee(), argument_types); + // Resolve the argument types to handle the case where an untyped NULL is passed + const auto resolved_types = accessor_->ResolveProcArgumentTypes(ast->Callee(), argument_types); + if (resolved_types.empty()) { + throw BINDER_EXCEPTION(fmt::format("Procedure '{}' not registered", ast->Callee()), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } else if (resolved_types.size() > 1) { + throw BINDER_EXCEPTION(fmt::format("Procedure call '{}' is ambiguous", ast->Callee()), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + // This lookup should now always succeed + const auto proc_oid = accessor_->GetProcOid(ast->Callee(), resolved_types.front()); if (proc_oid == catalog::INVALID_PROC_OID) { throw BINDER_EXCEPTION(fmt::format("Invalid function call '{}'", ast->Callee()), common::ErrorCode::ERRCODE_PLPGSQL_ERROR); diff --git a/src/include/catalog/catalog_accessor.h b/src/include/catalog/catalog_accessor.h index 21b48c75f6..6b2961de81 100644 --- a/src/include/catalog/catalog_accessor.h +++ b/src/include/catalog/catalog_accessor.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -362,6 +363,24 @@ class EXPORT CatalogAccessor { */ proc_oid_t GetProcOid(const std::string &procname, const std::vector &arg_types); + /** + * Resolve procedure argument types. + * @param procname The name of the procedure + * @param arg_types A vector of the string representation of the argument types + * @return A collection of all sets of arguments for which this procedure is resolved + */ + std::vector> ResolveProcArgumentTypes(const std::string &procname, + const std::vector &arg_types) const; + + /** + * Resolve procedure argument types. + * @param procname The name of the procedure + * @param arg_types A vector of the string representation of the argument types + * @return A collection of all sets of arguments for which this procedure is resolved + */ + std::vector> ResolveProcArgumentTypes(const std::string &procname, + const std::vector &arg_types) const; + /** * Gets the proc context pointer column of proc_oid * @param proc_oid The proc_oid whose pointer column we are getting here @@ -397,7 +416,7 @@ class EXPORT CatalogAccessor { * @param type * @return type_oid of type in pg_type */ - type_oid_t GetTypeOidFromTypeId(execution::sql::SqlTypeId type); + type_oid_t GetTypeOidFromTypeId(execution::sql::SqlTypeId type) const; /** * @return BlockStore to be used for CREATE operations @@ -475,7 +494,7 @@ class EXPORT CatalogAccessor { * @param type_name The type name * @return The internal catalog type identifier for the type */ - type_oid_t TypeNameToType(const std::string &type_name); + type_oid_t TypeNameToType(const std::string &type_name) const; }; } // namespace noisepage::catalog diff --git a/src/include/catalog/database_catalog.h b/src/include/catalog/database_catalog.h index 0ff877e3b0..6bcdfdbaae 100644 --- a/src/include/catalog/database_catalog.h +++ b/src/include/catalog/database_catalog.h @@ -149,7 +149,7 @@ class DatabaseCatalog { common::ManagedPointer txn, table_oid_t table); /** @return The type_oid_t that corresponds to the internal TypeId. */ - type_oid_t GetTypeOidForType(execution::sql::SqlTypeId type); + type_oid_t GetTypeOidForType(execution::sql::SqlTypeId type) const; /** @brief Get a list of all of the constraints for the specified table. */ std::vector GetConstraints(common::ManagedPointer txn, @@ -173,12 +173,20 @@ class DatabaseCatalog { const std::string &src, bool is_aggregate); /** @brief Drop the specified procedure. @see PgProcImpl::DropProcedure */ bool DropProcedure(common::ManagedPointer txn, proc_oid_t proc); + /** @brief Get the OID of the specified procedure. @see PgProcImpl::GetProcOid */ proc_oid_t GetProcOid(common::ManagedPointer txn, namespace_oid_t procns, const std::string &procname, const std::vector &all_arg_types); + + /** @brief Resolve all combinations of argument types for the procedure */ + std::vector> ResolveProcArgumentTypes( + common::ManagedPointer txn, namespace_oid_t procns, const std::string &procname, + const std::vector &arg_types); + /** @brief Get the procedure context for the specified procedure. @see PgProcImpl::GetProcCtxPtr */ common::ManagedPointer GetFunctionContext( common::ManagedPointer txn, proc_oid_t proc_oid); + /** @brief Set the procedure context for the specified procedure. @see PgProcImpl::SetProcCtxPtr */ bool SetFunctionContext(common::ManagedPointer txn, proc_oid_t proc_oid, const execution::functions::FunctionContext *func_context); @@ -354,12 +362,24 @@ class DatabaseCatalog { Function Lookup -------------------------------------------------------------------------- */ + /** + * Recursive helper function for procedure argument type resolution. + * @param txn The transaction context + * @param procns The namespace of the procedure + * @param procname The procedure name + * @param arg_types The argument types + * @param result The vector that receives any resolved sets of arguments + */ + void ResolveProcArgumentTypes(common::ManagedPointer txn, namespace_oid_t procns, + const std::string &procname, const std::vector &arg_types, + std::vector> *result); + /** * Determine if the vector of argument types contains an untyped NULL. * @param arg_types The vector of argument types * @return `true` if the vector contains an untyped NULL type, `false` otherwise */ - bool ContainsUntypedNull(const std::vector &arg_types); + bool ContainsUntypedNull(const std::vector &arg_types) const; /** * Swap the first untyped NULL argument type in `arg_types` with `type`. @@ -368,6 +388,6 @@ class DatabaseCatalog { * @return The modified vector */ std::vector ReplaceFirstUntypedNullWith(const std::vector &arg_types, - execution::sql::SqlTypeId type); + execution::sql::SqlTypeId type) const; }; } // namespace noisepage::catalog diff --git a/test/catalog/catalog_test.cpp b/test/catalog/catalog_test.cpp index da874f597b..cdee1b7710 100644 --- a/test/catalog/catalog_test.cpp +++ b/test/catalog/catalog_test.cpp @@ -234,30 +234,53 @@ TEST_F(CatalogTests, ProcTest2) { // Look for proc that we added, with fully-specified types EXPECT_EQ(proc_oid, accessor->GetProcOid(procname, arg_types)); - // Look for the same proc, but with the first type unspecified - EXPECT_EQ(proc_oid, + // Look for the same proc, but with the first type unspecified (should fail) + EXPECT_EQ(catalog::INVALID_PROC_OID, accessor->GetProcOid(procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid), accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer)})); - // Look for the same proc, but with the second type unspecified - EXPECT_EQ(proc_oid, + // Look for the same proc, but with the second type unspecified (should fail) + EXPECT_EQ(catalog::INVALID_PROC_OID, accessor->GetProcOid(procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer), accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid)})); - // Look for the same proc, but with both types unspecified - EXPECT_EQ(proc_oid, + // Look for the same proc, but with both types unspecified (should fail) + EXPECT_EQ(catalog::INVALID_PROC_OID, accessor->GetProcOid(procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid), accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid)})); + // Look for the same proc, but with types resolved + const auto r0 = accessor->ResolveProcArgumentTypes( + procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer)}); + EXPECT_EQ(1, r0.size()); + EXPECT_EQ(proc_oid, accessor->GetProcOid(procname, r0.front())); + + // Look for the same proc, but with types resolved + const auto r1 = accessor->ResolveProcArgumentTypes( + procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid)}); + EXPECT_EQ(1, r1.size()); + EXPECT_EQ(proc_oid, accessor->GetProcOid(procname, r1.front())); + + // Look for the same proc, but with types resolved + const auto r2 = accessor->ResolveProcArgumentTypes( + procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid)}); + EXPECT_EQ(1, r2.size()); + EXPECT_EQ(proc_oid, accessor->GetProcOid(procname, r2.front())); + // Look for a proc with one fixed, incorrect parameter - EXPECT_EQ(catalog::INVALID_PROC_OID, - accessor->GetProcOid(procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid), - accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Real)})); + const auto r3 = + accessor->ResolveProcArgumentTypes(procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Real)}); + EXPECT_TRUE(r3.empty()); // Look for a proc with one fixed, incorrect parameter - EXPECT_EQ(catalog::INVALID_PROC_OID, - accessor->GetProcOid(procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Real), - accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid)})); + const auto r4 = accessor->ResolveProcArgumentTypes( + procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Real), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid)}); + EXPECT_TRUE(r4.empty()); EXPECT_TRUE(accessor->DropProcedure(proc_oid)); txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); From 02829ee2fc440758c7d8181d28502528b97c8264 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sun, 26 Sep 2021 09:28:33 -0400 Subject: [PATCH 132/139] now able to invoke UDFs with untyped NULLs that are automatically resolved to the correct type --- script/testing/junit/sql/udf.sql | 30 +++++++++++---------- script/testing/junit/traces/udf.test | 24 ++++++++++++++--- src/binder/bind_node_visitor.cpp | 17 ++++++++++-- src/catalog/catalog_accessor.cpp | 4 +++ src/catalog/database_catalog.cpp | 36 ++++++++++++++++++++++++++ src/execution/sql/sql.cpp | 2 ++ src/include/catalog/catalog_accessor.h | 13 +++++++--- src/include/catalog/database_catalog.h | 3 +++ test/catalog/catalog_test.cpp | 14 ++++++++++ 9 files changed, 122 insertions(+), 21 deletions(-) diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index d811bd7865..29fbe94862 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -603,20 +603,24 @@ END \ $$ LANGUAGE PLPGSQL; SELECT proc_is_null(1); +SELECT proc_is_null(NULL); + DROP FUNCTION proc_is_null(INT); -CREATE FUNCTION proc_is_null(x INT) RETURNS INT AS $$ \ -DECLARE \ - r INT; \ -BEGIN \ - IF x IS NOT NULL THEN \ - r = 1; \ - ELSE \ - r = 2; \ - END IF; \ - RETURN r; \ -END \ +CREATE FUNCTION proc_is_not_null(x INT) RETURNS INT AS $$ \ +DECLARE \ + r INT; \ +BEGIN \ + IF x IS NOT NULL THEN \ + r = 1; \ + ELSE \ + r = 2; \ + END IF; \ + RETURN r; \ +END \ $$ LANGUAGE PLPGSQL; -SELECT proc_is_null(1); -DROP FUNCTION proc_is_null(INT); +SELECT proc_is_not_null(1); +SELECT proc_is_not_null(NULL); + +DROP FUNCTION proc_is_not_null(INT); diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 0d865dc6b7..5f1d473f04 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -1136,6 +1136,15 @@ SELECT proc_is_null(1); 2 +query I rowsort +SELECT proc_is_null(NULL); +---- +1 + + +statement ok + + statement ok DROP FUNCTION proc_is_null(INT); @@ -1143,18 +1152,27 @@ statement ok statement ok -CREATE FUNCTION proc_is_null(x INT) RETURNS INT AS $$ DECLARE r INT; BEGIN IF x IS NOT NULL THEN r = 1; ELSE r = 2; END IF; RETURN r; END $$ LANGUAGE PLPGSQL; +CREATE FUNCTION proc_is_not_null(x INT) RETURNS INT AS $$ DECLARE r INT; BEGIN IF x IS NOT NULL THEN r = 1; ELSE r = 2; END IF; RETURN r; END $$ LANGUAGE PLPGSQL; statement ok query I rowsort -SELECT proc_is_null(1); +SELECT proc_is_not_null(1); ---- 1 +query I rowsort +SELECT proc_is_not_null(NULL); +---- +2 + + statement ok -DROP FUNCTION proc_is_null(INT); + + +statement ok +DROP FUNCTION proc_is_not_null(INT); statement ok diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index c97652b467..42eddbb7b6 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -816,8 +816,8 @@ void BindNodeVisitor::Visit(common::ManagedPointer e BINDER_LOG_TRACE("Visiting FunctionExpression ..."); SqlNodeVisitor::Visit(expr); - std::vector arg_types; auto children = expr->GetChildren(); + std::vector arg_types{}; arg_types.reserve(children.size()); for (const auto &child : children) { arg_types.push_back(catalog_accessor_->GetTypeOidFromTypeId(child->GetReturnValueType())); @@ -837,8 +837,21 @@ void BindNodeVisitor::Visit(common::ManagedPointer e throw BINDER_EXCEPTION("Procedure not registered", common::ErrorCode::ERRCODE_UNDEFINED_FUNCTION); } - auto func_context = catalog_accessor_->GetFunctionContext(proc_oid); + // The function is now resolved; we need to perform one further substitution + // here to handle the case where a literal untyped NULL is provided as an + // argument to the function call. In this case, the execution engine has no + // way to model the untyped NULL, so we need to replace this with a typed NULL + // from the function call argument that was resolved above + for (std::size_t i = 0; i < children.size(); ++i) { + auto child = children[i]; + if (child->GetExpressionType() == parser::ExpressionType::VALUE_CONSTANT && + child->GetReturnValueType() == execution::sql::SqlTypeId::Invalid) { + auto cve = child.CastManagedPointerTo(); + cve->SetValue(catalog_accessor_->GetTypeIdFromTypeOid(resolved_types.front()[0]), execution::sql::Val(true)); + } + } + auto func_context = catalog_accessor_->GetFunctionContext(proc_oid); expr->SetProcOid(proc_oid); expr->SetReturnValueType(func_context->GetFunctionReturnType()); } diff --git a/src/catalog/catalog_accessor.cpp b/src/catalog/catalog_accessor.cpp index 4777b64ff7..a38f47ce7d 100644 --- a/src/catalog/catalog_accessor.cpp +++ b/src/catalog/catalog_accessor.cpp @@ -262,6 +262,10 @@ type_oid_t CatalogAccessor::GetTypeOidFromTypeId(execution::sql::SqlTypeId type) return dbc_->GetTypeOidForType(type); } +execution::sql::SqlTypeId CatalogAccessor::GetTypeIdFromTypeOid(type_oid_t type) const { + return dbc_->GetTypeForTypeOid(type); +} + common::ManagedPointer CatalogAccessor::GetBlockStore() const { // TODO(Matt): at some point we may decide to adjust the source (i.e. each DatabaseCatalog has one), stick it in a // pg_tablespace table, or we may eliminate the concept entirely. This works for now to allow CREATE nodes to bind a diff --git a/src/catalog/database_catalog.cpp b/src/catalog/database_catalog.cpp index 4dade0a352..6fd4aa3c33 100644 --- a/src/catalog/database_catalog.cpp +++ b/src/catalog/database_catalog.cpp @@ -317,6 +317,42 @@ type_oid_t DatabaseCatalog::GetTypeOidForType(const execution::sql::SqlTypeId ty return type_oid_t(static_cast(type)); } +execution::sql::SqlTypeId DatabaseCatalog::GetTypeForTypeOid(type_oid_t type) const { + // NOTE(Kyle): This is a disgusting hack + switch (type.UnderlyingValue()) { + case 0: + return execution::sql::SqlTypeId::Boolean; + case 1: + return execution::sql::SqlTypeId::TinyInt; + case 2: + return execution::sql::SqlTypeId::SmallInt; + case 3: + return execution::sql::SqlTypeId::Integer; + case 4: + return execution::sql::SqlTypeId::BigInt; + case 5: + return execution::sql::SqlTypeId::Real; + case 6: + return execution::sql::SqlTypeId::Double; + case 7: + return execution::sql::SqlTypeId::Decimal; + case 8: + return execution::sql::SqlTypeId::Date; + case 9: + return execution::sql::SqlTypeId::Timestamp; + case 10: + return execution::sql::SqlTypeId::Char; + case 11: + return execution::sql::SqlTypeId::Varchar; + case 12: + return execution::sql::SqlTypeId::Varbinary; + case 255: + return execution::sql::SqlTypeId::Invalid; + default: + UNREACHABLE("Impossible type_oid_t"); + } +} + void DatabaseCatalog::BootstrapTable(const common::ManagedPointer txn, const table_oid_t table_oid, const namespace_oid_t ns_oid, const std::string &name, const Schema &schema, const common::ManagedPointer table_ptr) { diff --git a/src/execution/sql/sql.cpp b/src/execution/sql/sql.cpp index 7637bfd1a8..60284b80d4 100644 --- a/src/execution/sql/sql.cpp +++ b/src/execution/sql/sql.cpp @@ -286,6 +286,8 @@ std::string SqlTypeIdToString(SqlTypeId type) { return "Varchar"; case SqlTypeId::Varbinary: return "Varbinary"; + case SqlTypeId::Invalid: + return "Invalid"; default: // All cases handled UNREACHABLE("Impossible type"); diff --git a/src/include/catalog/catalog_accessor.h b/src/include/catalog/catalog_accessor.h index 6b2961de81..6da5005664 100644 --- a/src/include/catalog/catalog_accessor.h +++ b/src/include/catalog/catalog_accessor.h @@ -412,12 +412,19 @@ class EXPORT CatalogAccessor { optimizer::TableStats GetTableStatistics(table_oid_t table_oid); /** - * Returns the type oid of the given TypeId in pg_type - * @param type - * @return type_oid of type in pg_type + * Returns the type oid of the given TypeId in pg_type. + * @param type The queried type + * @return The corresponding type_oid_t */ type_oid_t GetTypeOidFromTypeId(execution::sql::SqlTypeId type) const; + /** + * Returns the SQL type ID of the given type_oid_t. + * @param type The queried type + * @return The corresponding SQL type ID + */ + execution::sql::SqlTypeId GetTypeIdFromTypeOid(type_oid_t type) const; + /** * @return BlockStore to be used for CREATE operations */ diff --git a/src/include/catalog/database_catalog.h b/src/include/catalog/database_catalog.h index 6bcdfdbaae..dc60153fab 100644 --- a/src/include/catalog/database_catalog.h +++ b/src/include/catalog/database_catalog.h @@ -151,6 +151,9 @@ class DatabaseCatalog { /** @return The type_oid_t that corresponds to the internal TypeId. */ type_oid_t GetTypeOidForType(execution::sql::SqlTypeId type) const; + /** @return The SQL type ID that corresponds to the type_oid_t */ + execution::sql::SqlTypeId GetTypeForTypeOid(type_oid_t type) const; + /** @brief Get a list of all of the constraints for the specified table. */ std::vector GetConstraints(common::ManagedPointer txn, table_oid_t table); diff --git a/test/catalog/catalog_test.cpp b/test/catalog/catalog_test.cpp index cdee1b7710..ccf5e41c7b 100644 --- a/test/catalog/catalog_test.cpp +++ b/test/catalog/catalog_test.cpp @@ -1015,4 +1015,18 @@ TEST_F(CatalogTests, StatisticTest) { txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); } +TEST_F(CatalogTests, TypeRoundTrip) { + // Ensure that types always round-trip + auto txn = txn_manager_->BeginTransaction(); + auto accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); + for (int8_t type_raw = static_cast(execution::sql::SqlTypeId::Boolean); + type_raw <= static_cast(execution::sql::SqlTypeId::Varbinary); ++type_raw) { + const execution::sql::SqlTypeId type = static_cast(type_raw); + const catalog::type_oid_t oid = accessor->GetTypeOidFromTypeId(type); + EXPECT_EQ(type, accessor->GetTypeIdFromTypeOid(oid)); + } + + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); +} + } // namespace noisepage From e6173690335851b95cf3aa3e4459a183cd9bf2af Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sun, 26 Sep 2021 12:31:50 -0400 Subject: [PATCH 133/139] fix argument resolution when a cast is required --- script/testing/junit/sql/udf.sql | 99 +++++++++ script/testing/junit/traces/udf.test | 201 ++++++++++++++++++ src/execution/ast/udf/udf_ast_nodes.cpp | 49 +++++ src/execution/compiler/udf/udf_codegen.cpp | 36 ++++ src/include/execution/ast/udf/node_types.h | 24 +++ src/include/execution/ast/udf/udf_ast_nodes.h | 93 +++++--- .../execution/compiler/udf/udf_codegen.h | 7 + src/parser/udf/plpgsql_parser.cpp | 3 + 8 files changed, 485 insertions(+), 27 deletions(-) create mode 100644 src/execution/ast/udf/udf_ast_nodes.cpp create mode 100644 src/include/execution/ast/udf/node_types.h diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql index 29fbe94862..ef5b5d8a7b 100644 --- a/script/testing/junit/sql/udf.sql +++ b/script/testing/junit/sql/udf.sql @@ -121,6 +121,33 @@ SELECT x, conditional(x) FROM integers; DROP FUNCTION conditional(INT); +-- Nested conditional control flow +CREATE FUNCTION conditional(x INT, y INT) RETURNS INT AS $$ \ +BEGIN \ + IF x > 1 THEN \ + IF y > 1 THEN \ + RETURN 1; \ + ELSE \ + RETURN 2; \ + END IF; \ + ELSE \ + IF y > 1 THEN \ + RETURN 3; \ + ELSE \ + RETURN 4; \ + END IF; \ + END IF; \ + RETURN 0; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT conditional(1, 1); +SELECT conditional(1, 2); +SELECT conditional(2, 1); +SELECT conditional(2, 2); + +DROP FUNCTION conditional(INT, INT); + -- ---------------------------------------------------------------------------- -- proc_while() @@ -624,3 +651,75 @@ SELECT proc_is_not_null(1); SELECT proc_is_not_null(NULL); DROP FUNCTION proc_is_not_null(INT); + +-- ---------------------------------------------------------------------------- +-- proc_length() + +-- Assignment of LENGTH to temporary +CREATE FUNCTION proc_length(t VARCHAR) RETURNS INT AS $$ \ +DECLARE \ + r INT; \ +BEGIN \ + r = LENGTH(t); \ + RETURN r; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_length('hello'); +DROP FUNCTION proc_length(VARCHAR); + +-- Direct RETURN of LENGTH +CREATE FUNCTION proc_length(t VARCHAR) RETURNS INT AS $$ \ +BEGIN \ + RETURN LENGTH(t); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_length('hello'); + +DROP FUNCTION proc_length(VARCHAR); + +-- Use of LENGTH() in conditional +CREATE FUNCTION proc_length(t VARCHAR) RETURNS INT AS $$ \ +BEGIN \ + IF LENGTH(t) > 1 THEN \ + RETURN 1; \ + ELSE \ + RETURN 2; \ + END IF; \ + RETURN 0; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_length('a'); +SELECT proc_length('ab'); +SELECT proc_length('abc'); + +DROP FUNCTION proc_length(VARCHAR); + +-- ---------------------------------------------------------------------------- +-- proc_substr() + +-- Able to pass all arguments through +CREATE FUNCTION proc_substr(t VARCHAR, i INT, l INT) RETURNS VARCHAR AS $$ \ +BEGIN \ + RETURN SUBSTR(t, i, l); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_substr('hello', 1, 1); +SELECT proc_substr('hello', 1, 2); + +DROP FUNCTION proc_substr(VARCHAR, INT, INT); + +-- Able to specify a literal value +CREATE FUNCTION proc_substr(t VARCHAR, i INT) RETURNS VARCHAR AS $$ \ +BEGIN \ + RETURN SUBSTR(t, i, 1); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_substr('hello', 1); +SELECT proc_substr('hello', 2); + +DROP FUNCTION proc_substr(VARCHAR, INT); diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test index 5f1d473f04..9321696cb4 100644 --- a/script/testing/junit/traces/udf.test +++ b/script/testing/junit/traces/udf.test @@ -305,6 +305,48 @@ DROP FUNCTION conditional(INT); statement ok +statement ok +-- Nested conditional control flow + +statement ok +CREATE FUNCTION conditional(x INT, y INT) RETURNS INT AS $$ BEGIN IF x > 1 THEN IF y > 1 THEN RETURN 1; ELSE RETURN 2; END IF; ELSE IF y > 1 THEN RETURN 3; ELSE RETURN 4; END IF; END IF; RETURN 0; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT conditional(1, 1); +---- +4 + + +query I rowsort +SELECT conditional(1, 2); +---- +3 + + +query I rowsort +SELECT conditional(2, 1); +---- +2 + + +query I rowsort +SELECT conditional(2, 2); +---- +1 + + +statement ok + + +statement ok +DROP FUNCTION conditional(INT, INT); + +statement ok + + statement ok -- ---------------------------------------------------------------------------- @@ -1176,3 +1218,162 @@ statement ok DROP FUNCTION proc_is_not_null(INT); statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_length() + +statement ok + + +statement ok +-- Assignment of LENGTH to temporary + +statement ok +CREATE FUNCTION proc_length(t VARCHAR) RETURNS INT AS $$ DECLARE r INT; BEGIN r = LENGTH(t); RETURN r; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_length('hello'); +---- +5 + + +statement ok +DROP FUNCTION proc_length(VARCHAR); + +statement ok + + +statement ok +-- Direct RETURN of LENGTH + +statement ok +CREATE FUNCTION proc_length(t VARCHAR) RETURNS INT AS $$ BEGIN RETURN LENGTH(t); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_length('hello'); +---- +5 + + +statement ok + + +statement ok +DROP FUNCTION proc_length(VARCHAR); + +statement ok + + +statement ok +-- Use of LENGTH() in conditional + +statement ok +CREATE FUNCTION proc_length(t VARCHAR) RETURNS INT AS $$ BEGIN IF LENGTH(t) > 1 THEN RETURN 1; ELSE RETURN 2; END IF; RETURN 0; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_length('a'); +---- +2 + + +query I rowsort +SELECT proc_length('ab'); +---- +1 + + +query I rowsort +SELECT proc_length('abc'); +---- +1 + + +statement ok + + +statement ok +DROP FUNCTION proc_length(VARCHAR); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_substr() + +statement ok + + +statement ok +-- Able to pass all arguments through + +statement ok +CREATE FUNCTION proc_substr(t VARCHAR, i INT, l INT) RETURNS VARCHAR AS $$ BEGIN RETURN SUBSTR(t, i, l); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query T rowsort +SELECT proc_substr('hello', 1, 1); +---- +h + + +query T rowsort +SELECT proc_substr('hello', 1, 2); +---- +he + + +statement ok + + +statement ok +DROP FUNCTION proc_substr(VARCHAR, INT, INT); + +statement ok + + +statement ok +-- Able to specify a literal value + +statement ok +CREATE FUNCTION proc_substr(t VARCHAR, i INT) RETURNS VARCHAR AS $$ BEGIN RETURN SUBSTR(t, i, 1); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query T rowsort +SELECT proc_substr('hello', 1); +---- +h + + +query T rowsort +SELECT proc_substr('hello', 2); +---- +e + + +statement ok + + +statement ok +DROP FUNCTION proc_substr(VARCHAR, INT); + +statement ok diff --git a/src/execution/ast/udf/udf_ast_nodes.cpp b/src/execution/ast/udf/udf_ast_nodes.cpp new file mode 100644 index 0000000000..4e9a3c7ee9 --- /dev/null +++ b/src/execution/ast/udf/udf_ast_nodes.cpp @@ -0,0 +1,49 @@ +#include + +#include "common/macros.h" +#include "execution/ast/udf/node_types.h" + +namespace noisepage::execution::ast::udf { + +std::string NodeTypeToShortString(NodeType type) { + switch (type) { + case NodeType::VALUE_EXPR: + return "VALUE_EXPR"; + case NodeType::IS_NULL_EXPR: + return "IS_NULL_EXPR"; + case NodeType::VARIABLE_EXPR: + return "VARIABLE_EXPR"; + case NodeType::MEMBER_EXPR: + return "MEMBER_EXPR"; + case NodeType::BINARY_EXPR: + return "BINARY_EXPR"; + case NodeType::CALL_EXPR: + return "CALL_EXPR"; + case NodeType::SEQ_STMT: + return "SEQ_STMT"; + case NodeType::DECL_STMT: + return "DECL_STMT"; + case NodeType::IF_STMT: + return "IF_STMT"; + case NodeType::FORI_STMT: + return "FORI_STMT"; + case NodeType::FORS_STMT: + return "FORS_STMT"; + case NodeType::WHILE_STMT: + return "WHILE_STMT"; + case NodeType::RET_STMT: + return "RET_STMT"; + case NodeType::ASSIGN_STMT: + return "ASSIGN_STMT"; + case NodeType::SQL_STMT: + return "SQL_STMT"; + case NodeType::DYNAMIC_SQL_STMT: + return "DYNAMIC_SQL_STMT"; + case NodeType::FUNCTION: + return "FUNCTION"; + default: + NOISEPAGE_ASSERT(false, "Impossible node type"); + } +} + +} // namespace noisepage::execution::ast::udf diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp index 4d455728b3..dac8820c39 100644 --- a/src/execution/compiler/udf/udf_codegen.cpp +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -278,6 +278,10 @@ void UdfCodegen::Visit(ast::udf::BinaryExprAST *ast) { } void UdfCodegen::Visit(ast::udf::IfStmtAST *ast) { + // TODO(Kyle): It would be nice to add support for IF .. ELSIF .. ELSE + // constructs, but the current TPL architecture does not have native + // support for code generation of this type of control flow, so I am + // going to punt on it for now. ast::Expr *condition = EvaluateExpression(ast->Condition()); If branch(fb_, condition); ast->Then()->Accept(this); @@ -360,6 +364,10 @@ void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { auto context = accessor_->GetFunctionContext(proc_oid); if (context->IsBuiltin()) { + if (context->IsExecCtxRequired()) { + // If this builtin requires an execution context, provide it + arguments.insert(arguments.begin(), GetExecutionContext()); + } ast::Expr *result = codegen_->CallBuiltin(context->GetBuiltin(), arguments); SetExecutionResult(result); } else { @@ -387,6 +395,8 @@ void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { } sql::SqlTypeId UdfCodegen::ResolveType(const ast::Expr *expr) const { + const auto t = expr->GetKind(); + (void)t; switch (expr->GetKind()) { case ast::AstNode::Kind::LitExpr: return ResolveTypeForLiteralExpression(expr->SafeAs()); @@ -394,6 +404,8 @@ sql::SqlTypeId UdfCodegen::ResolveType(const ast::Expr *expr) const { return ResolveTypeForBinaryExpression(expr->SafeAs()); case ast::AstNode::Kind::IdentifierExpr: return ResolveTypeForIdentifierExpression(expr->SafeAs()); + case ast::AstNode::Kind::CallExpr: + return ResolveTypeForCallExpression(expr->SafeAs()); default: UNREACHABLE("Function call argument type cannot be resolved"); } @@ -461,6 +473,30 @@ sql::SqlTypeId UdfCodegen::ResolveTypeForIdentifierExpression(const ast::Identif return GetVariableType(expr->Name().GetString()); } +sql::SqlTypeId UdfCodegen::ResolveTypeForCallExpression(const ast::CallExpr *expr) const { + const ast::Type *type = expr->GetType(); + NOISEPAGE_ASSERT(type->IsSqlValueType(), "Invalid type"); + const ast::BuiltinType *builtin = type->SafeAs(); + switch (builtin->GetKind()) { + case ast::BuiltinType::Kind::Boolean: + return sql::SqlTypeId::Boolean; + case ast::BuiltinType::Kind::Integer: + return sql::SqlTypeId::Integer; + case ast::BuiltinType::Kind::Real: + return sql::SqlTypeId::Real; + case ast::BuiltinType::Kind::Decimal: + return sql::SqlTypeId::Decimal; + case ast::BuiltinType::Kind::StringVal: + return sql::SqlTypeId::Varchar; + case ast::BuiltinType::Kind::Date: + return sql::SqlTypeId::Date; + case ast::BuiltinType::Kind::Timestamp: + return sql::SqlTypeId::Timestamp; + default: + UNREACHABLE("Invalid type"); + } +} + /* ---------------------------------------------------------------------------- Code Generation: Integer-Variant For-Loops ---------------------------------------------------------------------------- */ diff --git a/src/include/execution/ast/udf/node_types.h b/src/include/execution/ast/udf/node_types.h new file mode 100644 index 0000000000..44cf92d37d --- /dev/null +++ b/src/include/execution/ast/udf/node_types.h @@ -0,0 +1,24 @@ +namespace noisepage::execution::ast::udf { + +/** Enumerates all (instantiable) AST node types */ +enum class NodeType { + VALUE_EXPR, + IS_NULL_EXPR, + VARIABLE_EXPR, + MEMBER_EXPR, + BINARY_EXPR, + CALL_EXPR, + SEQ_STMT, + DECL_STMT, + IF_STMT, + FORI_STMT, + FORS_STMT, + WHILE_STMT, + RET_STMT, + ASSIGN_STMT, + SQL_STMT, + DYNAMIC_SQL_STMT, + FUNCTION +}; + +} // namespace noisepage::execution::ast::udf diff --git a/src/include/execution/ast/udf/udf_ast_nodes.h b/src/include/execution/ast/udf/udf_ast_nodes.h index 422e3f1d1e..7fd99d85d4 100644 --- a/src/include/execution/ast/udf/udf_ast_nodes.h +++ b/src/include/execution/ast/udf/udf_ast_nodes.h @@ -8,20 +8,33 @@ #include "parser/expression/constant_value_expression.h" #include "parser/expression_defs.h" +#include "parser/parse_result.h" +#include "execution/ast/udf/node_types.h" #include "execution/ast/udf/udf_ast_node_visitor.h" #include "execution/sql/value.h" namespace noisepage::execution::ast::udf { +/** + * Get the string representation of a node type. + * @param type The node type + * @return The string representation + */ +std::string NodeTypeToShortString(NodeType type); + /** * The AbstractAST class serves as a base class for all AST nodes. */ class AbstractAST { public: /** - * Destroy the AST node. + * Construct a new AbstractAST node instance. + * @param type The type of the node */ + explicit AbstractAST(NodeType type) : type_{type} {} + + /** Destroy the AST node. */ virtual ~AbstractAST() = default; /** @@ -29,17 +42,28 @@ class AbstractAST { * @param visitor The visitor */ virtual void Accept(ASTNodeVisitor *visitor) { visitor->Visit(this); } + + /** @return The type of the AST node */ + NodeType GetType() const { return type_; } + + private: + /** The type of the AST node */ + NodeType type_; }; /** - * The StmtAST class serves as the base class for all statement nodes. + * The ExprAST class serves as the base class for all expression nodes. */ -class StmtAST : public AbstractAST { +class ExprAST : public AbstractAST { public: /** - * Destroy the AST node. + * Construct a new ExprAST instance. + * @param type The type of the expression node */ - ~StmtAST() override = default; + explicit ExprAST(NodeType type) : AbstractAST{type} {} + + /** Destroy the AST node. */ + ~ExprAST() override = default; /** * AST visitor pattern. @@ -49,14 +73,18 @@ class StmtAST : public AbstractAST { }; /** - * The ExprAST class serves as the base class for all expression nodes. + * The StmtAST class serves as the base class for all statement nodes. */ -class ExprAST : public StmtAST { +class StmtAST : public AbstractAST { public: /** - * Destroy the AST node. + * Construct a new StmtAST instance. + * @param type The type of the statement node */ - ~ExprAST() override = default; + explicit StmtAST(NodeType type) : AbstractAST{type} {} + + /** Destroy the AST node. */ + ~StmtAST() override = default; /** * AST visitor pattern. @@ -74,7 +102,8 @@ class ValueExprAST : public ExprAST { * Construct a new ValueExprAST instance. * @param value The AbstractExpression that represents the value */ - explicit ValueExprAST(std::unique_ptr &&value) : value_(std::move(value)) {} + explicit ValueExprAST(std::unique_ptr &&value) + : ExprAST{NodeType::VALUE_EXPR}, value_(std::move(value)) {} /** * AST visitor pattern. @@ -104,7 +133,7 @@ class IsNullExprAST : public ExprAST { * @param child The child expression */ IsNullExprAST(bool is_null_check, std::unique_ptr &&child) - : is_null_check_{is_null_check}, child_{std::move(child)} {} + : ExprAST{NodeType::IS_NULL_EXPR}, is_null_check_{is_null_check}, child_{std::move(child)} {} /** * AST visitor pattern. @@ -138,7 +167,7 @@ class VariableExprAST : public ExprAST { * Construct a new VariableExprAST instance. * @param name The name of the variable */ - explicit VariableExprAST(std::string name) : name_{std::move(name)} {} + explicit VariableExprAST(std::string name) : ExprAST{NodeType::VARIABLE_EXPR}, name_{std::move(name)} {} /** * AST visitor pattern. @@ -165,7 +194,7 @@ class MemberExprAST : public ExprAST { * @param field The name of the field in the structure */ MemberExprAST(std::unique_ptr &&object, std::string field) - : object_{std::move(object)}, field_(std::move(field)) {} + : ExprAST{NodeType::MEMBER_EXPR}, object_{std::move(object)}, field_(std::move(field)) {} /** * AST visitor pattern. @@ -202,7 +231,7 @@ class BinaryExprAST : public ExprAST { * @param rhs The expression on the right-hand side of the operation */ BinaryExprAST(parser::ExpressionType op, std::unique_ptr &&lhs, std::unique_ptr &&rhs) - : op_{op}, lhs_{std::move(lhs)}, rhs_{std::move(rhs)} {} + : ExprAST{NodeType::BINARY_EXPR}, op_{op}, lhs_{std::move(lhs)}, rhs_{std::move(rhs)} {} /** * AST visitor pattern. @@ -247,7 +276,7 @@ class CallExprAST : public ExprAST { * @param args The arguments to the function call */ CallExprAST(std::string callee, std::vector> &&args) - : callee_{std::move(callee)}, args_{std::move(args)} {} + : ExprAST{NodeType::CALL_EXPR}, callee_{std::move(callee)}, args_{std::move(args)} {} /** * AST visitor pattern. @@ -281,7 +310,8 @@ class SeqStmtAST : public StmtAST { * Construct a new SeqStmtAST instance. * @param statements The collection of statements in the sequence */ - explicit SeqStmtAST(std::vector> &&statements) : statements_(std::move(statements)) {} + explicit SeqStmtAST(std::vector> &&statements) + : StmtAST{NodeType::SEQ_STMT}, statements_(std::move(statements)) {} /** * AST visitor pattern. @@ -313,7 +343,7 @@ class DeclStmtAST : public StmtAST { * @param initial The initial value in the declaration */ DeclStmtAST(std::string name, sql::SqlTypeId type, std::unique_ptr &&initial) - : name_{std::move(name)}, type_(type), initial_{std::move(initial)} {} + : StmtAST{NodeType::DECL_STMT}, name_{std::move(name)}, type_(type), initial_{std::move(initial)} {} /** * AST visitor pattern. @@ -357,7 +387,10 @@ class IfStmtAST : public StmtAST { */ IfStmtAST(std::unique_ptr &&cond_expr, std::unique_ptr &&then_stmt, std::unique_ptr &&else_stmt) - : cond_expr_{std::move(cond_expr)}, then_stmt_{std::move(then_stmt)}, else_stmt_{std::move(else_stmt)} {} + : StmtAST{NodeType::IF_STMT}, + cond_expr_{std::move(cond_expr)}, + then_stmt_{std::move(then_stmt)}, + else_stmt_{std::move(else_stmt)} {} /** * AST visitor pattern. @@ -418,7 +451,8 @@ class ForIStmtAST : public StmtAST { */ ForIStmtAST(std::string variable, std::unique_ptr lower, std::unique_ptr upper, std::unique_ptr step, std::unique_ptr body) - : variable_{std::move(variable)}, + : StmtAST{NodeType::FORI_STMT}, + variable_{std::move(variable)}, lower_{std::move(lower)}, upper_{std::move(upper)}, step_{std::move(step)}, @@ -485,7 +519,10 @@ class ForSStmtAST : public StmtAST { */ ForSStmtAST(std::vector &&variables, std::unique_ptr &&query, std::unique_ptr body) - : variables_{std::move(variables)}, query_{std::move(query)}, body_{std::move(body)} {} + : StmtAST{NodeType::FORS_STMT}, + variables_{std::move(variables)}, + query_{std::move(query)}, + body_{std::move(body)} {} /** * AST visitor pattern. @@ -530,7 +567,7 @@ class WhileStmtAST : public StmtAST { * @param body The loop body statement */ WhileStmtAST(std::unique_ptr &&condition, std::unique_ptr &&body) - : condition_{std::move(condition)}, body_{std::move(body)} {} + : StmtAST{NodeType::WHILE_STMT}, condition_{std::move(condition)}, body_{std::move(body)} {} /** * AST visitor pattern. @@ -567,7 +604,8 @@ class RetStmtAST : public StmtAST { * Construct a new RetStmtAST instance. * @param ret_expr The `return` expression */ - explicit RetStmtAST(std::unique_ptr &&ret_expr) : ret_expr_{std::move(ret_expr)} {} + explicit RetStmtAST(std::unique_ptr &&ret_expr) + : StmtAST{NodeType::RET_STMT}, ret_expr_{std::move(ret_expr)} {} /** * AST visitor pattern. @@ -589,7 +627,7 @@ class RetStmtAST : public StmtAST { /** * The AssignStmtAST class represents an assignment statement. */ -class AssignStmtAST : public ExprAST { +class AssignStmtAST : public StmtAST { public: /** * Construct a new AssignStmtAST instance. @@ -597,7 +635,7 @@ class AssignStmtAST : public ExprAST { * @param src The expression that represents the source of the assignment */ AssignStmtAST(std::unique_ptr &&dst, std::unique_ptr &&src) - : dst_{std::move(dst)}, src_{std::move(src)} {} + : StmtAST{NodeType::ASSIGN_STMT}, dst_{std::move(dst)}, src_{std::move(src)} {} /** * AST visitor pattern. @@ -637,7 +675,7 @@ class SQLStmtAST : public StmtAST { * to which results of the query are bound */ SQLStmtAST(std::unique_ptr &&query, std::vector &&variables) - : query_{std::move(query)}, variables_{std::move(variables)} {} + : StmtAST{NodeType::SQL_STMT}, query_{std::move(query)}, variables_{std::move(variables)} {} /** * AST visitor pattern. @@ -673,7 +711,7 @@ class DynamicSQLStmtAST : public StmtAST { * @param name The name of the variable to which results are bound */ DynamicSQLStmtAST(std::unique_ptr &&query, std::string name) - : query_{std::move(query)}, name_{std::move(name)} {} + : StmtAST{NodeType::DYNAMIC_SQL_STMT}, query_{std::move(query)}, name_{std::move(name)} {} /** * AST visitor pattern. @@ -708,7 +746,8 @@ class FunctionAST : public AbstractAST { */ FunctionAST(std::unique_ptr &&body, std::vector parameter_names, std::vector parameter_types) - : body_{std::move(body)}, + : AbstractAST{NodeType::FUNCTION}, + body_{std::move(body)}, parameter_names_{std::move(parameter_names)}, parameter_types_{std::move(parameter_types)} { NOISEPAGE_ASSERT(parameter_names_.size() == parameter_types_.size(), "Parameter Name and Type Mismatch"); diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h index 3d6bdf199d..368a78b59b 100644 --- a/src/include/execution/compiler/udf/udf_codegen.h +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -269,6 +269,13 @@ class UdfCodegen : ast::udf::ASTNodeVisitor { */ sql::SqlTypeId ResolveTypeForIdentifierExpression(const ast::IdentifierExpr *expr) const; + /** + * Resolve the type of a call expression in a function call argument. + * @param expr The call expression + * @return The resolved type of the call expression + */ + sql::SqlTypeId ResolveTypeForCallExpression(const ast::CallExpr *expr) const; + /* -------------------------------------------------------------------------- Code Generation: For-S Loops -------------------------------------------------------------------------- */ diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index 0c78a2fc42..9a7a30fb7f 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -442,6 +442,9 @@ std::optional> PLpgSQLParser::TryP } args.push_back(std::move(*argument)); } + for (const auto &a : args) { + std::cout << execution::ast::udf::NodeTypeToShortString(a->GetType()) << std::endl; + } return std::make_optional( std::make_unique(func_expr->GetFuncName(), std::move(args))); } From fb90d2c92351f8e187e5ad337f59de587bf8895d Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Sun, 26 Sep 2021 12:50:31 -0400 Subject: [PATCH 134/139] remove print --- src/parser/udf/plpgsql_parser.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp index 9a7a30fb7f..0c78a2fc42 100644 --- a/src/parser/udf/plpgsql_parser.cpp +++ b/src/parser/udf/plpgsql_parser.cpp @@ -442,9 +442,6 @@ std::optional> PLpgSQLParser::TryP } args.push_back(std::move(*argument)); } - for (const auto &a : args) { - std::cout << execution::ast::udf::NodeTypeToShortString(a->GetType()) << std::endl; - } return std::make_optional( std::make_unique(func_expr->GetFuncName(), std::move(args))); } From 5206e08842b3919a95e15a20c9c63c411fd07d79 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Tue, 28 Sep 2021 09:33:16 -0400 Subject: [PATCH 135/139] fix issue that only manifests in release --- src/execution/ast/udf/udf_ast_nodes.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/execution/ast/udf/udf_ast_nodes.cpp b/src/execution/ast/udf/udf_ast_nodes.cpp index 4e9a3c7ee9..b4ac41615c 100644 --- a/src/execution/ast/udf/udf_ast_nodes.cpp +++ b/src/execution/ast/udf/udf_ast_nodes.cpp @@ -43,6 +43,7 @@ std::string NodeTypeToShortString(NodeType type) { return "FUNCTION"; default: NOISEPAGE_ASSERT(false, "Impossible node type"); + return "INVALID"; } } From 28b266ace0fc0b163bde29bc86134de6368d46bf Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 18 Nov 2021 11:59:16 -0500 Subject: [PATCH 136/139] test run with tpch runner --- CMakeLists.txt | 9 +++-- benchmark/runner/tpch_runner.cpp | 58 +++++++++++++++++++++++++------- 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6f5b754caa..ede776c5c7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -868,7 +868,7 @@ file(GLOB_RECURSE function(add_test_util_lib TYPE) string(TOLOWER ${TYPE} TYPE_LOWER) add_library(noisepage_test_util_${TYPE_LOWER} ${TYPE} ${NOISEPAGE_TEST_UTIL_SRCS}) - add_custom_command(TARGET noisepage_test_util_${TYPE_LOWER} DEPENDS gtest gtest_main gmock gmock_main) + add_custom_command(TARGET noisepage_test_util_${TYPE_LOWER} DEPENDS gtest gtest_main) target_compile_options(noisepage_test_util_${TYPE_LOWER} PRIVATE "-Werror" "-Wall") # Inject the source directory path into the translation units for test utility lib target_compile_definitions(noisepage_test_util_${TYPE_LOWER} PRIVATE NOISEPAGE_BUILD_ROOT=${CMAKE_BINARY_DIR}) @@ -878,7 +878,7 @@ function(add_test_util_lib TYPE) ${CMAKE_BINARY_DIR}/_deps/src/googletest/googletest/include/ ) target_link_libraries(noisepage_test_util_${TYPE_LOWER} PUBLIC - ${CMAKE_BINARY_DIR}/lib/libgtest.a ${CMAKE_BINARY_DIR}/lib/libgmock.a + ${CMAKE_BINARY_DIR}/lib/libgtest.a util_${TYPE_LOWER} pqxx) set_target_properties(noisepage_test_util_${TYPE_LOWER} PROPERTIES CXX_EXTENSIONS OFF UNITY_BUILD ${NOISEPAGE_UNITY_BUILD}) endfunction() @@ -917,7 +917,6 @@ function(add_noisepage_test add_executable(${TEST_NAME} ${EXCLUDE_OPTION} ${TEST_SOURCES}) target_compile_options(${TEST_NAME} PRIVATE "-Werror" "-Wall" "-fvisibility=hidden") - target_link_libraries(${TEST_NAME} PRIVATE ${CMAKE_BINARY_DIR}/lib/libgmock_main.a) if (${NOISEPAGE_ENABLE_SHARED}) target_link_libraries(${TEST_NAME} PRIVATE noisepage_test_util_shared) else () @@ -1119,10 +1118,10 @@ file(GLOB_RECURSE ) add_library(noisepage_benchmark_util STATIC ${NOISEPAGE_BENCHMARK_UTIL_SRCS}) -add_custom_command(TARGET noisepage_benchmark_util DEPENDS gtest gtest_main gmock gmock_main) +add_custom_command(TARGET noisepage_benchmark_util DEPENDS gtest gtest_main) target_compile_options(noisepage_benchmark_util PRIVATE "-Werror" "-Wall") target_include_directories(noisepage_benchmark_util PUBLIC ${PROJECT_SOURCE_DIR}/benchmark/include) -target_link_libraries(noisepage_benchmark_util PUBLIC ${CMAKE_BINARY_DIR}/lib/libgmock_main.a noisepage_test_util_static benchmark) +target_link_libraries(noisepage_benchmark_util PUBLIC noisepage_test_util_static benchmark) set_target_properties(noisepage_benchmark_util PROPERTIES CXX_EXTENSIONS OFF) set(NOISEPAGE_BENCHMARKS "") diff --git a/benchmark/runner/tpch_runner.cpp b/benchmark/runner/tpch_runner.cpp index 2e9e6bb239..680a10a8bb 100644 --- a/benchmark/runner/tpch_runner.cpp +++ b/benchmark/runner/tpch_runner.cpp @@ -7,25 +7,51 @@ #include "test_util/fs_util.h" #include "test_util/tpch/workload.h" +/** + * The local paths to the data directories. + * https://github.com/malin1993ml/tpl_tables and "bash gen_tpch.sh 0.1". + */ +static constexpr const char TPCH_TABLE_ROOT[] = "/home/turing/dev/tpl-tables/tables/"; +static constexpr const char SSB_TABLE_ROOT[] = "/home/turing/dev/tpl-tables/tables/"; +static constexpr const char TPCH_DATABASE_NAME[] = "tpch_runner_db"; + namespace noisepage::runner { + +/** + * TPCHRunner runs TPCH benchmarks. + */ class TPCHRunner : public benchmark::Fixture { public: - const int8_t total_num_threads_ = 4; // defines the number of terminals (workers threads) - const uint64_t execution_us_per_worker_ = 20000000; // Time (us) to run per terminal (worker thread) + /** Defines the number of terminals (workers threads) */ + const int8_t total_num_threads_ = 4; + + /** Time (us) to run per terminal (worker thread) */ + const uint64_t execution_us_per_worker_ = 20000000; + + /** The average intervals in microseconds */ std::vector avg_interval_us_ = {10, 20, 50, 100, 200, 500, 1000}; + + /** The execution mode for the execution engine */ const execution::vm::ExecutionMode mode_ = execution::vm::ExecutionMode::Interpret; - const bool single_test_run_ = false; + /** Flag indicating if only a single test run should be run */ + const bool single_test_run_ = true; + + /** The main database instance */ std::unique_ptr db_main_; + + /** The workload with loaded data and queries */ std::unique_ptr workload_; - // To get tpl_tables, https://github.com/malin1993ml/tpl_tables and "bash gen_tpch.sh 0.1". - const std::string tpch_table_root_ = "../../../tpl_tables/tables/"; - const std::string ssb_dir_ = "../../../SSB_Table_Generator/ssb_tables/"; - const std::string tpch_database_name_ = "tpch_runner_db"; + /** Local paths to data */ + const std::string tpch_table_root_{TPCH_TABLE_ROOT}; + const std::string ssb_dir_{SSB_TABLE_ROOT}; + const std::string tpch_database_name_{TPCH_DATABASE_NAME}; + /** The benchmark type */ tpch::Workload::BenchmarkType type_ = tpch::Workload::BenchmarkType::TPCH; + /** Setup the database instance for benchmark. */ void SetUp(const benchmark::State &state) final { auto db_main_builder = DBMain::Builder() .SetUseGC(true) @@ -38,7 +64,6 @@ class TPCHRunner : public benchmark::Fixture { .SetRecordBufferSegmentSize(1000000) .SetRecordBufferSegmentReuse(1000000) .SetBytecodeHandlersPath(common::GetBinaryArtifactPath("bytecode_handlers_ir.bc")); - db_main_ = db_main_builder.Build(); auto metrics_manager = db_main_->GetMetricsManager(); @@ -46,6 +71,7 @@ class TPCHRunner : public benchmark::Fixture { metrics_manager->EnableMetric(metrics::MetricsComponent::EXECUTION_PIPELINE); } + /** Teardown the database instance after a benchmark */ void TearDown(const benchmark::State &state) final { // free db main here so we don't need to use the loggers anymore db_main_.reset(); @@ -82,11 +108,13 @@ BENCHMARK_DEFINE_F(TPCHRunner, Runner)(benchmark::State &state) { } auto total_query_num = workload_->GetQueryNum() + 1; - for (uint32_t query_num = query_num_start; query_num < total_query_num; query_num += 4) - for (auto num_threads = num_thread_start; num_threads <= total_num_threads_; num_threads += 3) - for (uint32_t repeat = 0; repeat < repeat_num; ++repeat) + for (uint32_t query_num = query_num_start; query_num < total_query_num; query_num += 4) { + for (auto num_threads = num_thread_start; num_threads <= total_num_threads_; num_threads += 3) { + for (uint32_t repeat = 0; repeat < repeat_num; ++repeat) { for (auto avg_interval_us : avg_interval_us_) { - std::this_thread::sleep_for(std::chrono::seconds(2)); // Let GC clean up + // Let GC clean up + std::this_thread::sleep_for(std::chrono::seconds(2)); + common::WorkerPool thread_pool{static_cast(num_threads), {}}; thread_pool.Startup(); @@ -99,10 +127,14 @@ BENCHMARK_DEFINE_F(TPCHRunner, Runner)(benchmark::State &state) { thread_pool.WaitUntilAllFinished(); thread_pool.Shutdown(); } + } + } + } - // free the workload here so we don't need to use the loggers anymore + // Free the workload here so we don't need to use the loggers anymore workload_.reset(); } BENCHMARK_REGISTER_F(TPCHRunner, Runner)->Unit(benchmark::kMillisecond)->UseManualTime()->Iterations(1); + } // namespace noisepage::runner From eb4cc4b71db5cfd4e17d86057e31fd983af79359 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 18 Nov 2021 22:41:56 -0500 Subject: [PATCH 137/139] setup for procbench runner --- benchmark/runner/procbench_runner.cpp | 124 +++++++++++++ .../execution/exec/execution_settings.h | 5 + .../test_util/procbench/procbench_query.h | 52 ++++++ test/include/test_util/procbench/workload.h | 99 ++++++++++ test/test_util/procbench/procbench_query.cpp | 166 +++++++++++++++++ test/test_util/procbench/workload.cpp | 170 ++++++++++++++++++ 6 files changed, 616 insertions(+) create mode 100644 benchmark/runner/procbench_runner.cpp create mode 100644 test/include/test_util/procbench/procbench_query.h create mode 100644 test/include/test_util/procbench/workload.h create mode 100644 test/test_util/procbench/procbench_query.cpp create mode 100644 test/test_util/procbench/workload.cpp diff --git a/benchmark/runner/procbench_runner.cpp b/benchmark/runner/procbench_runner.cpp new file mode 100644 index 0000000000..d4309fae50 --- /dev/null +++ b/benchmark/runner/procbench_runner.cpp @@ -0,0 +1,124 @@ +#include "benchmark/benchmark.h" +#include "common/scoped_timer.h" +#include "common/worker_pool.h" +#include "execution/execution_util.h" +#include "execution/vm/module.h" +#include "main/db_main.h" +#include "test_util/fs_util.h" +#include "test_util/procbench/workload.h" + +/** + * The local paths to the data directories. + * https://github.com/malin1993ml/tpl_tables and "bash gen_tpch.sh 0.1". + */ +static constexpr const char PROCBENCH_TABLE_ROOT[] = "/home/turing/dev/tpl-tables/tpcds-tables/"; +static constexpr const char PROCBENCH_DATABASE_NAME[] = "procbench_runner_db"; + +namespace noisepage::runner { + +/** + * ProcbenchRunner runs SQL ProcBench benchmarks. + */ +class ProcbenchRunner : public benchmark::Fixture { + public: + /** Defines the number of terminals (workers threads) */ + const int8_t total_num_threads_ = 1; + + /** Time (us) to run per terminal (worker thread) */ + const uint64_t execution_us_per_worker_ = 20000000; + + /** The average intervals in microseconds */ + std::vector avg_interval_us_ = {10, 20, 50, 100, 200, 500, 1000}; + + /** The execution mode for the execution engine */ + const execution::vm::ExecutionMode exec_mode_ = execution::vm::ExecutionMode::Interpret; + + /** Flag indicating if only a single test run should be run */ + const bool single_test_run_ = true; + + /** The main database instance */ + std::unique_ptr db_main_; + + /** The workload with loaded data and queries */ + std::unique_ptr workload_; + + /** Local paths to data */ + const std::string procbench_table_root_{PROCBENCH_TABLE_ROOT}; + const std::string procbench_database_name_{PROCBENCH_DATABASE_NAME}; + + /** Setup the database instance for benchmark. */ + void SetUp(const benchmark::State &state) final { + auto db_main_builder = DBMain::Builder() + .SetUseGC(true) + .SetUseCatalog(true) + .SetUseGCThread(true) + .SetUseMetrics(true) + .SetUseMetricsThread(true) + .SetBlockStoreSize(1000000) + .SetBlockStoreReuse(1000000) + .SetRecordBufferSegmentSize(1000000) + .SetRecordBufferSegmentReuse(1000000) + .SetBytecodeHandlersPath(common::GetBinaryArtifactPath("bytecode_handlers_ir.bc")); + db_main_ = db_main_builder.Build(); + + auto metrics_manager = db_main_->GetMetricsManager(); + metrics_manager->SetMetricSampleRate(metrics::MetricsComponent::EXECUTION_PIPELINE, 100); + metrics_manager->EnableMetric(metrics::MetricsComponent::EXECUTION_PIPELINE); + } + + /** Teardown the database instance after a benchmark */ + void TearDown(const benchmark::State &state) final { + // free db main here so we don't need to use the loggers anymore + db_main_.reset(); + } +}; + +// NOLINTNEXTLINE +BENCHMARK_DEFINE_F(ProcbenchRunner, Runner)(benchmark::State &state) { + // Load the ProcBench tables and compile the queries + workload_ = std::make_unique(common::ManagedPointer(db_main_), procbench_database_name_, + procbench_table_root_); + + int8_t num_thread_start; + uint32_t query_num_start, repeat_num; + if (single_test_run_) { + query_num_start = workload_->GetQueryCount(); + num_thread_start = total_num_threads_; + repeat_num = 1; + } else { + query_num_start = 1; + num_thread_start = 1; + repeat_num = 2; + } + + auto total_query_num = workload_->GetQueryCount() + 1; + for (uint32_t query_num = query_num_start; query_num < total_query_num; query_num += 4) { + for (auto num_threads = num_thread_start; num_threads <= total_num_threads_; num_threads += 3) { + for (uint32_t repeat = 0; repeat < repeat_num; ++repeat) { + for (auto avg_interval_us : avg_interval_us_) { + // Let GC clean up + std::this_thread::sleep_for(std::chrono::seconds(2)); + + common::WorkerPool thread_pool{static_cast(num_threads), {}}; + thread_pool.Startup(); + + for (int8_t i = 0; i < num_threads; i++) { + thread_pool.SubmitTask([this, i, avg_interval_us, query_num] { + workload_->Execute(i, execution_us_per_worker_, avg_interval_us, query_num, exec_mode_); + }); + } + + thread_pool.WaitUntilAllFinished(); + thread_pool.Shutdown(); + } + } + } + } + + // Free the workload here so we don't need to use the loggers anymore + workload_.reset(); +} + +BENCHMARK_REGISTER_F(ProcbenchRunner, Runner)->Unit(benchmark::kMillisecond)->UseManualTime()->Iterations(1); + +} // namespace noisepage::runner diff --git a/src/include/execution/exec/execution_settings.h b/src/include/execution/exec/execution_settings.h index a6f28a559f..84dfc0741a 100644 --- a/src/include/execution/exec/execution_settings.h +++ b/src/include/execution/exec/execution_settings.h @@ -32,6 +32,10 @@ namespace noisepage::tpch { class Workload; } // namespace noisepage::tpch +namespace noisepage::procbench { +class Workload; +} // namespace noisepage::procbench + namespace noisepage::selfdriving { namespace pilot { class PilotUtil; @@ -109,6 +113,7 @@ class EXPORT ExecutionSettings { // MiniRunners needs to set query_identifier and pipeline_operating_units_. friend class noisepage::runner::ExecutionRunners; friend class noisepage::tpch::Workload; + friend class noisepage::procbench::Workload; friend class noisepage::execution::SqlBasedTest; friend class noisepage::optimizer::IdxJoinTest_SimpleIdxJoinTest_Test; friend class noisepage::optimizer::IdxJoinTest_MultiPredicateJoin_Test; diff --git a/test/include/test_util/procbench/procbench_query.h b/test/include/test_util/procbench/procbench_query.h new file mode 100644 index 0000000000..1d8eac1c46 --- /dev/null +++ b/test/include/test_util/procbench/procbench_query.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include + +#include "catalog/catalog_accessor.h" +#include "execution/compiler/executable_query.h" + +namespace noisepage::procbench { + +/** ProcbenchQuery defines queries for SQL Procbench benchmarks. */ +class ProcbenchQuery { + public: + /// Static functions to generate executable queries for ProcBench benchmark. Query plans are hard coded. + // static std::tuple, + // std::unique_ptr> MakeExecutableQ1(const std::unique_ptr + // &accessor, + // const execution::exec::ExecutionSettings &exec_settings); + // static std::tuple, + // std::unique_ptr> MakeExecutableQ4(const std::unique_ptr + // &accessor, + // const execution::exec::ExecutionSettings &exec_settings); + // static std::tuple, + // std::unique_ptr> MakeExecutableQ5(const std::unique_ptr + // &accessor, + // const execution::exec::ExecutionSettings &exec_settings); + // static std::tuple, + // std::unique_ptr> MakeExecutableQ6(const std::unique_ptr + // &accessor, + // const execution::exec::ExecutionSettings &exec_settings); + // static std::tuple, + // std::unique_ptr> MakeExecutableQ7(const std::unique_ptr + // &accessor, + // const execution::exec::ExecutionSettings &exec_settings); + // static std::tuple, + // std::unique_ptr> MakeExecutableQ11(const std::unique_ptr + // &accessor, + // const execution::exec::ExecutionSettings &exec_settings); + // static std::tuple, + // std::unique_ptr> MakeExecutableQ16(const std::unique_ptr + // &accessor, + // const execution::exec::ExecutionSettings &exec_settings); + // static std::tuple, + // std::unique_ptr> MakeExecutableQ18(const std::unique_ptr + // &accessor, + // const execution::exec::ExecutionSettings &exec_settings); + // static std::tuple, + // std::unique_ptr> MakeExecutableQ19(const std::unique_ptr + // &accessor, + // const execution::exec::ExecutionSettings &exec_settings); +}; +} // namespace noisepage::procbench diff --git a/test/include/test_util/procbench/workload.h b/test/include/test_util/procbench/workload.h new file mode 100644 index 0000000000..16adf0ed1f --- /dev/null +++ b/test/include/test_util/procbench/workload.h @@ -0,0 +1,99 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "catalog/catalog_accessor.h" +#include "catalog/catalog_defs.h" +#include "common/managed_pointer.h" +#include "execution/compiler/executable_query.h" +#include "execution/exec/execution_settings.h" +#include "execution/vm/module.h" + +namespace noisepage::execution::exec { +class ExecutionContext; +} + +namespace noisepage::catalog { +class Catalog; +} + +namespace noisepage::transaction { +class TransactionManager; +} + +namespace noisepage { +class DBMain; +} + +namespace noisepage::procbench { + +/** + * Class that can load the ProcBench tables, compile the + * ProcBench queries, and execute the ProcBench workload. + */ +class Workload { + public: + /** + * Construct a new Workload instance. + * @param db_main The database instance + * @param db_name The name of the database + * @param table_root The root of the table data directory + */ + Workload(common::ManagedPointer db_main, const std::string &db_name, const std::string &table_root); + + /** + * Function to invoke for a single worker thread to invoke the ProcBench queries. + * @param worker_id 1-indexed thread ID + * @param execution_us_per_worker + * @param avg_interval_us + * @param query_id The identifier for the query to invoke + * @param exec_mode The execution mode + */ + void Execute(int8_t worker_id, uint64_t execution_us_per_worker, uint64_t avg_interval_us, uint32_t query_id, + execution::vm::ExecutionMode exec_mode); + + /** @return The number of queries in the workload. */ + uint32_t GetQueryCount() { return query_and_plan_.size(); } + + private: + /** + * Load the tables for the ProcBench benchmark. + * @param exec_ctx The execution context + * @param directory The name of the directory from which tables are loaded + */ + void LoadTables(execution::exec::ExecutionContext *exec_ctx, const std::string &directory); + + /** + * Load the queries for the ProcBench benchmark. + * @param accessor The catalog accessor instance + */ + void LoadQueries(const std::unique_ptr &accessor); + + private: + /** The database server instance */ + common::ManagedPointer db_main_; + /** The block store */ + common::ManagedPointer block_store_; + /** The catalog instance */ + common::ManagedPointer catalog_; + /** The transaction manager */ + common::ManagedPointer txn_manager_; + /** The database OID */ + catalog::db_oid_t db_oid_; + /** The namespace OID */ + catalog::namespace_oid_t ns_oid_; + /** Execution settings for all executed queries */ + execution::exec::ExecutionSettings exec_settings_{}; + /** The catalog accessor */ + std::unique_ptr accessor_; + /** The collection of executable queries and associated plans */ + std::vector< + std::tuple, std::unique_ptr>> + query_and_plan_; +}; + +} // namespace noisepage::procbench diff --git a/test/test_util/procbench/procbench_query.cpp b/test/test_util/procbench/procbench_query.cpp new file mode 100644 index 0000000000..7db0e9b2b9 --- /dev/null +++ b/test/test_util/procbench/procbench_query.cpp @@ -0,0 +1,166 @@ +#include "test_util/procbench/procbench_query.h" + +#include "catalog/catalog_accessor.h" +#include "execution/compiler/compilation_context.h" +#include "execution/compiler/expression_maker.h" +#include "execution/compiler/output_schema_util.h" +#include "execution/sql/sql_def.h" +#include "planner/plannodes/aggregate_plan_node.h" +#include "planner/plannodes/hash_join_plan_node.h" +#include "planner/plannodes/nested_loop_join_plan_node.h" +#include "planner/plannodes/order_by_plan_node.h" +#include "planner/plannodes/seq_scan_plan_node.h" + +namespace noisepage::procbench { + +// std::tuple, std::unique_ptr> +// ProcbenchQuery::MakeExecutableQ1(const std::unique_ptr &accessor, +// const execution::exec::ExecutionSettings &exec_settings) { +// execution::compiler::test::ExpressionMaker expr_maker; +// auto table_oid = accessor->GetTableOid("lineitem"); +// const auto &l_schema = accessor->GetSchema(table_oid); +// // Scan the table +// std::unique_ptr l_seq_scan; +// execution::compiler::test::OutputSchemaHelper l_seq_scan_out{0, &expr_maker}; +// { +// // Read all needed columns +// auto l_returnflag = expr_maker.CVE(l_schema.GetColumn("l_returnflag").Oid(), execution::sql::SqlTypeId::Varchar); +// auto l_linestatus = expr_maker.CVE(l_schema.GetColumn("l_linestatus").Oid(), execution::sql::SqlTypeId::Varchar); +// auto l_extendedprice = +// expr_maker.CVE(l_schema.GetColumn("l_extendedprice").Oid(), execution::sql::SqlTypeId::Double); +// auto l_discount = expr_maker.CVE(l_schema.GetColumn("l_discount").Oid(), execution::sql::SqlTypeId::Double); +// auto l_tax = expr_maker.CVE(l_schema.GetColumn("l_tax").Oid(), execution::sql::SqlTypeId::Double); +// auto l_quantity = expr_maker.CVE(l_schema.GetColumn("l_quantity").Oid(), execution::sql::SqlTypeId::Double); +// auto l_shipdate = expr_maker.CVE(l_schema.GetColumn("l_shipdate").Oid(), execution::sql::SqlTypeId::Date); +// std::vector col_oids = { +// l_schema.GetColumn("l_returnflag").Oid(), l_schema.GetColumn("l_linestatus").Oid(), +// l_schema.GetColumn("l_extendedprice").Oid(), l_schema.GetColumn("l_discount").Oid(), +// l_schema.GetColumn("l_tax").Oid(), l_schema.GetColumn("l_quantity").Oid(), +// l_schema.GetColumn("l_shipdate").Oid()}; +// // Make the output schema +// l_seq_scan_out.AddOutput("l_returnflag", l_returnflag); +// l_seq_scan_out.AddOutput("l_linestatus", l_linestatus); +// l_seq_scan_out.AddOutput("l_extendedprice", l_extendedprice); +// l_seq_scan_out.AddOutput("l_discount", l_discount); +// l_seq_scan_out.AddOutput("l_tax", l_tax); +// l_seq_scan_out.AddOutput("l_quantity", l_quantity); +// auto schema = l_seq_scan_out.MakeSchema(); +// // Make the predicate +// l_seq_scan_out.AddOutput("l_shipdate", l_shipdate); +// auto date_const = expr_maker.Constant(1998, 9, 2); +// auto predicate = expr_maker.ComparisonLt(l_shipdate, date_const); +// // Build +// planner::SeqScanPlanNode::Builder builder; +// l_seq_scan = builder.SetOutputSchema(std::move(schema)) +// .SetScanPredicate(predicate) +// .SetTableOid(table_oid) +// .SetColumnOids(std::move(col_oids)) +// .Build(); +// } +// // Make the aggregate +// std::unique_ptr agg; +// execution::compiler::test::OutputSchemaHelper agg_out{0, &expr_maker}; +// { +// // Read previous layer's output +// auto l_returnflag = l_seq_scan_out.GetOutput("l_returnflag"); +// auto l_linestatus = l_seq_scan_out.GetOutput("l_linestatus"); +// auto l_quantity = l_seq_scan_out.GetOutput("l_quantity"); +// auto l_extendedprice = l_seq_scan_out.GetOutput("l_extendedprice"); +// auto l_discount = l_seq_scan_out.GetOutput("l_discount"); +// auto l_tax = l_seq_scan_out.GetOutput("l_tax"); +// // Make the aggregate expressions +// auto sum_qty = expr_maker.AggSum(l_quantity); +// auto sum_base_price = expr_maker.AggSum(l_extendedprice); +// auto one_const = expr_maker.Constant(1.0f); +// auto disc_price = expr_maker.OpMul(l_extendedprice, expr_maker.OpMin(one_const, l_discount)); +// auto sum_disc_price = expr_maker.AggSum(disc_price); +// auto charge = expr_maker.OpMul(disc_price, expr_maker.OpSum(one_const, l_tax)); +// auto sum_charge = expr_maker.AggSum(charge); +// auto avg_qty = expr_maker.AggAvg(l_quantity); +// auto avg_price = expr_maker.AggAvg(l_extendedprice); +// auto avg_disc = expr_maker.AggAvg(l_discount); +// auto count_order = expr_maker.AggCount(expr_maker.Constant(1)); // Works as Count(*) +// // Add them to the helper. +// agg_out.AddGroupByTerm("l_returnflag", l_returnflag); +// agg_out.AddGroupByTerm("l_linestatus", l_linestatus); +// agg_out.AddAggTerm("sum_qty", sum_qty); +// agg_out.AddAggTerm("sum_base_price", sum_base_price); +// agg_out.AddAggTerm("sum_disc_price", sum_disc_price); +// agg_out.AddAggTerm("sum_charge", sum_charge); +// agg_out.AddAggTerm("avg_qty", avg_qty); +// agg_out.AddAggTerm("avg_price", avg_price); +// agg_out.AddAggTerm("avg_disc", avg_disc); +// agg_out.AddAggTerm("count_order", count_order); +// // Make the output schema +// agg_out.AddOutput("l_returnflag", agg_out.GetGroupByTermForOutput("l_returnflag")); +// agg_out.AddOutput("l_linestatus", agg_out.GetGroupByTermForOutput("l_linestatus")); +// agg_out.AddOutput("sum_qty", agg_out.GetAggTermForOutput("sum_qty")); +// agg_out.AddOutput("sum_base_price", agg_out.GetAggTermForOutput("sum_base_price")); +// agg_out.AddOutput("sum_disc_price", agg_out.GetAggTermForOutput("sum_disc_price")); +// agg_out.AddOutput("sum_charge", agg_out.GetAggTermForOutput("sum_charge")); +// agg_out.AddOutput("avg_qty", agg_out.GetAggTermForOutput("avg_qty")); +// agg_out.AddOutput("avg_price", agg_out.GetAggTermForOutput("avg_price")); +// agg_out.AddOutput("avg_disc", agg_out.GetAggTermForOutput("avg_disc")); +// agg_out.AddOutput("count_order", agg_out.GetAggTermForOutput("count_order")); +// auto schema = agg_out.MakeSchema(); +// // Build +// planner::AggregatePlanNode::Builder builder; +// agg = builder.SetOutputSchema(std::move(schema)) +// .AddGroupByTerm(l_returnflag) +// .AddGroupByTerm(l_linestatus) +// .AddAggregateTerm(sum_qty) +// .AddAggregateTerm(sum_base_price) +// .AddAggregateTerm(sum_disc_price) +// .AddAggregateTerm(sum_charge) +// .AddAggregateTerm(avg_qty) +// .AddAggregateTerm(avg_price) +// .AddAggregateTerm(avg_disc) +// .AddAggregateTerm(count_order) +// .AddChild(std::move(l_seq_scan)) +// .SetAggregateStrategyType(planner::AggregateStrategyType::HASH) +// .SetHavingClausePredicate(nullptr) +// .Build(); +// } + +// // Order By +// std::unique_ptr order_by; +// execution::compiler::test::OutputSchemaHelper order_by_out{0, &expr_maker}; +// { +// // Output Colums col1, col2, col1 + col2 +// auto l_returnflag = agg_out.GetOutput("l_returnflag"); +// auto l_linestatus = agg_out.GetOutput("l_linestatus"); +// auto sum_qty = agg_out.GetOutput("sum_qty"); +// auto sum_base_price = agg_out.GetOutput("sum_base_price"); +// auto sum_disc_price = agg_out.GetOutput("sum_disc_price"); +// auto sum_charge = agg_out.GetOutput("sum_charge"); +// auto avg_qty = agg_out.GetOutput("avg_qty"); +// auto avg_price = agg_out.GetOutput("avg_price"); +// auto avg_disc = agg_out.GetOutput("avg_disc"); +// auto count_order = agg_out.GetOutput("count_order"); +// order_by_out.AddOutput("l_returnflag", l_returnflag); +// order_by_out.AddOutput("l_linestatus", l_linestatus); +// order_by_out.AddOutput("sum_qty", sum_qty); +// order_by_out.AddOutput("sum_base_price", sum_base_price); +// order_by_out.AddOutput("sum_disc_price", sum_disc_price); +// order_by_out.AddOutput("sum_charge", sum_charge); +// order_by_out.AddOutput("avg_qty", avg_qty); +// order_by_out.AddOutput("avg_price", avg_price); +// order_by_out.AddOutput("avg_disc", avg_disc); +// order_by_out.AddOutput("count_order", count_order); +// auto schema = order_by_out.MakeSchema(); +// // Order By Clause +// planner::SortKey clause1{l_returnflag, optimizer::OrderByOrderingType::ASC}; +// planner::SortKey clause2{l_linestatus, optimizer::OrderByOrderingType::ASC}; +// // Build +// planner::OrderByPlanNode::Builder builder; +// order_by = builder.SetOutputSchema(std::move(schema)) +// .AddChild(std::move(agg)) +// .AddSortKey(clause1.first, clause1.second) +// .AddSortKey(clause2.first, clause2.second) +// .Build(); +// } +// auto query = execution::compiler::CompilationContext::Compile(*order_by, exec_settings, accessor.get()); +// return std::make_tuple(std::move(query), std::move(order_by)); +// } + +} // namespace noisepage::procbench diff --git a/test/test_util/procbench/workload.cpp b/test/test_util/procbench/workload.cpp new file mode 100644 index 0000000000..9fb423ea8b --- /dev/null +++ b/test/test_util/procbench/workload.cpp @@ -0,0 +1,170 @@ +#include "test_util/procbench/workload.h" + +#include +#include +#include + +#include "common/managed_pointer.h" +#include "execution/compiler/output_schema_util.h" +#include "execution/exec/execution_context_builder.h" +#include "execution/sql/value_util.h" +#include "execution/table_generator/table_generator.h" +#include "main/db_main.h" +#include "planner/plannodes/aggregate_plan_node.h" +#include "planner/plannodes/hash_join_plan_node.h" +#include "planner/plannodes/nested_loop_join_plan_node.h" +#include "planner/plannodes/order_by_plan_node.h" +#include "planner/plannodes/seq_scan_plan_node.h" +#include "test_util/procbench/procbench_query.h" +#include "test_util/ssb/star_schema_query.h" + +namespace noisepage::procbench { + +/** ProcBench table names */ +static const std::vector PROCBENCH_TABLE_NAMES{"call_center", + "catalog_page", + "catalog_returns_history", + "catalog_returns", + "catalog_sales_history", + "catalog_sales", + "customer_address", + "customer_demographics", + "customer", + "date_dim", + "household_demographics", + "income_band", + "inventory_history", + "inventory", + "item", + "promotion", + "reason", + "ship_mode", + "store_returns_history", + "store_returns", + "store_sales_history", + "store_sales", + "store", + "time_dim", + "warehouse", + "web_page", + "web_returns_history", + "web_returns", + "web_sales_history", + "web_sales", + "web_site"}; + +Workload::Workload(common::ManagedPointer db_main, const std::string &db_name, const std::string &table_root) { + // cache db main and members + db_main_ = db_main; + txn_manager_ = db_main_->GetTransactionLayer()->GetTransactionManager(); + block_store_ = db_main_->GetStorageLayer()->GetBlockStore(); + catalog_ = db_main_->GetCatalogLayer()->GetCatalog(); + txn_manager_ = db_main_->GetTransactionLayer()->GetTransactionManager(); + + auto txn = txn_manager_->BeginTransaction(); + + // Create database catalog and namespace + db_oid_ = catalog_->CreateDatabase(common::ManagedPointer(txn), db_name, true); + auto accessor = + catalog_->GetAccessor(common::ManagedPointer(txn), db_oid_, DISABLED); + ns_oid_ = accessor->GetDefaultNamespace(); + + // Enable counters and disable the parallel execution for this workload + exec_settings_.is_parallel_execution_enabled_ = false; + exec_settings_.is_counters_enabled_ = true; + + // Make the execution context + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithExecutionSettings(exec_settings_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) + .WithOutputCallback(execution::exec::ExecutionContext::NULL_OUTPUT_CALLBACK) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); + + // Create the ProcBench database + LoadTables(exec_ctx.get(), table_root); + // Compile all queries for the benchmark + LoadQueries(accessor); + + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); +} + +void Workload::LoadTables(execution::exec::ExecutionContext *exec_ctx, const std::string &directory) { + execution::sql::TableReader table_reader{exec_ctx, block_store_.Get(), ns_oid_}; + for (const auto &table_name : PROCBENCH_TABLE_NAMES) { + const std::string schema_path = fmt::format("{}{}.schema", directory, table_name); + const std::string data_path = fmt::format("{}{}.data", directory, table_name); + const auto num_rows = table_reader.ReadTable(schema_path, data_path); + EXECUTION_LOG_INFO("Wrote {} rows on table {}.", num_rows, table_name); + } +} + +void Workload::LoadQueries(const std::unique_ptr &accessor) { + // Executable query and plan node are stored as a tuple as the entry of vector + (void)accessor; + // query_and_plan_.emplace_back(TPCHQuery::MakeExecutableQ1(accessor, exec_settings_)); +} + +void Workload::Execute(int8_t worker_id, uint64_t execution_us_per_worker, uint64_t avg_interval_us, uint32_t query_id, + execution::vm::ExecutionMode mode) { + // Shuffle the queries randomly for each thread + const auto total_query_num = query_and_plan_.size(); + std::vector index{}; + index.resize(total_query_num); + std::iota(index.begin(), index.end(), 0); + std::shuffle(index.begin(), index.end(), std::mt19937(time(nullptr) + worker_id)); + + // Get the sleep time range distribution + std::mt19937 generator{}; + std::uniform_int_distribution distribution(avg_interval_us - avg_interval_us / 2, + avg_interval_us + avg_interval_us / 2); + + // Register to the metrics manager + db_main_->GetMetricsManager()->RegisterThread(); + uint32_t counter = 0; + uint64_t end_time = metrics::MetricsUtil::Now() + execution_us_per_worker; + while (metrics::MetricsUtil::Now() < end_time) { + // Executing all the queries on by one in round robin + auto txn = txn_manager_->BeginTransaction(); + auto accessor = + catalog_->GetAccessor(common::ManagedPointer(txn), db_oid_, DISABLED); + + auto output_schema = std::get<1>(query_and_plan_[index[counter]])->GetOutputSchema().Get(); + // Uncomment this line and change output.cpp:90 to EXECUTION_LOG_INFO to print output + // execution::exec::OutputPrinter printer(output_schema); + execution::exec::NoOpResultConsumer printer; + + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithExecutionSettings(exec_settings_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(common::ManagedPointer{output_schema}) + .WithOutputCallback(printer) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); + + std::get<0>(query_and_plan_[index[counter]]) + ->Run(common::ManagedPointer(exec_ctx), mode); + + // Only execute up to query_num number of queries for this thread in round-robin + counter = counter == query_id - 1 ? 0 : counter + 1; + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); + + // Sleep to create different execution frequency patterns + auto random_sleep_time = distribution(generator); + std::this_thread::sleep_for(std::chrono::microseconds(random_sleep_time)); + } + + // Unregister from the metrics manager + db_main_->GetMetricsManager()->UnregisterThread(); +} + +} // namespace noisepage::procbench From d729006ec7d3a2455d80f41e9b6becbf4664c659 Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Fri, 19 Nov 2021 11:05:04 -0500 Subject: [PATCH 138/139] slight modifications to table and schema readers, able to load all procbench tables --- benchmark/runner/procbench_runner.cpp | 76 ++++++++++--------- test/test_util/procbench/workload.cpp | 7 +- .../table_generator/table_reader.cpp | 4 +- .../execution/table_generator/schema_reader.h | 8 +- 4 files changed, 53 insertions(+), 42 deletions(-) diff --git a/benchmark/runner/procbench_runner.cpp b/benchmark/runner/procbench_runner.cpp index d4309fae50..e21a81319d 100644 --- a/benchmark/runner/procbench_runner.cpp +++ b/benchmark/runner/procbench_runner.cpp @@ -79,41 +79,47 @@ BENCHMARK_DEFINE_F(ProcbenchRunner, Runner)(benchmark::State &state) { workload_ = std::make_unique(common::ManagedPointer(db_main_), procbench_database_name_, procbench_table_root_); - int8_t num_thread_start; - uint32_t query_num_start, repeat_num; - if (single_test_run_) { - query_num_start = workload_->GetQueryCount(); - num_thread_start = total_num_threads_; - repeat_num = 1; - } else { - query_num_start = 1; - num_thread_start = 1; - repeat_num = 2; - } - - auto total_query_num = workload_->GetQueryCount() + 1; - for (uint32_t query_num = query_num_start; query_num < total_query_num; query_num += 4) { - for (auto num_threads = num_thread_start; num_threads <= total_num_threads_; num_threads += 3) { - for (uint32_t repeat = 0; repeat < repeat_num; ++repeat) { - for (auto avg_interval_us : avg_interval_us_) { - // Let GC clean up - std::this_thread::sleep_for(std::chrono::seconds(2)); - - common::WorkerPool thread_pool{static_cast(num_threads), {}}; - thread_pool.Startup(); - - for (int8_t i = 0; i < num_threads; i++) { - thread_pool.SubmitTask([this, i, avg_interval_us, query_num] { - workload_->Execute(i, execution_us_per_worker_, avg_interval_us, query_num, exec_mode_); - }); - } - - thread_pool.WaitUntilAllFinished(); - thread_pool.Shutdown(); - } - } - } - } + // int8_t num_thread_start; + // uint32_t query_num_start, repeat_num; + // if (single_test_run_) { + // query_num_start = workload_->GetQueryCount(); + // num_thread_start = total_num_threads_; + // repeat_num = 1; + // } else { + // query_num_start = 1; + // num_thread_start = 1; + // repeat_num = 2; + // } + + // auto total_query_num = workload_->GetQueryCount() + 1; + // for (uint32_t query_num = query_num_start; query_num < total_query_num; query_num += 4) { + // for (auto num_threads = num_thread_start; num_threads <= total_num_threads_; num_threads += 3) { + // for (uint32_t repeat = 0; repeat < repeat_num; ++repeat) { + // for (auto avg_interval_us : avg_interval_us_) { + // // Let GC clean up + // std::this_thread::sleep_for(std::chrono::seconds(2)); + + // common::WorkerPool thread_pool{static_cast(num_threads), {}}; + // thread_pool.Startup(); + + // for (int8_t i = 0; i < num_threads; i++) { + // thread_pool.SubmitTask([this, i, avg_interval_us, query_num] { + // workload_->Execute(i, execution_us_per_worker_, avg_interval_us, query_num, exec_mode_); + // }); + // } + + // thread_pool.WaitUntilAllFinished(); + // thread_pool.Shutdown(); + // } + // } + // } + // } + + const auto start = std::chrono::high_resolution_clock::now(); + std::this_thread::sleep_for(std::chrono::seconds{1}); + const auto stop = std::chrono::high_resolution_clock::now(); + const auto duration = std::chrono::duration_cast(stop - start).count(); + state.SetIterationTime(duration); // Free the workload here so we don't need to use the loggers anymore workload_.reset(); diff --git a/test/test_util/procbench/workload.cpp b/test/test_util/procbench/workload.cpp index 9fb423ea8b..1ecf6dee4d 100644 --- a/test/test_util/procbench/workload.cpp +++ b/test/test_util/procbench/workload.cpp @@ -16,7 +16,6 @@ #include "planner/plannodes/order_by_plan_node.h" #include "planner/plannodes/seq_scan_plan_node.h" #include "test_util/procbench/procbench_query.h" -#include "test_util/ssb/star_schema_query.h" namespace noisepage::procbench { @@ -95,19 +94,23 @@ Workload::Workload(common::ManagedPointer db_main, const std::string &db } void Workload::LoadTables(execution::exec::ExecutionContext *exec_ctx, const std::string &directory) { + EXECUTION_LOG_INFO("Loading tables for ProcBench benchmark..."); execution::sql::TableReader table_reader{exec_ctx, block_store_.Get(), ns_oid_}; for (const auto &table_name : PROCBENCH_TABLE_NAMES) { - const std::string schema_path = fmt::format("{}{}.schema", directory, table_name); const std::string data_path = fmt::format("{}{}.data", directory, table_name); + const std::string schema_path = fmt::format("{}{}.schema", directory, table_name); const auto num_rows = table_reader.ReadTable(schema_path, data_path); EXECUTION_LOG_INFO("Wrote {} rows on table {}.", num_rows, table_name); } + EXECUTION_LOG_INFO("Done."); } void Workload::LoadQueries(const std::unique_ptr &accessor) { + EXECUTION_LOG_INFO("Loading queries for ProcBench benchmark..."); // Executable query and plan node are stored as a tuple as the entry of vector (void)accessor; // query_and_plan_.emplace_back(TPCHQuery::MakeExecutableQ1(accessor, exec_settings_)); + EXECUTION_LOG_INFO("Done."); } void Workload::Execute(int8_t worker_id, uint64_t execution_us_per_worker, uint64_t avg_interval_us, uint32_t query_id, diff --git a/util/execution/table_generator/table_reader.cpp b/util/execution/table_generator/table_reader.cpp index ca5306b097..124ce15679 100644 --- a/util/execution/table_generator/table_reader.cpp +++ b/util/execution/table_generator/table_reader.cpp @@ -138,7 +138,7 @@ void TableReader::WriteIndexEntry(IndexInfo *index_info, storage::ProjectedRow * void TableReader::WriteTableCol(storage::ProjectedRow *insert_pr, uint16_t col_offset, execution::sql::SqlTypeId type, csv::CSVField *field) { - if (*field == NULL_STRING) { + if (*field == NULL_STRING || field->is_null()) { insert_pr->SetNull(col_offset); return; } @@ -190,7 +190,7 @@ void TableReader::WriteTableCol(storage::ProjectedRow *insert_pr, uint16_t col_o break; } default: - UNREACHABLE("Unsupported type. Add it here first!!!"); + UNREACHABLE("Unsupported type."); } } diff --git a/util/include/execution/table_generator/schema_reader.h b/util/include/execution/table_generator/schema_reader.h index d28368b927..78e4fc69fd 100644 --- a/util/include/execution/table_generator/schema_reader.h +++ b/util/include/execution/table_generator/schema_reader.h @@ -104,10 +104,12 @@ class SchemaReader { */ SchemaReader() : type_names_{{"tinyint", execution::sql::SqlTypeId::TinyInt}, {"smallint", execution::sql::SqlTypeId::SmallInt}, - {"int", execution::sql::SqlTypeId::Integer}, {"bigint", execution::sql::SqlTypeId::BigInt}, - {"bool", execution::sql::SqlTypeId::Boolean}, {"real", execution::sql::SqlTypeId::Double}, + {"integer", execution::sql::SqlTypeId::Integer}, {"int", execution::sql::SqlTypeId::Integer}, + {"bigint", execution::sql::SqlTypeId::BigInt}, {"bool", execution::sql::SqlTypeId::Boolean}, + {"real", execution::sql::SqlTypeId::Double}, {"float8", execution::sql::SqlTypeId::Double}, {"decimal", execution::sql::SqlTypeId::Double}, {"varchar", execution::sql::SqlTypeId::Varchar}, - {"varlen", execution::sql::SqlTypeId::Varchar}, {"date", execution::sql::SqlTypeId::Date}} {} + {"char", execution::sql::SqlTypeId::Char}, {"varlen", execution::sql::SqlTypeId::Varchar}, + {"date", execution::sql::SqlTypeId::Date}} {} /** * Reads table metadata From 0855600b16de6daa2a3b9f4d9fdbf56d8ad9c87d Mon Sep 17 00:00:00 2001 From: turingcompl33t Date: Thu, 9 Dec 2021 18:05:52 -0500 Subject: [PATCH 139/139] got procbench running --- benchmark/runner/procbench_runner.cpp | 54 +----- .../execution/compiler/expression_maker.h | 7 + .../test_util/procbench/procbench_query.h | 41 +--- test/include/test_util/procbench/workload.h | 16 +- test/test_util/procbench/procbench_query.cpp | 179 +++--------------- test/test_util/procbench/workload.cpp | 94 +++++---- 6 files changed, 100 insertions(+), 291 deletions(-) diff --git a/benchmark/runner/procbench_runner.cpp b/benchmark/runner/procbench_runner.cpp index e21a81319d..b479b13999 100644 --- a/benchmark/runner/procbench_runner.cpp +++ b/benchmark/runner/procbench_runner.cpp @@ -21,21 +21,9 @@ namespace noisepage::runner { */ class ProcbenchRunner : public benchmark::Fixture { public: - /** Defines the number of terminals (workers threads) */ - const int8_t total_num_threads_ = 1; - - /** Time (us) to run per terminal (worker thread) */ - const uint64_t execution_us_per_worker_ = 20000000; - - /** The average intervals in microseconds */ - std::vector avg_interval_us_ = {10, 20, 50, 100, 200, 500, 1000}; - /** The execution mode for the execution engine */ const execution::vm::ExecutionMode exec_mode_ = execution::vm::ExecutionMode::Interpret; - /** Flag indicating if only a single test run should be run */ - const bool single_test_run_ = true; - /** The main database instance */ std::unique_ptr db_main_; @@ -79,46 +67,14 @@ BENCHMARK_DEFINE_F(ProcbenchRunner, Runner)(benchmark::State &state) { workload_ = std::make_unique(common::ManagedPointer(db_main_), procbench_database_name_, procbench_table_root_); - // int8_t num_thread_start; - // uint32_t query_num_start, repeat_num; - // if (single_test_run_) { - // query_num_start = workload_->GetQueryCount(); - // num_thread_start = total_num_threads_; - // repeat_num = 1; - // } else { - // query_num_start = 1; - // num_thread_start = 1; - // repeat_num = 2; - // } - - // auto total_query_num = workload_->GetQueryCount() + 1; - // for (uint32_t query_num = query_num_start; query_num < total_query_num; query_num += 4) { - // for (auto num_threads = num_thread_start; num_threads <= total_num_threads_; num_threads += 3) { - // for (uint32_t repeat = 0; repeat < repeat_num; ++repeat) { - // for (auto avg_interval_us : avg_interval_us_) { - // // Let GC clean up - // std::this_thread::sleep_for(std::chrono::seconds(2)); - - // common::WorkerPool thread_pool{static_cast(num_threads), {}}; - // thread_pool.Startup(); - - // for (int8_t i = 0; i < num_threads; i++) { - // thread_pool.SubmitTask([this, i, avg_interval_us, query_num] { - // workload_->Execute(i, execution_us_per_worker_, avg_interval_us, query_num, exec_mode_); - // }); - // } - - // thread_pool.WaitUntilAllFinished(); - // thread_pool.Shutdown(); - // } - // } - // } - // } - const auto start = std::chrono::high_resolution_clock::now(); - std::this_thread::sleep_for(std::chrono::seconds{1}); + + // Execute the workload + workload_->Execute(6, execution::vm::ExecutionMode::Interpret); + const auto stop = std::chrono::high_resolution_clock::now(); const auto duration = std::chrono::duration_cast(stop - start).count(); + state.SetIterationTime(duration); // Free the workload here so we don't need to use the loggers anymore diff --git a/test/include/execution/compiler/expression_maker.h b/test/include/execution/compiler/expression_maker.h index da629e6b71..95b7e3413d 100644 --- a/test/include/execution/compiler/expression_maker.h +++ b/test/include/execution/compiler/expression_maker.h @@ -95,6 +95,13 @@ class ExpressionMaker { return MakeManaged(std::make_unique(catalog::table_oid_t(0), column_oid, type)); } + /** + * Create a column value expression + */ + ManagedExpression CVE(catalog::table_oid_t table_oid, catalog::col_oid_t column_oid, execution::sql::SqlTypeId type) { + return MakeManaged(std::make_unique(table_oid, column_oid, type)); + } + /** * Create a derived value expression */ diff --git a/test/include/test_util/procbench/procbench_query.h b/test/include/test_util/procbench/procbench_query.h index 1d8eac1c46..f4110f2c50 100644 --- a/test/include/test_util/procbench/procbench_query.h +++ b/test/include/test_util/procbench/procbench_query.h @@ -11,42 +11,9 @@ namespace noisepage::procbench { /** ProcbenchQuery defines queries for SQL Procbench benchmarks. */ class ProcbenchQuery { public: - /// Static functions to generate executable queries for ProcBench benchmark. Query plans are hard coded. - // static std::tuple, - // std::unique_ptr> MakeExecutableQ1(const std::unique_ptr - // &accessor, - // const execution::exec::ExecutionSettings &exec_settings); - // static std::tuple, - // std::unique_ptr> MakeExecutableQ4(const std::unique_ptr - // &accessor, - // const execution::exec::ExecutionSettings &exec_settings); - // static std::tuple, - // std::unique_ptr> MakeExecutableQ5(const std::unique_ptr - // &accessor, - // const execution::exec::ExecutionSettings &exec_settings); - // static std::tuple, - // std::unique_ptr> MakeExecutableQ6(const std::unique_ptr - // &accessor, - // const execution::exec::ExecutionSettings &exec_settings); - // static std::tuple, - // std::unique_ptr> MakeExecutableQ7(const std::unique_ptr - // &accessor, - // const execution::exec::ExecutionSettings &exec_settings); - // static std::tuple, - // std::unique_ptr> MakeExecutableQ11(const std::unique_ptr - // &accessor, - // const execution::exec::ExecutionSettings &exec_settings); - // static std::tuple, - // std::unique_ptr> MakeExecutableQ16(const std::unique_ptr - // &accessor, - // const execution::exec::ExecutionSettings &exec_settings); - // static std::tuple, - // std::unique_ptr> MakeExecutableQ18(const std::unique_ptr - // &accessor, - // const execution::exec::ExecutionSettings &exec_settings); - // static std::tuple, - // std::unique_ptr> MakeExecutableQ19(const std::unique_ptr - // &accessor, - // const execution::exec::ExecutionSettings &exec_settings); + // Static functions to generate executable queries for ProcBench benchmark. Query plans are hard coded. + static std::tuple, std::unique_ptr> + MakeExecutableQ6(const std::unique_ptr &accessor, + const execution::exec::ExecutionSettings &exec_settings); }; } // namespace noisepage::procbench diff --git a/test/include/test_util/procbench/workload.h b/test/include/test_util/procbench/workload.h index 16adf0ed1f..695f742fc9 100644 --- a/test/include/test_util/procbench/workload.h +++ b/test/include/test_util/procbench/workload.h @@ -47,14 +47,9 @@ class Workload { /** * Function to invoke for a single worker thread to invoke the ProcBench queries. - * @param worker_id 1-indexed thread ID - * @param execution_us_per_worker - * @param avg_interval_us - * @param query_id The identifier for the query to invoke * @param exec_mode The execution mode */ - void Execute(int8_t worker_id, uint64_t execution_us_per_worker, uint64_t avg_interval_us, uint32_t query_id, - execution::vm::ExecutionMode exec_mode); + void Execute(std::size_t query_number, execution::vm::ExecutionMode exec_mode); /** @return The number of queries in the workload. */ uint32_t GetQueryCount() { return query_and_plan_.size(); } @@ -73,6 +68,13 @@ class Workload { */ void LoadQueries(const std::unique_ptr &accessor); + /** + * Get the index for the specified query number + * @param query_number The query number + * @return The index + */ + std::size_t QueryNumberToIndex(std::size_t query_number) const; + private: /** The database server instance */ common::ManagedPointer db_main_; @@ -94,6 +96,8 @@ class Workload { std::vector< std::tuple, std::unique_ptr>> query_and_plan_; + /** Translate a query number of corresponding index */ + std::unordered_map query_number_to_index_; }; } // namespace noisepage::procbench diff --git a/test/test_util/procbench/procbench_query.cpp b/test/test_util/procbench/procbench_query.cpp index 7db0e9b2b9..336b4cba76 100644 --- a/test/test_util/procbench/procbench_query.cpp +++ b/test/test_util/procbench/procbench_query.cpp @@ -13,154 +13,37 @@ namespace noisepage::procbench { -// std::tuple, std::unique_ptr> -// ProcbenchQuery::MakeExecutableQ1(const std::unique_ptr &accessor, -// const execution::exec::ExecutionSettings &exec_settings) { -// execution::compiler::test::ExpressionMaker expr_maker; -// auto table_oid = accessor->GetTableOid("lineitem"); -// const auto &l_schema = accessor->GetSchema(table_oid); -// // Scan the table -// std::unique_ptr l_seq_scan; -// execution::compiler::test::OutputSchemaHelper l_seq_scan_out{0, &expr_maker}; -// { -// // Read all needed columns -// auto l_returnflag = expr_maker.CVE(l_schema.GetColumn("l_returnflag").Oid(), execution::sql::SqlTypeId::Varchar); -// auto l_linestatus = expr_maker.CVE(l_schema.GetColumn("l_linestatus").Oid(), execution::sql::SqlTypeId::Varchar); -// auto l_extendedprice = -// expr_maker.CVE(l_schema.GetColumn("l_extendedprice").Oid(), execution::sql::SqlTypeId::Double); -// auto l_discount = expr_maker.CVE(l_schema.GetColumn("l_discount").Oid(), execution::sql::SqlTypeId::Double); -// auto l_tax = expr_maker.CVE(l_schema.GetColumn("l_tax").Oid(), execution::sql::SqlTypeId::Double); -// auto l_quantity = expr_maker.CVE(l_schema.GetColumn("l_quantity").Oid(), execution::sql::SqlTypeId::Double); -// auto l_shipdate = expr_maker.CVE(l_schema.GetColumn("l_shipdate").Oid(), execution::sql::SqlTypeId::Date); -// std::vector col_oids = { -// l_schema.GetColumn("l_returnflag").Oid(), l_schema.GetColumn("l_linestatus").Oid(), -// l_schema.GetColumn("l_extendedprice").Oid(), l_schema.GetColumn("l_discount").Oid(), -// l_schema.GetColumn("l_tax").Oid(), l_schema.GetColumn("l_quantity").Oid(), -// l_schema.GetColumn("l_shipdate").Oid()}; -// // Make the output schema -// l_seq_scan_out.AddOutput("l_returnflag", l_returnflag); -// l_seq_scan_out.AddOutput("l_linestatus", l_linestatus); -// l_seq_scan_out.AddOutput("l_extendedprice", l_extendedprice); -// l_seq_scan_out.AddOutput("l_discount", l_discount); -// l_seq_scan_out.AddOutput("l_tax", l_tax); -// l_seq_scan_out.AddOutput("l_quantity", l_quantity); -// auto schema = l_seq_scan_out.MakeSchema(); -// // Make the predicate -// l_seq_scan_out.AddOutput("l_shipdate", l_shipdate); -// auto date_const = expr_maker.Constant(1998, 9, 2); -// auto predicate = expr_maker.ComparisonLt(l_shipdate, date_const); -// // Build -// planner::SeqScanPlanNode::Builder builder; -// l_seq_scan = builder.SetOutputSchema(std::move(schema)) -// .SetScanPredicate(predicate) -// .SetTableOid(table_oid) -// .SetColumnOids(std::move(col_oids)) -// .Build(); -// } -// // Make the aggregate -// std::unique_ptr agg; -// execution::compiler::test::OutputSchemaHelper agg_out{0, &expr_maker}; -// { -// // Read previous layer's output -// auto l_returnflag = l_seq_scan_out.GetOutput("l_returnflag"); -// auto l_linestatus = l_seq_scan_out.GetOutput("l_linestatus"); -// auto l_quantity = l_seq_scan_out.GetOutput("l_quantity"); -// auto l_extendedprice = l_seq_scan_out.GetOutput("l_extendedprice"); -// auto l_discount = l_seq_scan_out.GetOutput("l_discount"); -// auto l_tax = l_seq_scan_out.GetOutput("l_tax"); -// // Make the aggregate expressions -// auto sum_qty = expr_maker.AggSum(l_quantity); -// auto sum_base_price = expr_maker.AggSum(l_extendedprice); -// auto one_const = expr_maker.Constant(1.0f); -// auto disc_price = expr_maker.OpMul(l_extendedprice, expr_maker.OpMin(one_const, l_discount)); -// auto sum_disc_price = expr_maker.AggSum(disc_price); -// auto charge = expr_maker.OpMul(disc_price, expr_maker.OpSum(one_const, l_tax)); -// auto sum_charge = expr_maker.AggSum(charge); -// auto avg_qty = expr_maker.AggAvg(l_quantity); -// auto avg_price = expr_maker.AggAvg(l_extendedprice); -// auto avg_disc = expr_maker.AggAvg(l_discount); -// auto count_order = expr_maker.AggCount(expr_maker.Constant(1)); // Works as Count(*) -// // Add them to the helper. -// agg_out.AddGroupByTerm("l_returnflag", l_returnflag); -// agg_out.AddGroupByTerm("l_linestatus", l_linestatus); -// agg_out.AddAggTerm("sum_qty", sum_qty); -// agg_out.AddAggTerm("sum_base_price", sum_base_price); -// agg_out.AddAggTerm("sum_disc_price", sum_disc_price); -// agg_out.AddAggTerm("sum_charge", sum_charge); -// agg_out.AddAggTerm("avg_qty", avg_qty); -// agg_out.AddAggTerm("avg_price", avg_price); -// agg_out.AddAggTerm("avg_disc", avg_disc); -// agg_out.AddAggTerm("count_order", count_order); -// // Make the output schema -// agg_out.AddOutput("l_returnflag", agg_out.GetGroupByTermForOutput("l_returnflag")); -// agg_out.AddOutput("l_linestatus", agg_out.GetGroupByTermForOutput("l_linestatus")); -// agg_out.AddOutput("sum_qty", agg_out.GetAggTermForOutput("sum_qty")); -// agg_out.AddOutput("sum_base_price", agg_out.GetAggTermForOutput("sum_base_price")); -// agg_out.AddOutput("sum_disc_price", agg_out.GetAggTermForOutput("sum_disc_price")); -// agg_out.AddOutput("sum_charge", agg_out.GetAggTermForOutput("sum_charge")); -// agg_out.AddOutput("avg_qty", agg_out.GetAggTermForOutput("avg_qty")); -// agg_out.AddOutput("avg_price", agg_out.GetAggTermForOutput("avg_price")); -// agg_out.AddOutput("avg_disc", agg_out.GetAggTermForOutput("avg_disc")); -// agg_out.AddOutput("count_order", agg_out.GetAggTermForOutput("count_order")); -// auto schema = agg_out.MakeSchema(); -// // Build -// planner::AggregatePlanNode::Builder builder; -// agg = builder.SetOutputSchema(std::move(schema)) -// .AddGroupByTerm(l_returnflag) -// .AddGroupByTerm(l_linestatus) -// .AddAggregateTerm(sum_qty) -// .AddAggregateTerm(sum_base_price) -// .AddAggregateTerm(sum_disc_price) -// .AddAggregateTerm(sum_charge) -// .AddAggregateTerm(avg_qty) -// .AddAggregateTerm(avg_price) -// .AddAggregateTerm(avg_disc) -// .AddAggregateTerm(count_order) -// .AddChild(std::move(l_seq_scan)) -// .SetAggregateStrategyType(planner::AggregateStrategyType::HASH) -// .SetHavingClausePredicate(nullptr) -// .Build(); -// } +std::tuple, std::unique_ptr> +ProcbenchQuery::MakeExecutableQ6(const std::unique_ptr &accessor, + const execution::exec::ExecutionSettings &exec_settings) { + execution::compiler::test::ExpressionMaker expr_maker; + const auto web_sales_history_oid = accessor->GetTableOid("web_sales_history"); + const auto &web_sales_history_schema = accessor->GetSchema(web_sales_history_oid); -// // Order By -// std::unique_ptr order_by; -// execution::compiler::test::OutputSchemaHelper order_by_out{0, &expr_maker}; -// { -// // Output Colums col1, col2, col1 + col2 -// auto l_returnflag = agg_out.GetOutput("l_returnflag"); -// auto l_linestatus = agg_out.GetOutput("l_linestatus"); -// auto sum_qty = agg_out.GetOutput("sum_qty"); -// auto sum_base_price = agg_out.GetOutput("sum_base_price"); -// auto sum_disc_price = agg_out.GetOutput("sum_disc_price"); -// auto sum_charge = agg_out.GetOutput("sum_charge"); -// auto avg_qty = agg_out.GetOutput("avg_qty"); -// auto avg_price = agg_out.GetOutput("avg_price"); -// auto avg_disc = agg_out.GetOutput("avg_disc"); -// auto count_order = agg_out.GetOutput("count_order"); -// order_by_out.AddOutput("l_returnflag", l_returnflag); -// order_by_out.AddOutput("l_linestatus", l_linestatus); -// order_by_out.AddOutput("sum_qty", sum_qty); -// order_by_out.AddOutput("sum_base_price", sum_base_price); -// order_by_out.AddOutput("sum_disc_price", sum_disc_price); -// order_by_out.AddOutput("sum_charge", sum_charge); -// order_by_out.AddOutput("avg_qty", avg_qty); -// order_by_out.AddOutput("avg_price", avg_price); -// order_by_out.AddOutput("avg_disc", avg_disc); -// order_by_out.AddOutput("count_order", count_order); -// auto schema = order_by_out.MakeSchema(); -// // Order By Clause -// planner::SortKey clause1{l_returnflag, optimizer::OrderByOrderingType::ASC}; -// planner::SortKey clause2{l_linestatus, optimizer::OrderByOrderingType::ASC}; -// // Build -// planner::OrderByPlanNode::Builder builder; -// order_by = builder.SetOutputSchema(std::move(schema)) -// .AddChild(std::move(agg)) -// .AddSortKey(clause1.first, clause1.second) -// .AddSortKey(clause2.first, clause2.second) -// .Build(); -// } -// auto query = execution::compiler::CompilationContext::Compile(*order_by, exec_settings, accessor.get()); -// return std::make_tuple(std::move(query), std::move(order_by)); -// } + // Scan the table + std::unique_ptr seq_scan; + execution::compiler::test::OutputSchemaHelper seq_scan_out{0, &expr_maker}; + { + // Read all needed columns + auto ws_sold_date = + expr_maker.CVE(web_sales_history_oid, web_sales_history_schema.GetColumn("ws_sold_date_sk").Oid(), + execution::sql::SqlTypeId::Integer); + std::vector col_oids = {web_sales_history_schema.GetColumn("ws_sold_date_sk").Oid()}; + + // Make the output schema + seq_scan_out.AddOutput("ws_sold_date", ws_sold_date); + auto schema = seq_scan_out.MakeSchema(); + + // Build + planner::SeqScanPlanNode::Builder builder; + seq_scan = builder.SetOutputSchema(std::move(schema)) + .SetScanPredicate(nullptr) + .SetTableOid(web_sales_history_oid) + .SetColumnOids(std::move(col_oids)) + .Build(); + } + auto query = execution::compiler::CompilationContext::Compile(*seq_scan, exec_settings, accessor.get()); + return std::make_tuple(std::move(query), std::move(seq_scan)); +} } // namespace noisepage::procbench diff --git a/test/test_util/procbench/workload.cpp b/test/test_util/procbench/workload.cpp index 1ecf6dee4d..8d7d900461 100644 --- a/test/test_util/procbench/workload.cpp +++ b/test/test_util/procbench/workload.cpp @@ -19,6 +19,9 @@ namespace noisepage::procbench { +/** Query identifiers */ +static constexpr const std::size_t Q6_ID = 6; + /** ProcBench table names */ static const std::vector PROCBENCH_TABLE_NAMES{"call_center", "catalog_page", @@ -107,67 +110,56 @@ void Workload::LoadTables(execution::exec::ExecutionContext *exec_ctx, const std void Workload::LoadQueries(const std::unique_ptr &accessor) { EXECUTION_LOG_INFO("Loading queries for ProcBench benchmark..."); + // Executable query and plan node are stored as a tuple as the entry of vector - (void)accessor; - // query_and_plan_.emplace_back(TPCHQuery::MakeExecutableQ1(accessor, exec_settings_)); + query_and_plan_.emplace_back(ProcbenchQuery::MakeExecutableQ6(accessor, exec_settings_)); + query_number_to_index_[Q6_ID] = query_and_plan_.size() - 1; + EXECUTION_LOG_INFO("Done."); } -void Workload::Execute(int8_t worker_id, uint64_t execution_us_per_worker, uint64_t avg_interval_us, uint32_t query_id, - execution::vm::ExecutionMode mode) { - // Shuffle the queries randomly for each thread - const auto total_query_num = query_and_plan_.size(); - std::vector index{}; - index.resize(total_query_num); - std::iota(index.begin(), index.end(), 0); - std::shuffle(index.begin(), index.end(), std::mt19937(time(nullptr) + worker_id)); - - // Get the sleep time range distribution - std::mt19937 generator{}; - std::uniform_int_distribution distribution(avg_interval_us - avg_interval_us / 2, - avg_interval_us + avg_interval_us / 2); +void Workload::Execute(std::size_t query_number, execution::vm::ExecutionMode mode) { + // The total number of queries to be executed + const std::size_t query_index = QueryNumberToIndex(query_number); // Register to the metrics manager db_main_->GetMetricsManager()->RegisterThread(); - uint32_t counter = 0; - uint64_t end_time = metrics::MetricsUtil::Now() + execution_us_per_worker; - while (metrics::MetricsUtil::Now() < end_time) { - // Executing all the queries on by one in round robin - auto txn = txn_manager_->BeginTransaction(); - auto accessor = - catalog_->GetAccessor(common::ManagedPointer(txn), db_oid_, DISABLED); - - auto output_schema = std::get<1>(query_and_plan_[index[counter]])->GetOutputSchema().Get(); - // Uncomment this line and change output.cpp:90 to EXECUTION_LOG_INFO to print output - // execution::exec::OutputPrinter printer(output_schema); - execution::exec::NoOpResultConsumer printer; - - auto exec_ctx = execution::exec::ExecutionContextBuilder() - .WithDatabaseOID(db_oid_) - .WithExecutionSettings(exec_settings_) - .WithTxnContext(common::ManagedPointer{txn}) - .WithOutputSchema(common::ManagedPointer{output_schema}) - .WithOutputCallback(printer) - .WithCatalogAccessor(common::ManagedPointer{accessor}) - .WithMetricsManager(db_main_->GetMetricsManager()) - .WithReplicationManager(DISABLED) - .WithRecoveryManager(DISABLED) - .Build(); - - std::get<0>(query_and_plan_[index[counter]]) - ->Run(common::ManagedPointer(exec_ctx), mode); - - // Only execute up to query_num number of queries for this thread in round-robin - counter = counter == query_id - 1 ? 0 : counter + 1; - txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); - - // Sleep to create different execution frequency patterns - auto random_sleep_time = distribution(generator); - std::this_thread::sleep_for(std::chrono::microseconds(random_sleep_time)); - } + + // Execute the selected query + auto txn = txn_manager_->BeginTransaction(); + auto accessor = + catalog_->GetAccessor(common::ManagedPointer(txn), db_oid_, DISABLED); + + // Get the output schema for the query + auto *output_schema = std::get<1>(query_and_plan_.at(query_index))->GetOutputSchema().Get(); + + // Construct an execution context for the query + execution::exec::NoOpResultConsumer printer; + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithExecutionSettings(exec_settings_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(common::ManagedPointer{output_schema}) + .WithOutputCallback(printer) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); + + // Execute the query + std::cout << "Executing...\n"; + std::get<0>(query_and_plan_.at(query_index)) + ->Run(common::ManagedPointer(exec_ctx), mode); + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); + std::cout << "Done.\n"; // Unregister from the metrics manager db_main_->GetMetricsManager()->UnregisterThread(); } +std::size_t Workload::QueryNumberToIndex(std::size_t query_number) const { + return query_number_to_index_.at(query_number); +} + } // namespace noisepage::procbench