diff --git a/src/scripting/FunctionOverride.cpp b/src/scripting/FunctionOverride.cpp index bcd4e13b..e911b371 100644 --- a/src/scripting/FunctionOverride.cpp +++ b/src/scripting/FunctionOverride.cpp @@ -3,13 +3,12 @@ #include "FunctionOverride.h" #include "Scripting.h" #include "Utils.h" -#include +#include #include #include static FunctionOverride* s_pOverride = nullptr; -static RTTILocator s_inkGameControllerType("gameuiWidgetGameController"); using TRunPureScriptFunction = bool (*)(RED4ext::CBaseFunction* apFunction, RED4ext::CScriptStack*, void*); using TCallScriptFunction = bool (*)(RED4ext::IFunction* apFunction, RED4ext::IScriptable* apContext, @@ -73,45 +72,17 @@ void FunctionOverride::Clear() // Reverse order as we want to swap from most recent to oldest change for (auto& [pFunction, pContext] : m_functions) { - // Just an added function, not an override - if (pContext.Trampoline == nullptr) - { - auto* pClassType = pFunction->parent; - auto* pArray = &pClassType->funcs; - - if (pFunction->flags.isStatic) - pClassType->staticFuncs; - - for (auto*& pItor : *pArray) - { - if (pItor == pFunction) - { - // Swap our self with the last element - pItor = *(pArray->End() - 1); - // Pop last - pArray->size -= 1; - - break; - } - } - } - else - { - auto* pRealFunction = pContext.Trampoline; + auto* pRealFunction = pContext.Trampoline; - std::array tmpBuffer; - size_t funcSize = GetFunctionSize(pRealFunction); + std::array tmpBuffer; + size_t funcSize = GetFunctionSize(pRealFunction); - std::memcpy(&tmpBuffer, pRealFunction, funcSize); - std::memcpy(pRealFunction, pFunction, funcSize); - std::memcpy(pFunction, &tmpBuffer, funcSize); - } + std::memcpy(&tmpBuffer, pRealFunction, funcSize); + std::memcpy(pRealFunction, pFunction, funcSize); + std::memcpy(pFunction, &tmpBuffer, funcSize); } m_functions.clear(); - - m_pBuffer = m_pBufferStart; - m_size = kExecutableSize; } bool FunctionOverride::HookRunPureScriptFunction(RED4ext::CClassFunction* apFunction, RED4ext::CScriptStack* apStack, RED4ext::CStackFrame* a3) @@ -136,6 +107,7 @@ bool FunctionOverride::HookRunPureScriptFunction(RED4ext::CClassFunction* apFunc auto pAllocator = TiltedPhoques::Allocator::Get(); TiltedPhoques::Allocator::Set(&s_allocator); TiltedPhoques::Vector args; + TiltedPhoques::Vector outArgs; TiltedPhoques::Allocator::Set(pAllocator); auto state = chain.pScripting->GetState(); @@ -143,8 +115,9 @@ bool FunctionOverride::HookRunPureScriptFunction(RED4ext::CClassFunction* apFunc auto pContext = apStack->GetContext(); if (!apFunction->flags.isStatic && pContext) { - const auto handle = RED4ext::Handle(pContext); - auto obj = sol::make_object(state.Get(), StrongReference(state, handle)); + const auto weak = RED4ext::WeakHandle( + *(RED4ext::WeakHandle*)&pContext->ref); + auto obj = sol::make_object(state.Get(), WeakReference(state, weak)); args.push_back(obj); } @@ -158,13 +131,16 @@ bool FunctionOverride::HookRunPureScriptFunction(RED4ext::CClassFunction* apFunc arg.value = pOffset; args.push_back(Scripting::ToLua(state, arg)); + + if (p->flags.isOut) + outArgs.push_back(arg); } RED4ext::CStackType ret; ret.value = apStack->GetResultAddr(); ret.type = apStack->GetType(); - return ExecuteChain(chain, lock, pContext, &args, &ret, apStack, a3, nullptr, 0); + return ExecuteChain(chain, lock, pContext, &args, &ret, &outArgs, apStack, a3, nullptr, 0); } if (chain.CollectGarbage) @@ -209,6 +185,7 @@ void FunctionOverride::HandleOverridenFunction(RED4ext::IScriptable* apContext, auto pAllocator = TiltedPhoques::Allocator::Get(); TiltedPhoques::Allocator::Set(&s_allocator); TiltedPhoques::Vector args; + TiltedPhoques::Vector outArgs; TiltedPhoques::Allocator::Set(pAllocator); auto state = chain.pScripting->GetState(); @@ -228,8 +205,9 @@ void FunctionOverride::HandleOverridenFunction(RED4ext::IScriptable* apContext, self.value = apFrame->context; } - const auto handle = RED4ext::Handle((RED4ext::IScriptable*)self.value); - auto obj = sol::make_object(state.Get(), StrongReference(state, handle)); + const auto ref = (RED4ext::WeakHandle*)&((RED4ext::IScriptable*)self.value)->ref; + const auto weak = RED4ext::WeakHandle(*ref); + auto obj = sol::make_object(state.Get(), WeakReference(state, weak)); args.push_back(obj); } @@ -269,6 +247,15 @@ void FunctionOverride::HandleOverridenFunction(RED4ext::IScriptable* apContext, args.push_back(Scripting::ToLua(state, arg)); + if (pArg->flags.isOut) + { + // This is an original arg, pInstance contains copy + if (apFrame->unk30) + arg.value = reinterpret_cast(apFrame->unk30); + + outArgs.push_back(arg); + } + // Release inner values if (isScriptRef) { @@ -277,8 +264,11 @@ void FunctionOverride::HandleOverridenFunction(RED4ext::IScriptable* apContext, pScriptRef->innerType->GetAllocator()->Free(pScriptRef->ref); } - pType->Destroy(pInstance); - pAllocator->Free(pInstance); + if (!pArg->flags.isOut || apFrame->unk30) + { + pType->Destroy(pInstance); + pAllocator->Free(pInstance); + } } apFrame->code++; // skip ParamEnd @@ -291,7 +281,7 @@ void FunctionOverride::HandleOverridenFunction(RED4ext::IScriptable* apContext, ret.value = apOut; } - ExecuteChain(chain, lock, apContext, &args, &ret, nullptr, apFrame, pCode, currentParam); + ExecuteChain(chain, lock, apContext, &args, &ret, &outArgs, nullptr, apFrame, pCode, currentParam); return; } @@ -307,8 +297,9 @@ void FunctionOverride::HandleOverridenFunction(RED4ext::IScriptable* apContext, bool FunctionOverride::ExecuteChain(const CallChain& aChain, std::shared_lock& aLock, RED4ext::IScriptable* apContext, TiltedPhoques::Vector* apOrigArgs, - RED4ext::CStackType* apResult, RED4ext::CScriptStack* apStack, - RED4ext::CStackFrame* apFrame, char* apCode, uint8_t aParam) + RED4ext::CStackType* apResult, TiltedPhoques::Vector* apOutArgs, + RED4ext::CScriptStack* apStack, RED4ext::CStackFrame* apFrame, + char* apCode, uint8_t aParam) { if (!aChain.Before.empty()) { @@ -329,19 +320,36 @@ bool FunctionOverride::ExecuteChain(const CallChain& aChain, std::shared_lockflags.isStatic ? 0 : 1; sol::object luaContext = pRealFunction->flags.isStatic ? sol::nil : apOrigArgs->at(0); - TiltedPhoques::Vector luaArgs(apOrigArgs->begin() + argSelfOffset, apOrigArgs->end()); + TiltedPhoques::Vector luaArgs(apOrigArgs->begin() + (pRealFunction->flags.isStatic ? 0 : 1), + apOrigArgs->end()); auto lockedState = aChain.pScripting->GetState(); auto& luaState = lockedState.Get(); - auto wrapped = WrapNextOverride(aChain, 0, luaState, luaContext, luaArgs, pRealFunction, apContext, aLock); - auto result = wrapped(as_args(luaArgs)); + auto luaWrapped = WrapNextOverride(aChain, 0, luaState, luaContext, luaArgs, pRealFunction, apContext, aLock); + auto luaResult = luaWrapped(as_args(luaArgs)); - if (result.valid() && apResult && apResult->value) + if (luaResult.valid()) { - Scripting::ToRED(result.get(), *apResult); + auto luaRetOffset = 0; + + if (apResult && apResult->value) + { + Scripting::ToRED(luaResult.get(), *apResult); + ++luaRetOffset; + } + + if (apOutArgs && !apOutArgs->empty()) + { + for (auto i = 0; i < apOutArgs->size(); ++i) + { + auto luaOutArg = luaResult.get(i + luaRetOffset); + + if (luaOutArg != sol::nil) + Scripting::ToRED(luaOutArg, apOutArgs->at(i)); + } + } } } else @@ -481,7 +489,7 @@ void FunctionOverride::Hook(Options& aOptions) const { DWORD oldProtect; VirtualProtect(pLocation, 0x40, PAGE_READWRITE, &oldProtect); - *pFirstLocation = *pSecondLocation = std::max(sizeof(RED4ext::CClassFunction), sizeof(RED4ext::CScriptedFunction)); + *pFirstLocation = *pSecondLocation = std::max(s_cMaxFunctionSize, sizeof(RED4ext::CScriptedFunction)); VirtualProtect(pLocation, 0x40, oldProtect, &oldProtect); spdlog::info("Override function allocator patched!"); @@ -537,73 +545,82 @@ void FunctionOverride::Override(const std::string& acTypeName, const std::string m_functions[pRealFunction] = {}; pEntry = &m_functions[pRealFunction]; - /* - sub rsp, 56 - mov rax, 0xDEADBEEFC0DEBAAD - mov qword ptr[rsp + 32], rax - mov rax, 0xDEADBEEFC0DEBAAD - call rax - add rsp, 56 - ret - */ - uint8_t payload[] = {0x48, 0x83, 0xEC, 0x38, 0x48, 0xB8, 0xAD, 0xBA, 0xDE, 0xC0, 0xEF, 0xBE, - 0xAD, 0xDE, 0x48, 0x89, 0x44, 0x24, 0x20, 0x48, 0xB8, 0xAD, 0xBA, 0xDE, - 0xC0, 0xEF, 0xBE, 0xAD, 0xDE, 0xFF, 0xD0, 0x48, 0x83, 0xC4, 0x38, 0xC3}; - - auto funcAddr = reinterpret_cast(&FunctionOverride::HandleOverridenFunction); - - std::memcpy(payload + 6, &pRealFunction, 8); - std::memcpy(payload + 21, &funcAddr, 8); - - using TNativeScriptFunction = void (*)(RED4ext::IScriptable*, RED4ext::CStackFrame*, void*, int64_t); - auto* pExecutablePayload = static_cast(MakeExecutable(payload, std::size(payload))); - RED4ext::CBaseFunction* pFunc; - if (pRealFunction->flags.isStatic) + if (!m_trampolines.contains(pRealFunction)) { - if (pRealFunction->flags.isNative) + /* + sub rsp, 56 + mov rax, 0xDEADBEEFC0DEBAAD + mov qword ptr[rsp + 32], rax + mov rax, 0xDEADBEEFC0DEBAAD + call rax + add rsp, 56 + ret + */ + uint8_t payload[] = {0x48, 0x83, 0xEC, 0x38, 0x48, 0xB8, 0xAD, 0xBA, 0xDE, 0xC0, 0xEF, 0xBE, + 0xAD, 0xDE, 0x48, 0x89, 0x44, 0x24, 0x20, 0x48, 0xB8, 0xAD, 0xBA, 0xDE, + 0xC0, 0xEF, 0xBE, 0xAD, 0xDE, 0xFF, 0xD0, 0x48, 0x83, 0xC4, 0x38, 0xC3}; + + auto funcAddr = reinterpret_cast(&FunctionOverride::HandleOverridenFunction); + + std::memcpy(payload + 6, &pRealFunction, 8); + std::memcpy(payload + 21, &funcAddr, 8); + + using TNativeScriptFunction = void (*)(RED4ext::IScriptable*, RED4ext::CStackFrame*, void*, int64_t); + auto* pExecutablePayload = static_cast(MakeExecutable(payload, std::size(payload))); + + if (pRealFunction->flags.isStatic) { - pFunc = RED4ext::CClassStaticFunction::Create(pClassType, acFullName.c_str(), acFullName.c_str(), - pExecutablePayload, pRealFunction->flags); - reinterpret_cast(pFunc)->parent = pRealFunction->parent; + if (pRealFunction->flags.isNative) + { + pFunc = RED4ext::CClassStaticFunction::Create(pClassType, acFullName.c_str(), acFullName.c_str(), + pExecutablePayload, pRealFunction->flags); + reinterpret_cast(pFunc)->parent = pRealFunction->parent; + } + else + { + pFunc = RED4ext::CGlobalFunction::Create(acFullName.c_str(), acFullName.c_str(), pExecutablePayload); + } } else { - pFunc = RED4ext::CGlobalFunction::Create(acFullName.c_str(), acFullName.c_str(), pExecutablePayload); + pFunc = RED4ext::CClassFunction::Create(pClassType, acFullName.c_str(), acFullName.c_str(), + pExecutablePayload, pRealFunction->flags); + reinterpret_cast(pFunc)->parent = pRealFunction->parent; } - } - else - { - pFunc = RED4ext::CClassFunction::Create(pClassType, acFullName.c_str(), acFullName.c_str(), - pExecutablePayload, pRealFunction->flags); - reinterpret_cast(pFunc)->parent = pRealFunction->parent; - } - pFunc->fullName = pRealFunction->fullName; - pFunc->shortName = pRealFunction->shortName; + pFunc->fullName = pRealFunction->fullName; + pFunc->shortName = pRealFunction->shortName; - pFunc->returnType = pRealFunction->returnType; - for (auto* p : pRealFunction->params) - { - pFunc->params.PushBack(p); - } + pFunc->returnType = pRealFunction->returnType; + for (auto* p : pRealFunction->params) + { + pFunc->params.PushBack(p); + } - for (auto* p : pRealFunction->localVars) + for (auto* p : pRealFunction->localVars) + { + pFunc->localVars.PushBack(p); + } + + pFunc->unk20 = pRealFunction->unk20; + pFunc->bytecode = pRealFunction->bytecode; + pFunc->unk48 = pRealFunction->unk48; + pFunc->unkAC = pRealFunction->unkAC; + pFunc->flags = pRealFunction->flags; + pFunc->flags.isNative = true; + + m_trampolines[pRealFunction] = pFunc; + } + else { - pFunc->localVars.PushBack(p); + pFunc = m_trampolines[pRealFunction]; } - pFunc->unk20 = pRealFunction->unk20; - pFunc->bytecode = pRealFunction->bytecode; - pFunc->unk48 = pRealFunction->unk48; - pFunc->unkAC = pRealFunction->unkAC; - pFunc->flags = pRealFunction->flags; - pFunc->flags.isNative = true; - pEntry->Trampoline = pFunc; pEntry->pScripting = m_pScripting; - pEntry->CollectGarbage = aCollectGarbage || pClassType->IsA(s_inkGameControllerType); + pEntry->CollectGarbage = aCollectGarbage; pEntry->IsEmpty = true; // Swap the content of the real function with the one we just created diff --git a/src/scripting/FunctionOverride.h b/src/scripting/FunctionOverride.h index 35c4f0c5..f68e7e8d 100644 --- a/src/scripting/FunctionOverride.h +++ b/src/scripting/FunctionOverride.h @@ -38,8 +38,8 @@ struct FunctionOverride static bool HookRunPureScriptFunction(RED4ext::CClassFunction* apFunction, RED4ext::CScriptStack* apContext, RED4ext::CStackFrame* a3); static bool ExecuteChain(const CallChain& aChain, std::shared_lock& aLock, RED4ext::IScriptable* apContext, TiltedPhoques::Vector* apArgs, - RED4ext::CStackType* apResult, RED4ext::CScriptStack* apStack, - RED4ext::CStackFrame* apFrame, char* apCode, uint8_t aParam); + RED4ext::CStackType* apResult, TiltedPhoques::Vector* apOutArgs, + RED4ext::CScriptStack* apStack, RED4ext::CStackFrame* apFrame, char* apCode, uint8_t aParam); static sol::function WrapNextOverride(const CallChain& aChain, int aStep, sol::state& aLuaState, sol::object& aLuaContext, TiltedPhoques::Vector& aLuaArgs, RED4ext::CBaseFunction* apRealFunction, RED4ext::IScriptable* apRealContext, @@ -57,7 +57,8 @@ struct FunctionOverride void* m_pBufferStart; void* m_pBuffer; size_t m_size{ kExecutableSize }; - TiltedPhoques::Map m_functions; + TiltedPhoques::Map m_functions; + TiltedPhoques::Map m_trampolines; Scripting* m_pScripting; std::shared_mutex m_lock; }; \ No newline at end of file diff --git a/src/scripting/LuaVM.cpp b/src/scripting/LuaVM.cpp index 8e812971..5e7e946e 100644 --- a/src/scripting/LuaVM.cpp +++ b/src/scripting/LuaVM.cpp @@ -25,7 +25,7 @@ void LuaVM::Update(float aDeltaTime) if (!m_initialized) { if (m_logCount.load(std::memory_order_relaxed) > 0) - PostInitialize(); + PostInitializeStage2(); return; } @@ -46,6 +46,7 @@ void LuaVM::ReloadAllMods() if (m_initialized) { m_scripting.ReloadAllMods(); + m_scripting.TriggerOnTweak(); m_scripting.TriggerOnInit(); if (CET::Get().GetOverlay().IsEnabled()) @@ -180,12 +181,17 @@ void LuaVM::RegisterTDBIDString(uint64_t aValue, uint64_t aBase, const std::stri m_tdbidDerivedLookup[aBase].insert(aValue); } -void LuaVM::PostInitialize() +void LuaVM::PostInitializeStage1() { - assert(!m_initialized); + m_scripting.PostInitializeStage1(); + m_scripting.TriggerOnTweak(); +} - m_scripting.PostInitialize(); +void LuaVM::PostInitializeStage2() +{ + assert(!m_initialized); + m_scripting.PostInitializeStage2(); m_scripting.TriggerOnInit(); if (CET::Get().GetOverlay().IsEnabled()) m_scripting.TriggerOnOverlayOpen(); diff --git a/src/scripting/LuaVM.h b/src/scripting/LuaVM.h index 8d820399..ab4ff4a9 100644 --- a/src/scripting/LuaVM.h +++ b/src/scripting/LuaVM.h @@ -11,6 +11,7 @@ using TSetMousePosition = BOOL(void*, HWND, long, long); using TTDBIDCtorDerive = TDBID*(const TDBID*, TDBID*, const char*); using TRunningStateRun = bool(uintptr_t, uintptr_t); using TSetLoadingState = uintptr_t(uintptr_t, int); +using TTweakDBLoad = uint64_t(uintptr_t, uintptr_t); struct TDBIDLookupEntry { @@ -51,7 +52,8 @@ struct LuaVM void RegisterTDBIDString(uint64_t aValue, uint64_t aBase, const std::string& acString); - void PostInitialize(); + void PostInitializeStage1(); + void PostInitializeStage2(); protected: @@ -63,6 +65,7 @@ struct LuaVM static TDBID* HookTDBIDCtorDerive(TDBID* apBase, TDBID* apThis, const char* acpName); static bool HookRunningStateRun(uintptr_t aThis, uintptr_t aApp); static uintptr_t HookSetLoadingState(uintptr_t aThis, int aState); + static uint64_t HookTweakDBLoad(uintptr_t aThis, uintptr_t aParam); private: @@ -77,6 +80,7 @@ struct LuaVM TTDBIDCtorDerive* m_realTDBIDCtorDerive{ nullptr }; TRunningStateRun* m_realRunningStateRun{ nullptr }; TSetLoadingState* m_realSetLoadingState{ nullptr }; + TTweakDBLoad* m_realTweakDBLoad{ nullptr }; std::chrono::time_point m_lastframe; diff --git a/src/scripting/LuaVM_Hooks.cpp b/src/scripting/LuaVM_Hooks.cpp index 30fc14a7..6b106fd8 100644 --- a/src/scripting/LuaVM_Hooks.cpp +++ b/src/scripting/LuaVM_Hooks.cpp @@ -141,13 +141,24 @@ uintptr_t LuaVM::HookSetLoadingState(uintptr_t aThis, int aState) { std::call_once(s_initBarrier, [&luavm]() { - luavm.PostInitialize(); + luavm.PostInitializeStage2(); }); } return luavm.m_realSetLoadingState(aThis, aState); } +uint64_t LuaVM::HookTweakDBLoad(uintptr_t aThis, uintptr_t aParam) +{ + auto& luavm = CET::Get().GetVM(); + + const auto ret = luavm.m_realTweakDBLoad(aThis, aParam); + + luavm.PostInitializeStage1(); + + return ret; +} + void LuaVM::Hook(Options& aOptions) { auto& gameImage = aOptions.GameImage; @@ -241,6 +252,23 @@ void LuaVM::Hook(Options& aOptions) } } } + + { + const mem::pattern cPattern("48 89 5C 24 10 48 89 7C 24 18 4C 89 74 24 20 55 48 8B EC 48 83 EC 70 48"); + const mem::default_scanner cScanner(cPattern); + uint8_t* pLocation = cScanner(gameImage.TextRegion).as(); + + if (pLocation) + { + if (MH_CreateHook(pLocation, &HookTweakDBLoad, reinterpret_cast(&m_realTweakDBLoad)) != MH_OK || + MH_EnableHook(pLocation) != MH_OK) + spdlog::error("Could not hook TweakDB::Load function!"); + else + { + spdlog::info("TweakDB::Load function hook complete!"); + } + } + } // Disable SetLoadingState hook temporarily and get back to log count workaround // as it introduces major breaking change for onInit handler. diff --git a/src/scripting/ScriptContext.cpp b/src/scripting/ScriptContext.cpp index 8960ad2e..86b55032 100644 --- a/src/scripting/ScriptContext.cpp +++ b/src/scripting/ScriptContext.cpp @@ -43,6 +43,8 @@ ScriptContext::ScriptContext(LuaSandbox& aLuaSandbox, const std::filesystem::pat { if(acName == "onInit") m_onInit = aCallback; + else if(acName == "onTweak") + m_onTweak = aCallback; else if(acName == "onShutdown") m_onShutdown = aCallback; else if(acName == "onUpdate") @@ -155,6 +157,13 @@ const TiltedPhoques::Vector& ScriptContext::GetBinds() const return m_vkBindInfos; } +void ScriptContext::TriggerOnTweak() const +{ + auto state = m_sandbox.GetState(); + + TryLuaFunction(m_logger, m_onTweak); +} + void ScriptContext::TriggerOnInit() const { auto state = m_sandbox.GetState(); diff --git a/src/scripting/ScriptContext.h b/src/scripting/ScriptContext.h index 8be8af3b..83d3192e 100644 --- a/src/scripting/ScriptContext.h +++ b/src/scripting/ScriptContext.h @@ -12,6 +12,7 @@ struct ScriptContext const TiltedPhoques::Vector& GetBinds() const; + void TriggerOnTweak() const; void TriggerOnInit() const; void TriggerOnUpdate(float aDeltaTime) const; void TriggerOnDraw() const; @@ -30,6 +31,7 @@ struct ScriptContext LuaSandbox& m_sandbox; size_t m_sandboxID; sol::object m_object{ }; + sol::function m_onTweak{ }; sol::function m_onInit{ }; sol::function m_onShutdown{ }; sol::function m_onUpdate{ }; diff --git a/src/scripting/ScriptStore.cpp b/src/scripting/ScriptStore.cpp index 9995e208..13b0a15c 100644 --- a/src/scripting/ScriptStore.cpp +++ b/src/scripting/ScriptStore.cpp @@ -25,10 +25,16 @@ void ScriptStore::LoadAll() auto fPath = file.path(); - if (is_symlink(fPath)) - fPath = read_symlink(fPath); - else if (is_symlink(fPath / "init.lua")) - fPath = read_symlink(fPath / "init.lua").parent_path(); + try + { + if (is_symlink(fPath)) + fPath = read_symlink(fPath); + else if (is_symlink(fPath / "init.lua")) + fPath = read_symlink(fPath / "init.lua").parent_path(); + } + catch (std::exception& e) + { + } fPath = absolute(fPath); auto fPathStr = fPath.string(); @@ -68,6 +74,12 @@ const TiltedPhoques::Vector& ScriptStore::GetBinds() const return m_vkBindInfos; } +void ScriptStore::TriggerOnTweak() const +{ + for (const auto& kvp : m_contexts) + kvp.second.TriggerOnTweak(); +} + void ScriptStore::TriggerOnInit() const { for (const auto& kvp : m_contexts) diff --git a/src/scripting/ScriptStore.h b/src/scripting/ScriptStore.h index 8265ea7c..85659a34 100644 --- a/src/scripting/ScriptStore.h +++ b/src/scripting/ScriptStore.h @@ -11,6 +11,7 @@ struct ScriptStore const TiltedPhoques::Vector& GetBinds() const; + void TriggerOnTweak() const; void TriggerOnInit() const; void TriggerOnUpdate(float aDeltaTime) const; void TriggerOnDraw() const; diff --git a/src/scripting/Scripting.cpp b/src/scripting/Scripting.cpp index 96d65052..205a389d 100644 --- a/src/scripting/Scripting.cpp +++ b/src/scripting/Scripting.cpp @@ -88,7 +88,7 @@ void Scripting::Initialize() const auto resolution = m_d3d12.GetResolution(); return {static_cast(resolution.cx), static_cast(resolution.cy)}; }; - + luaGlobal["ModArchiveExists"] = [this](const std::string& acArchiveName) -> bool { const auto cAbsPath = absolute(m_paths.ArchiveModsRoot() / acArchiveName); @@ -126,7 +126,7 @@ void Scripting::Initialize() m_store.LoadAll(); } -void Scripting::PostInitialize() +void Scripting::PostInitializeStage1() { auto lua = m_lua.Lock(); auto& luaVm = lua.Get(); @@ -135,8 +135,7 @@ void Scripting::PostInitialize() luaVm.new_usertype("__Game", sol::meta_function::construct, sol::no_constructor, - sol::meta_function::index, &Scripting::Index, - sol::meta_function::new_index, &Scripting::NewIndex); + sol::meta_function::index, &Scripting::Index); luaVm.new_usertype("__Type", sol::meta_function::construct, sol::no_constructor, @@ -389,35 +388,38 @@ void Scripting::PostInitialize() luaVm.new_usertype("LocKey", sol::constructors(), - sol::call_constructor, sol::constructors(), sol::meta_function::to_string, &gamedataLocKeyWrapper::ToString, sol::meta_function::equal_to, &gamedataLocKeyWrapper::operator==, - "hash", &gamedataLocKeyWrapper::hash); - - luaGlobal["LocKey"] = luaVm["LocKey"]; + sol::call_constructor, sol::factories([](sol::object aValue, sol::this_state aState) -> sol::object { + sol::state_view lua(aState); + gamedataLocKeyWrapper result(0); - luaGlobal["NewObject"] = [this](const std::string& acName, sol::this_environment aEnv) -> sol::object - { - auto* pRtti = RED4ext::CRTTISystem::Get(); - auto* pType = pRtti->GetType(RED4ext::FNV1a(acName.c_str())); + if (aValue != sol::nil) + { + if (aValue.get_type() == sol::type::number) + { + result.hash = aValue.as(); + } + else if (IsLuaCData(aValue)) + { + std::string str = lua["tostring"](aValue); + result.hash = std::stoull(str); + } + else if (aValue.get_type() == sol::type::string) + { + result.hash = RED4ext::FNV1a64(aValue.as()); + } + } - if (!pType) - { - const sol::environment cEnv = aEnv; - std::shared_ptr logger = cEnv["__logger"].get>(); - logger->info("Type '{}' not found.", acName); - return sol::nil; - } + return sol::object(lua, sol::in_place, std::move(result)); + }), + "hash", sol::property([](gamedataLocKeyWrapper& aThis, sol::this_state aState) -> sol::object { + sol::state_view lua(aState); + auto converted = lua.script(fmt::format("return {}ull", aThis.hash)); + return converted.get(); + })); - // Always try to return instance wrapped in Handle<> if the type allows it. - // See NewHandle() for more info. - return RTTIHelper::Get().NewHandle(pType, sol::nullopt); - }; - - luaGlobal["GetSingleton"] = [this](const std::string& acName, sol::this_environment aThisEnv) - { - return this->GetSingletonHandle(acName, aThisEnv); - }; + luaGlobal["LocKey"] = luaVm["LocKey"]; luaVm.new_usertype("GameOptions", sol::meta_function::construct, sol::no_constructor, @@ -453,6 +455,39 @@ void Scripting::PostInitialize() luaGlobal["TweakDB"] = TweakDB(m_lua.AsRef()); + m_sandbox.PostInitialize(); +} + +void Scripting::PostInitializeStage2() +{ + auto lua = m_lua.Lock(); + auto& luaVm = lua.Get(); + + sol::table luaGlobal = luaVm[m_global]; + + luaGlobal["NewObject"] = [this](const std::string& acName, sol::this_environment aEnv) -> sol::object + { + auto* pRtti = RED4ext::CRTTISystem::Get(); + auto* pType = pRtti->GetType(RED4ext::FNV1a(acName.c_str())); + + if (!pType) + { + const sol::environment cEnv = aEnv; + std::shared_ptr logger = cEnv["__logger"].get>(); + logger->info("Type '{}' not found.", acName); + return sol::nil; + } + + // Always try to return instance wrapped in Handle<> if the type allows it. + // See NewHandle() for more info. + return RTTIHelper::Get().NewHandle(pType, sol::nullopt); + }; + + luaGlobal["GetSingleton"] = [this](const std::string& acName, sol::this_environment aThisEnv) + { + return this->GetSingletonHandle(acName, aThisEnv); + }; + luaGlobal["Override"] = [this](const std::string& acTypeName, const std::string& acFullName, sol::protected_function aFunction, sol::this_environment aThisEnv) -> void { m_override.Override(acTypeName, acFullName, aFunction, aThisEnv, true); @@ -470,6 +505,12 @@ void Scripting::PostInitialize() luaGlobal["Observe"] = luaGlobal["ObserveBefore"]; + // Doesn't require RTTI, but still shouldn't be used before onInit as there is no guarantee that the required mod will be loaded before + luaGlobal["GetMod"] = [this](const std::string& acName) -> sol::object + { + return GetMod(acName); + }; + luaGlobal["GameDump"] = [this](Type* apType) { return apType ? apType->GameDump() : "Null"; @@ -490,7 +531,7 @@ void Scripting::PostInitialize() const ClassType type(m_lua.AsRef(), pType); return type.Dump(aDetailed); }; - + luaGlobal["DumpAllTypeNames"] = [this](sol::this_environment aThisEnv) { auto* pRtti = RED4ext::CRTTISystem::Get(); @@ -522,12 +563,6 @@ void Scripting::PostInitialize() }; #endif - // Doesn't require RTTI, but still shouldn't be used before onInit as there is no guarantee that the required mod will be loaded before - luaGlobal["GetMod"] = [this](const std::string& acName) -> sol::object - { - return GetMod(acName); - }; - luaVm["Game"] = this; luaGlobal["Game"] = luaVm["Game"]; @@ -543,7 +578,7 @@ void Scripting::RegisterOverrides() auto lua = m_lua.Lock(); auto& luaVm = lua.Get(); - luaVm["RegisterGlobalInputListener"] = [](StrongReference& aSelf, sol::this_environment aThisEnv) { + luaVm["RegisterGlobalInputListener"] = [](WeakReference& aSelf, sol::this_environment aThisEnv) { sol::protected_function unregisterInputListener = aSelf.Index("UnregisterInputListener", aThisEnv); sol::protected_function registerInputListener = aSelf.Index("RegisterInputListener", aThisEnv); @@ -561,6 +596,11 @@ const TiltedPhoques::Vector& Scripting::GetBinds() const return m_store.GetBinds(); } +void Scripting::TriggerOnTweak() const +{ + m_store.TriggerOnTweak(); +} + void Scripting::TriggerOnInit() const { m_store.TriggerOnInit(); @@ -644,18 +684,6 @@ sol::object Scripting::Index(const std::string& acName, sol::this_state aState, return itor->second; } - return InternalIndex(acName, aState, aEnv); -} - -sol::object Scripting::NewIndex(const std::string& acName, sol::object aParam) -{ - auto& property = m_properties[acName]; - property = std::move(aParam); - return property; -} - -sol::protected_function Scripting::InternalIndex(const std::string& acName, sol::this_state aState, sol::this_environment aEnv) -{ auto func = RTTIHelper::Get().ResolveFunction(acName); if (!func) @@ -676,7 +704,9 @@ sol::protected_function Scripting::InternalIndex(const std::string& acName, sol: return sol::nil; } - return NewIndex(acName, std::move(func)); + auto& property = m_properties[acName]; + property = std::move(func); + return property; } sol::object Scripting::GetSingletonHandle(const std::string& acName, sol::this_environment aThisEnv) diff --git a/src/scripting/Scripting.h b/src/scripting/Scripting.h index 5c341c83..a4fe05ae 100644 --- a/src/scripting/Scripting.h +++ b/src/scripting/Scripting.h @@ -15,10 +15,12 @@ struct Scripting ~Scripting() = default; void Initialize(); - void PostInitialize(); + void PostInitializeStage1(); + void PostInitializeStage2(); const TiltedPhoques::Vector& GetBinds() const; + void TriggerOnTweak() const; void TriggerOnInit() const; void TriggerOnUpdate(float aDeltaTime) const; void TriggerOnDraw() const; @@ -43,10 +45,7 @@ struct Scripting void RegisterOverrides(); sol::object Index(const std::string& acName, sol::this_state aState, sol::this_environment aEnv); - sol::object NewIndex(const std::string& acName, sol::object aParam); sol::object GetSingletonHandle(const std::string& acName, sol::this_environment aThisEnv); - sol::protected_function InternalIndex(const std::string& acName, sol::this_state aState, - sol::this_environment aEnv); private: TiltedPhoques::Lockable m_lua;