From 7751994d924468826a97ba43780f8db281188784 Mon Sep 17 00:00:00 2001 From: Andreas Fertig Date: Fri, 26 Jul 2024 10:50:08 -0400 Subject: [PATCH] Fixed support for static variable inside a coroutine. --- CodeGenerator.cpp | 5 +- CodeGenerator.h | 6 +- CoroutinesCodeGenerator.cpp | 6 +- tests/EduCoroutineStaticVarTest.cpp | 76 +++++++++ tests/EduCoroutineStaticVarTest.expect | 214 +++++++++++++++++++++++++ 5 files changed, 301 insertions(+), 6 deletions(-) create mode 100644 tests/EduCoroutineStaticVarTest.cpp create mode 100644 tests/EduCoroutineStaticVarTest.expect diff --git a/CodeGenerator.cpp b/CodeGenerator.cpp index 92dd038..026ed07 100644 --- a/CodeGenerator.cpp +++ b/CodeGenerator.cpp @@ -1262,7 +1262,7 @@ void CodeGenerator::InsertArg(const VarDecl* stmt) HandleLocalStaticNonTrivialClass(stmt); } else { - if(InsertVarDecl()) { + if(InsertVarDecl(stmt)) { const auto desugaredType = GetType(GetDesugarType(stmt->getType())); const bool isMemberPointer{isa(desugaredType.getTypePtrOrNull())}; @@ -2542,7 +2542,8 @@ void CodeGenerator::InsertArg(const ForStmt* stmt) WrapInParens( [&]() { if(const auto* init = stmt->getInit()) { - MultiStmtDeclCodeGenerator codeGenerator{mOutputFormatHelper, mLambdaStack, InsertVarDecl()}; + MultiStmtDeclCodeGenerator codeGenerator{ + mOutputFormatHelper, mLambdaStack, InsertVarDecl(nullptr)}; codeGenerator.InsertArg(init); } else { diff --git a/CodeGenerator.h b/CodeGenerator.h index 2d6dea9..8ed01d5 100644 --- a/CodeGenerator.h +++ b/CodeGenerator.h @@ -281,7 +281,7 @@ class CodeGenerator void EndLifetimeScope(); protected: - virtual bool InsertVarDecl() { return true; } + virtual bool InsertVarDecl(const VarDecl*) { return true; } virtual bool SkipSpaceAfterVarDecl() { return false; } virtual bool InsertComma() { return false; } virtual bool InsertSemi() { return true; } @@ -500,7 +500,7 @@ class MultiStmtDeclCodeGenerator final : public CodeGenerator OnceFalse mInsertComma{}; //! Insert the comma after we have generated the first \c VarDecl and we are about to //! insert another one. - bool InsertVarDecl() override { return mInsertVarDecl; } + bool InsertVarDecl(const VarDecl*) override { return mInsertVarDecl; } bool InsertComma() override { return mInsertComma; } bool InsertSemi() override { return false; } }; @@ -563,7 +563,7 @@ class CoroutinesCodeGenerator final : public CodeGenerator std::string GetFrameName() const { return mFrameName; } protected: - bool InsertVarDecl() override { return mInsertVarDecl; } + bool InsertVarDecl(const VarDecl* vd) override { return mInsertVarDecl or (vd and vd->isStaticLocal()); } bool SkipSpaceAfterVarDecl() override { return not mInsertVarDecl; } private: diff --git a/CoroutinesCodeGenerator.cpp b/CoroutinesCodeGenerator.cpp index e48685d..cebf5ff 100644 --- a/CoroutinesCodeGenerator.cpp +++ b/CoroutinesCodeGenerator.cpp @@ -229,7 +229,7 @@ class CoroutineASTTransformer : public StmtVisitor void VisitDeclRefExpr(DeclRefExpr* stmt) { if(auto* vd = dyn_cast_or_null(stmt->getDecl())) { - RETURN_IF(not vd->isLocalVarDeclOrParm() or not Contains(mVarNamePrefix, vd)); + RETURN_IF(not vd->isLocalVarDeclOrParm() or vd->isStaticLocal() or not Contains(mVarNamePrefix, vd)); auto* memberExpr = mVarNamePrefix[vd]; @@ -241,6 +241,10 @@ class CoroutineASTTransformer : public StmtVisitor { for(auto* decl : stmt->decls()) { if(auto* varDecl = dyn_cast_or_null(decl)) { + if(varDecl->isStaticLocal()) { + continue; + } + // add this point a placement-new would be appropriate for at least some cases. auto* field = AddField(mASTData, GetName(*varDecl), varDecl->getType()); diff --git a/tests/EduCoroutineStaticVarTest.cpp b/tests/EduCoroutineStaticVarTest.cpp new file mode 100644 index 0000000..1f01be4 --- /dev/null +++ b/tests/EduCoroutineStaticVarTest.cpp @@ -0,0 +1,76 @@ +// cmdline:-std=c++20 +// cmdlineinsights:-edu-show-coroutine-transformation + +#include +#include +#include // std::terminate +#include +#include +#include +#include +#include + +using namespace std::string_literals; +using namespace std::string_view_literals; + +struct Task { + struct promise_type { + Task get_return_object() noexcept { return {}; } + std::suspend_never initial_suspend() noexcept { return {}; } + std::suspend_never final_suspend() noexcept { return {}; } + void return_void() noexcept {} + void unhandled_exception() noexcept {} + }; +}; + +struct Scheduler; + +struct awaiter : std::suspend_always { + Scheduler* _sched; + + explicit awaiter(Scheduler& sched) + : _sched{&sched} + {} + void await_suspend(std::coroutine_handle<> coro) const noexcept; +}; + +struct Scheduler { + std::list> _tasks{}; + + bool schedule() + { + auto task = _tasks.front(); + _tasks.pop_front(); + + if(not task.done()) { task.resume(); } + + return not _tasks.empty(); + } + + auto suspend() { return awaiter{*this}; } +}; + +void awaiter::await_suspend(std::coroutine_handle<> coro) const noexcept +{ + _sched->_tasks.push_back(coro); +} + +Task taskA(Scheduler& sched) +{ + std::cout << "Hello, from task A\n"sv; + + co_await sched.suspend(); + + static std::string res{"a is back doing work\n"s}; + std::cout << res; +} + +int main() +{ + Scheduler scheduler{}; + + taskA(scheduler); + + while(scheduler.schedule()) {} +} + diff --git a/tests/EduCoroutineStaticVarTest.expect b/tests/EduCoroutineStaticVarTest.expect new file mode 100644 index 0000000..e3d49d1 --- /dev/null +++ b/tests/EduCoroutineStaticVarTest.expect @@ -0,0 +1,214 @@ +/************************************************************************************* + * NOTE: The coroutine transformation you've enabled is a hand coded transformation! * + * Most of it is _not_ present in the AST. What you see is an approximation. * + *************************************************************************************/ +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std::string_literals; +using namespace std::string_view_literals; + +struct Task +{ + struct promise_type + { + inline Task get_return_object() noexcept + { + return {}; + } + + inline std::suspend_never initial_suspend() noexcept + { + return {}; + } + + inline std::suspend_never final_suspend() noexcept + { + return {}; + } + + inline void return_void() noexcept + { + } + + inline void unhandled_exception() noexcept + { + } + + // inline constexpr promise_type() noexcept = default; + }; + +}; + + +struct Scheduler; + +struct awaiter : public std::suspend_always +{ + Scheduler * _sched; + inline explicit awaiter(Scheduler & sched) + : std::suspend_always() + , _sched{&sched} + { + } + + void await_suspend(std::coroutine_handle coro) const noexcept; + +}; + + +struct Scheduler +{ + std::list, std::allocator > > _tasks = std::list, std::allocator > >{}; + inline bool schedule() + { + std::coroutine_handle task = std::coroutine_handle(this->_tasks.front()); + this->_tasks.pop_front(); + if(!task.done()) { + task.resume(); + } + + return !this->_tasks.empty(); + } + + inline awaiter suspend() + { + return awaiter{*this}; + } + + // inline ~Scheduler() noexcept = default; +}; + + +void awaiter::await_suspend(std::coroutine_handle coro) const noexcept +{ + this->_sched->_tasks.push_back(coro); +} + + +struct __taskAFrame +{ + void (*resume_fn)(__taskAFrame *); + void (*destroy_fn)(__taskAFrame *); + std::__coroutine_traits_sfinae::promise_type __promise; + int __suspend_index; + bool __initial_await_suspend_called; + Scheduler & sched; + std::suspend_never __suspend_58_6; + awaiter __suspend_62_18; + std::suspend_never __suspend_58_6_1; +}; + +Task taskA(Scheduler & sched) +{ + /* Allocate the frame including the promise */ + /* Note: The actual parameter new is __builtin_coro_size */ + __taskAFrame * __f = reinterpret_cast<__taskAFrame *>(operator new(sizeof(__taskAFrame))); + __f->__suspend_index = 0; + __f->__initial_await_suspend_called = false; + __f->sched = std::forward(sched); + + /* Construct the promise. */ + new (&__f->__promise)std::__coroutine_traits_sfinae::promise_type{}; + + /* Forward declare the resume and destroy function. */ + void __taskAResume(__taskAFrame * __f); + void __taskADestroy(__taskAFrame * __f); + + /* Assign the resume and destroy function pointers. */ + __f->resume_fn = &__taskAResume; + __f->destroy_fn = &__taskADestroy; + + /* Call the made up function with the coroutine body for initial suspend. + This function will be called subsequently by coroutine_handle<>::resume() + which calls __builtin_coro_resume(__handle_) */ + __taskAResume(__f); + + + return __f->__promise.get_return_object(); +} + +/* This function invoked by coroutine_handle<>::resume() */ +void __taskAResume(__taskAFrame * __f) +{ + try + { + /* Create a switch to get to the correct resume point */ + switch(__f->__suspend_index) { + case 0: break; + case 1: goto __resume_taskA_1; + case 2: goto __resume_taskA_2; + } + + /* co_await EduCoroutineStaticVarTest.cpp:58 */ + __f->__suspend_58_6 = __f->__promise.initial_suspend(); + if(!__f->__suspend_58_6.await_ready()) { + __f->__suspend_58_6.await_suspend(std::coroutine_handle::from_address(static_cast(__f)).operator std::coroutine_handle()); + __f->__suspend_index = 1; + __f->__initial_await_suspend_called = true; + return; + } + + __resume_taskA_1: + __f->__suspend_58_6.await_resume(); + std::operator<<(std::cout, std::operator""sv("Hello, from task A\n", 19UL)); + + /* co_await EduCoroutineStaticVarTest.cpp:62 */ + __f->__suspend_62_18 = __f->sched.suspend(); + if(!__f->__suspend_62_18.await_ready()) { + __f->__suspend_62_18.await_suspend(std::coroutine_handle::from_address(static_cast(__f)).operator std::coroutine_handle()); + __f->__suspend_index = 2; + return; + } + + __resume_taskA_2: + __f->__suspend_62_18.await_resume(); + static std::basic_string, std::allocator > res = {std::operator""s("a is back doing work\n", 21UL)}; + std::operator<<(std::cout, res); + goto __final_suspend; + } catch(...) { + if(!__f->__initial_await_suspend_called) { + throw ; + } + + __f->__promise.unhandled_exception(); + } + + __final_suspend: + + /* co_await EduCoroutineStaticVarTest.cpp:58 */ + __f->__suspend_58_6_1 = __f->__promise.final_suspend(); + if(!__f->__suspend_58_6_1.await_ready()) { + __f->__suspend_58_6_1.await_suspend(std::coroutine_handle::from_address(static_cast(__f)).operator std::coroutine_handle()); + return; + } + + __f->destroy_fn(__f); +} + +/* This function invoked by coroutine_handle<>::destroy() */ +void __taskADestroy(__taskAFrame * __f) +{ + /* destroy all variables with dtors */ + __f->~__taskAFrame(); + /* Deallocating the coroutine frame */ + /* Note: The actual argument to delete is __builtin_coro_frame with the promise as parameter */ + operator delete(static_cast(__f)); +} + + +int main() +{ + Scheduler scheduler = {{std::list, std::allocator > >{}}}; + taskA(scheduler); + while(scheduler.schedule()) { + } + + return 0; +}