diff --git a/polyhook.pplugin.in b/polyhook.pplugin.in index 7f1f560..7a5e47f 100644 --- a/polyhook.pplugin.in +++ b/polyhook.pplugin.in @@ -181,6 +181,23 @@ "type": "ptr64" } }, + { + "name": "FindOriginalAddr", + "funcName": "FindOriginalAddr", + "paramTypes": [ + { + "type": "ptr64", + "name": "pClass" + }, + { + "type": "ptr64", + "name": "pAddr" + } + ], + "retType": { + "type": "ptr64" + } + }, { "name": "UnhookAll", "funcName": "UnhookAll", diff --git a/src/callback.cpp b/src/callback.cpp index 2e19201..378dbd2 100644 --- a/src/callback.cpp +++ b/src/callback.cpp @@ -3,59 +3,54 @@ #include "polyhook2/MemProtector.hpp" template -constexpr asmjit::TypeId GetTypeIdx() noexcept { +constexpr asmjit::TypeId getTypeIdx() noexcept { return static_cast(asmjit::TypeUtils::TypeIdOfT::kTypeId); } asmjit::TypeId PLH::Callback::getTypeId(DataType type) { switch (type) { case DataType::Void: - return GetTypeIdx(); + return getTypeIdx(); case DataType::Bool: - return GetTypeIdx(); + return getTypeIdx(); case DataType::Int8: - return GetTypeIdx(); + return getTypeIdx(); case DataType::Int16: - return GetTypeIdx(); + return getTypeIdx(); case DataType::Int32: - return GetTypeIdx(); + return getTypeIdx(); case DataType::Int64: - return GetTypeIdx(); + return getTypeIdx(); case DataType::UInt8: - return GetTypeIdx(); + return getTypeIdx(); case DataType::UInt16: - return GetTypeIdx(); + return getTypeIdx(); case DataType::UInt32: - return GetTypeIdx(); + return getTypeIdx(); case DataType::UInt64: - return GetTypeIdx(); + return getTypeIdx(); case DataType::Float: - return GetTypeIdx(); + return getTypeIdx(); case DataType::Double: - return GetTypeIdx(); + return getTypeIdx(); case DataType::Pointer: case DataType::String: case DataType::WString: - case DataType::Matrix4x4: return asmjit::TypeId::kUIntPtr; - case DataType::Vector2: -#if _WIN32 - return asmjit::TypeId::kInt64; -#else - return asmjit::TypeId::kFloat64; -#endif - case DataType::Vector3: - case DataType::Vector4: -#if _WIN32 - return TypeId::kUIntPtr; -#else - return asmjit::TypeId::kFloat32x4; -#endif } return asmjit::TypeId::kVoid; } uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmjit::Arch arch, const PLH::Callback::CallbackEntry callback) { + if (m_functionPtr) { + return m_functionPtr; + } + + if (!g_jitRuntime) { + m_errorCode = "JitRuntime invalid"; + return 0; + } + /*AsmJit is smart enough to track register allocations and will forward the proper registers the right values and fixup any it dirtied earlier. This can only be done if it knows the signature, and ABI, so we give it @@ -94,7 +89,6 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmji // Create labels asmjit::Label supercede = cc.newLabel(); - asmjit::Label override = cc.newLabel(); asmjit::Label noPost = cc.newLabel(); // map argument slots to registers, following abi. @@ -109,7 +103,7 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmji } else if (asmjit::TypeUtils::isFloat(argType)) { arg = cc.newXmm(); } else { - //Log::log("Parameters wider than 64bits not supported", ErrorLevel::SEV); + m_errorCode = "Parameters wider than 64bits not supported"; return 0; } @@ -117,35 +111,6 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmji argRegisters.push_back(std::move(arg)); } -#if _WIN32 - uint32_t retSize = (uint32_t)(sizeof(uint64_t)); -#else - bool isPod = asmjit::TypeUtils::isVec128(sig.ret()); - bool isIntPod = asmjit::TypeUtils::isBetween(sig.ret(), asmjit::TypeId::kInt8x16, asmjit::TypeId::kUInt64x2); - bool isFloatPod = asmjit::TypeUtils::isBetween(sig.ret(), asmjit::TypeId::kFloat32x4, asmjit::TypeId::kFloat64x2); - uint32_t retSize = (uint32_t)(sizeof(uint64_t) * (isPod ? 2 : 1)); -#endif - - std::vector retRegisters; - retRegisters.reserve(2); - if (sig.hasRet()) { - if (asmjit::TypeUtils::isInt(sig.ret())) { - retRegisters.push_back(cc.newUIntPtr()); - } -#if !_WIN32 - /*else if (isIntPod) { - retRegisters.push_back(cc.newUIntPtr()); - retRegisters.push_back(cc.newUIntPtr()); - } else if (isFloatPod) { - retRegisters.push_back(cc.newXmm()); - retRegisters.push_back(cc.newXmm()); - }*/ -#endif - else { - retRegisters.push_back(cc.newXmm()); - } - } - // setup the stack structure to hold arguments for user callback uint32_t stackSize = (uint32_t)(sizeof(uint64_t) * sig.argCount()); asmjit::x86::Mem argsStack = cc.newStack(stackSize, 16); @@ -172,7 +137,7 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmji } else if(asmjit::TypeUtils::isFloat(argType)) { cc.movq(argsStackIdx, argRegisters.at(argIdx).as()); } else { - //Log::log("Parameters wider than 64bits not supported", ErrorLevel::SEV); + m_errorCode = "Parameters wider than 64bits not supported"; return 0; } @@ -197,6 +162,7 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmji cc.mov(argCountParam, (uint8_t)sig.argCount()); // create buffer for ret val + uint32_t retSize = (uint32_t)(sizeof(uint64_t)); asmjit::x86::Mem retStack = cc.newStack(retSize, 16); asmjit::x86::Gp retStruct = cc.newUIntPtr("retStruct"); cc.lea(retStruct, retStack); @@ -251,35 +217,22 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmji for (uint8_t argIdx = 0; argIdx < sig.argCount(); argIdx++) { origInvokeNode->setArg(argIdx, argRegisters.at(argIdx)); } - for (size_t retIdx = 0; retIdx < retRegisters.size(); retIdx++) { - origInvokeNode->setRet(retIdx, retRegisters[retIdx]); - } if (sig.hasRet()) { + asmjit::x86::Reg retRegister; + if (asmjit::TypeUtils::isInt(sig.ret())) { + retRegister = cc.newUIntPtr(); + } else { + retRegister = cc.newXmm(); + } + origInvokeNode->setRet(0, retRegister); + asmjit::x86::Mem retStackIdx(retStack); retStackIdx.setSize(sizeof(uint64_t)); if (asmjit::TypeUtils::isInt(sig.ret())) { - cc.mov(retStackIdx, retRegisters.at(0).as()); - } -#if !_WIN32 - /*else if (isIntPod) { - asmjit::x86::Mem retStackIdxUpper(retStack); - retStackIdxUpper.addOffset(sizeof(uint64_t)); - retStackIdxUpper.setSize(sizeof(uint64_t)); - - cc.mov(retStackIdx, retRegisters.at(0).as()); - cc.mov(retStackIdxUpper, retRegisters.at(1).as()); - } else if (isFloatPod) { - asmjit::x86::Mem retStackIdxUpper(retStack); - retStackIdxUpper.addOffset(sizeof(uint64_t)); - retStackIdxUpper.setSize(sizeof(uint64_t)); - - cc.movq(retStackIdx, retRegisters.at(0).as()); - cc.movq(retStackIdxUpper, retRegisters.at(1).as()); - }*/ -#endif - else { - cc.movq(retStackIdx, retRegisters.at(0).as()); + cc.mov(retStackIdx, retRegister.as()); + } else { + cc.movq(retStackIdx, retRegister.as()); } } @@ -327,47 +280,13 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmji cc.bind(noPost); if (sig.hasRet()) { - //cc.test(retValue, ReturnFlag::Override); - //cc.jnz(override); -/* -#if !_WIN32 - if (isPod) { - cc.ret(retRegisters.at(0), retRegisters.at(1)); - } else -#endif - { - cc.ret(retRegisters.at(0)); - }*/ - - //cc.bind(override); - asmjit::x86::Mem retStackIdx(retStack); retStackIdx.setSize(sizeof(uint64_t)); if (asmjit::TypeUtils::isInt(sig.ret())) { asmjit::x86::Gp tmp = cc.newUIntPtr(); cc.mov(tmp, retStackIdx); cc.ret(tmp); - } -#if !_WIN32 - /*else if (isIntPod) { - asmjit::x86::Mem retStackIdxUpper(retStack); - retStackIdxUpper.addOffset(sizeof(uint64_t)); - retStackIdxUpper.setSize(sizeof(uint64_t)); - - cc.mov(asmjit::x86::rax, retStackIdx); - cc.mov(asmjit::x86::rdx, retStackIdxUpper); - cc.ret(); - } else if (isFloatPod) { - asmjit::x86::Mem retStackIdxUpper(retStack); - retStackIdxUpper.addOffset(sizeof(uint64_t)); - retStackIdxUpper.setSize(sizeof(uint64_t)); - - cc.movq(asmjit::x86::xmm0, retStackIdx); - cc.movq(asmjit::x86::xmm1, retStackIdxUpper); - cc.ret(); - }*/ -#endif - else { + } else { asmjit::x86::Xmm tmp = cc.newXmm(); cc.movq(tmp, retStackIdx); cc.ret(tmp); @@ -378,35 +297,16 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmji cc.endFunc(); - // write to buffer cc.finalize(); - // worst case, overestimates for case trampolines needed - code.flatten(); - size_t size = code.codeSize(); - - // Allocate a virtual memory (executable). - m_callbackBuf = std::make_unique(size); - if (!m_callbackBuf) { - PolyHook2DebugBreak(); + if (asmjit::Error err = g_jitRuntime->add(&m_functionPtr, &code)) { + m_functionPtr = 0; + m_errorCode = asmjit::DebugUtils::errorAsString(err); return 0; } - m_callbackPtr = (uint64_t) m_callbackBuf.get(); - - MemoryProtector protector(m_callbackPtr, size, R | W | X, *this, false); - - // if multiple sections, resolve linkage (1 atm) - if (code.hasUnresolvedLinks()) { - code.resolveUnresolvedLinks(); - } - - // Relocate to the base-address of the allocated memory. - code.relocateToBase(m_callbackPtr); - code.copyFlattenedData(m_callbackBuf.get(), size); - //Log::log("JIT Stub:\n" + std::string(log.data()), ErrorLevel::INFO); - return m_callbackPtr; + return m_functionPtr; } uint64_t PLH::Callback::getJitFunc(DataType retType, const std::vector& paramTypes, const asmjit::Arch arch, const CallbackEntry callback) { @@ -485,5 +385,23 @@ uint64_t* PLH::Callback::getTrampolineHolder() { } uint64_t* PLH::Callback::getCallbackHolder() { - return &m_callbackPtr; + return &m_functionPtr; +} + +std::string_view PLH::Callback::getError() const { + return !m_functionPtr && m_errorCode ? m_errorCode : ""; +} + +PLH::Callback::Callback() { +} + +PLH::Callback::Callback(Callback&& other) noexcept : m_callbacks(std::move(m_callbacks)), m_functionPtr(other.m_functionPtr), m_trampolinePtr(other.m_trampolinePtr) { + other.m_functionPtr = 0; + other.m_trampolinePtr = 0; +} + +PLH::Callback::~Callback() { + if (m_functionPtr) { + g_jitRuntime->release(m_functionPtr); + } } \ No newline at end of file diff --git a/src/callback.hpp b/src/callback.hpp index b06bb43..79be36e 100644 --- a/src/callback.hpp +++ b/src/callback.hpp @@ -10,6 +10,9 @@ #include "polyhook2/Enums.hpp" #include "polyhook2/MemAccessor.hpp" +#include +#include + namespace PLH { enum class DataType : uint8_t { Void, @@ -26,12 +29,7 @@ namespace PLH { Double, Pointer, String, - WString, - // TODO: Only support return for now - Vector2, - Vector3, - Vector4, - Matrix4x4 + WString }; enum class ReturnAction : int32_t { @@ -49,21 +47,20 @@ namespace PLH { enum class ReturnFlag : uint8_t { Default = 0, ///< Value means this gives no information about return flag. NoPost = 1, - Override = 2, - Supercede = 4, + Supercede = 2, }; - class Callback final : public MemAccessor { + class Callback { public: struct Parameters { template void setArg(const uint8_t idx, const T val) const { - *(T*)getArgPtr(idx) = val; + *(T*) getArgPtr(idx) = val; } template T getArg(const uint8_t idx) const { - return *(T*)getArgPtr(idx); + return *(T*) getArgPtr(idx); } // asm depends on this specific type @@ -96,16 +93,17 @@ namespace PLH { typedef ReturnFlag (*CallbackEntry)(Callback* callback, CallbackType type, const Parameters* params, uint8_t count, const ReturnValue* ret); typedef ReturnAction (*CallbackHandler)(CallbackType type, const Parameters* params, int count, const ReturnValue* ret); - Callback() = default; - ~Callback() override = default; + Callback(); + Callback(Callback&& callback) noexcept; + ~Callback(); uint64_t getJitFunc(const asmjit::FuncSignature& sig, asmjit::Arch arch, CallbackEntry callback); - uint64_t getJitFunc(DataType retType, const std::vector& paramTypes, asmjit::Arch arch, CallbackEntry callback); uint64_t* getTrampolineHolder(); uint64_t* getCallbackHolder(); std::vector& getCallbacks(CallbackType type); + std::string_view getError() const; bool addCallback(CallbackType type, CallbackHandler callback); bool removeCallback(CallbackType type, CallbackHandler callback); @@ -116,11 +114,15 @@ namespace PLH { private: asmjit::TypeId getTypeId(DataType type); - std::unique_ptr m_callbackBuf; // May be asmjit::JitRuntime ? std::array, 2> m_callbacks; - uint64_t m_callbackPtr = 0; - uint64_t m_trampolinePtr = 0; + uint64_t m_functionPtr = 0; + union { + uint64_t m_trampolinePtr = 0; + const char* m_errorCode; + }; }; + + extern std::unique_ptr g_jitRuntime; } inline PLH::ReturnFlag operator|(PLH::ReturnFlag lhs, PLH::ReturnFlag rhs) noexcept { diff --git a/src/plugin.cpp b/src/plugin.cpp index 0ef0251..abd5071 100644 --- a/src/plugin.cpp +++ b/src/plugin.cpp @@ -3,6 +3,8 @@ PLH::PolyHookPlugin g_polyHookPlugin; EXPOSE_PLUGIN(PLUGIN_API, &g_polyHookPlugin) +std::unique_ptr PLH::g_jitRuntime; + static PLH::ReturnFlag GlobalCallback(PLH::Callback* callback, PLH::CallbackType type, const PLH::Callback::Parameters* params, uint8_t count, const PLH::Callback::ReturnValue* ret) { PLH::ReturnAction returnAction = PLH::ReturnAction::Ignored; @@ -15,22 +17,66 @@ static PLH::ReturnFlag GlobalCallback(PLH::Callback* callback, PLH::CallbackType returnAction = result; } + /* + *if (type == CallbackType::Post) { + ReturnAction lastPreReturnAction = m_lastPreReturnAction.back(); + m_lastPreReturnAction.pop_back(); + if (lastPreReturnAction >= ReturnAction::Override) + m_callingConvention->restoreReturnValue(m_registers); + if (lastPreReturnAction < ReturnAction::Supercede) + m_callingConvention->restoreCallArguments(m_registers); + } + + ReturnAction returnAction = ReturnAction::Ignored; + auto it = m_handlers.find(type); + if (it == m_handlers.end()) { + // still save the arguments for the post hook even if there + // is no pre-handler registered. + if (type == CallbackType::Pre) { + m_lastPreReturnAction.push_back(returnAction); + m_callingConvention->saveCallArguments(m_registers); + } + return returnAction; + } + + const std::vector& callbacks = it->second; + + for (const CallbackHandler callback : callbacks) { + ReturnAction result = callback(type, *this); + if (result > returnAction) + returnAction = result; + } + + if (type == CallbackType::Pre) { + m_lastPreReturnAction.push_back(returnAction); + if (returnAction >= ReturnAction::Override) + m_callingConvention->saveReturnValue(m_registers); + if (returnAction < ReturnAction::Supercede) + m_callingConvention->saveCallArguments(m_registers); + } + + return returnAction; + */ + PLH::ReturnFlag state = PLH::ReturnFlag::Default; if (type == PLH::CallbackType::Pre) { if (!callback->areCallbacksRegistered(PLH::CallbackType::Post)) { state |= PLH::ReturnFlag::NoPost; } - if (returnAction >= PLH::ReturnAction::Override) { - state |= PLH::ReturnFlag::Override; - if (returnAction == PLH::ReturnAction::Supercede) { - state |= PLH::ReturnFlag::Supercede; - } + if (returnAction >= PLH::ReturnAction::Supercede) { + state |= PLH::ReturnFlag::Supercede; } } return state; } -#include +void PLH::PolyHookPlugin::OnPluginStart() { + g_jitRuntime = std::make_unique(); +} + +void PLH::PolyHookPlugin::OnPluginEnd() { + g_jitRuntime.reset(); +} PLH::Callback* PLH::PolyHookPlugin::hookDetour(void* pFunc, DataType returnType, const std::vector& arguments) { if (!pFunc) @@ -47,6 +93,12 @@ PLH::Callback* PLH::PolyHookPlugin::hookDetour(void* pFunc, DataType returnType, } auto callback = std::make_unique(); + auto error = callback->getError(); + if (!error.empty()) { + // Log ? + return nullptr; + } + uint64_t JIT = callback->getJitFunc(returnType, arguments, asmjit::Arch::kHost, &GlobalCallback); auto detour = std::make_unique((uint64_t)pFunc, JIT, callback->getTrampolineHolder()); @@ -74,6 +126,12 @@ PLH::Callback* PLH::PolyHookPlugin::hookVirtual(void* pClass, int index, DataTyp } auto callback = std::make_unique(); + auto error = callback->getError(); + if (!error.empty()) { + // Log ? + return nullptr; + } + uint64_t JIT = callback->getJitFunc(returnType, arguments, asmjit::Arch::kHost, &GlobalCallback); auto& [redirectMap, origVFuncs] = m_tables[pClass]; @@ -86,15 +144,6 @@ PLH::Callback* PLH::PolyHookPlugin::hookVirtual(void* pClass, int index, DataTyp uint64_t origVFunc = origVFuncs[index]; *callback->getTrampolineHolder() = origVFunc; - // Step 1: Load the shared library - /*void* handle = dlopen("/home/qubka/.steam/cs2/game/csgo/addons/plugify/bin/linuxsteamrt64/libplugify.so", RTLD_LAZY); - if (handle != nullptr) { - using SourceHookPatchFunc = void (*)(uint64_t, uint64_t*); - SourceHookPatchFunc SourceHookPatch = (SourceHookPatchFunc) dlsym(handle, "Plugify_SourceHookPatch"); - - SourceHookPatch(origVFunc, callback->getCallbackHolder()); - }*/ - void* key = m_vtables.emplace(pClass, std::move(vtable)).first->second.get(); return m_callbacks.emplace(std::pair{key, index}, std::move(callback)).first->second.get(); } @@ -186,6 +235,19 @@ PLH::Callback* PLH::PolyHookPlugin::findVirtual(void* pClass, void* pFunc) const return findVirtual(pClass, getVTableIndex(pFunc)); } +void* PLH::PolyHookPlugin::findOriginalAddr(void* pClass, void* pAddr) { + auto it = m_tables.find(pClass); + if (it != m_tables.end()) { + auto& [redirectMap, origVFuncs] = it->second; + for (const auto& [index, addr] : redirectMap) { + if ((void*) addr == pAddr) { + return (void*) origVFuncs[index]; + } + } + } + return nullptr; +} + void PLH::PolyHookPlugin::unhookAll() { std::lock_guard m_lock(m_mutex); @@ -293,7 +355,7 @@ int PLH::PolyHookPlugin::getVTableIndex(void* pFunc) const { if (!ok) return -1; - constexpr int PtrSize = static_cast(sizeof(void*)); + constexpr int PtrSize = (int)(sizeof(void*)); if (*addr++ == 0xFF) { if (*addr == 0x60) @@ -360,6 +422,11 @@ PLUGIN_API PLH::Callback* FindVirtualByFunc(void* pClass, void* pFunc) { return g_polyHookPlugin.findVirtual(pClass, pFunc); } +extern "C" +PLUGIN_API void* FindOriginalAddr(void* pClass, void* pAddr) { + return g_polyHookPlugin.findOriginalAddr(pClass, pAddr); +} + extern "C" PLUGIN_API void UnhookAll() { return g_polyHookPlugin.unhookAll(); diff --git a/src/plugin.hpp b/src/plugin.hpp index 26714c3..59dfc0a 100644 --- a/src/plugin.hpp +++ b/src/plugin.hpp @@ -1,8 +1,9 @@ #pragma once +#include "callback.hpp" + #include #include -#include "callback.hpp" #include #include @@ -11,6 +12,8 @@ #include #include +#include +#include namespace PLH { class PolyHookPlugin final : public plg::IPluginEntry, public MemAccessor { @@ -19,19 +22,23 @@ namespace PLH { ~PolyHookPlugin() override = default; private: - void OnPluginStart() override {} - void OnPluginEnd() override {} + void OnPluginStart() override; + void OnPluginEnd() override; public: Callback* hookDetour(void* pFunc, DataType returnType, const std::vector& arguments); Callback* hookVirtual(void* pClass, int index, DataType returnType, const std::vector& arguments); Callback* hookVirtual(void* pClass, void* pFunc, DataType returnType, const std::vector& arguments); + bool unhookDetour(void* pFunc); bool unhookVirtual(void* pClass, int index); bool unhookVirtual(void* pClass, void* pFunc); + Callback* findDetour(void* pFunc) const; - Callback* findVirtual(void* pClass, int index) const; Callback* findVirtual(void* pClass, void* pFunc) const; + Callback* findVirtual(void* pClass, int index) const; + + void* findOriginalAddr(void* pClass, void* pAddr); void unhookAll(); void unhookAllVirtual(void* pClass); diff --git a/sym/exported_symbols.lds b/sym/exported_symbols.lds index 8aacde6..72f3bda 100644 --- a/sym/exported_symbols.lds +++ b/sym/exported_symbols.lds @@ -10,6 +10,7 @@ UnhookVirtualByFunc FindDetour FindVirtual FindVirtualByFunc +FindOriginalAddr UnhookAll HookAllVirtual AddCallback diff --git a/sym/version_script.lds b/sym/version_script.lds index 3854f90..7061085 100644 --- a/sym/version_script.lds +++ b/sym/version_script.lds @@ -12,6 +12,7 @@ FindDetour; FindVirtual; FindVirtualByFunc; + FindOriginalAddr; UnhookAll; HookAllVirtual; AddCallback;