Skip to content

Commit

Permalink
Use JitAllocator + added new functions
Browse files Browse the repository at this point in the history
  • Loading branch information
qubka committed Sep 16, 2024
1 parent 95a4e91 commit df50466
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 179 deletions.
17 changes: 17 additions & 0 deletions polyhook.pplugin.in
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,23 @@
"type": "ptr64"
}
},
{
"name": "FindOriginalAddr",
"funcName": "FindOriginalAddr",
"paramTypes": [
{
"type": "ptr64",
"name": "pClass"
},
{
"type": "ptr64",
"name": "pAddr"
}
],
"retType": {
"type": "ptr64"
}
},
{
"name": "UnhookAll",
"funcName": "UnhookAll",
Expand Down
202 changes: 60 additions & 142 deletions src/callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,59 +3,54 @@
#include "polyhook2/MemProtector.hpp"

template<typename T>
constexpr asmjit::TypeId GetTypeIdx() noexcept {
constexpr asmjit::TypeId getTypeIdx() noexcept {
return static_cast<asmjit::TypeId>(asmjit::TypeUtils::TypeIdOfT<T>::kTypeId);
}

asmjit::TypeId PLH::Callback::getTypeId(DataType type) {
switch (type) {
case DataType::Void:
return GetTypeIdx<void>();
return getTypeIdx<void>();
case DataType::Bool:
return GetTypeIdx<bool>();
return getTypeIdx<bool>();
case DataType::Int8:
return GetTypeIdx<int8_t>();
return getTypeIdx<int8_t>();
case DataType::Int16:
return GetTypeIdx<int16_t>();
return getTypeIdx<int16_t>();
case DataType::Int32:
return GetTypeIdx<int32_t>();
return getTypeIdx<int32_t>();
case DataType::Int64:
return GetTypeIdx<int64_t>();
return getTypeIdx<int64_t>();
case DataType::UInt8:
return GetTypeIdx<uint8_t>();
return getTypeIdx<uint8_t>();
case DataType::UInt16:
return GetTypeIdx<uint16_t>();
return getTypeIdx<uint16_t>();
case DataType::UInt32:
return GetTypeIdx<uint32_t>();
return getTypeIdx<uint32_t>();
case DataType::UInt64:
return GetTypeIdx<uint64_t>();
return getTypeIdx<uint64_t>();
case DataType::Float:
return GetTypeIdx<float>();
return getTypeIdx<float>();
case DataType::Double:
return GetTypeIdx<double>();
return getTypeIdx<double>();
case DataType::Pointer:
case DataType::String:
case DataType::WString:
case DataType::Matrix4x4:
return asmjit::TypeId::kUIntPtr;
case DataType::Vector2:
#if _WIN32
return asmjit::TypeId::kInt64;
#else
return asmjit::TypeId::kFloat64;
#endif
case DataType::Vector3:
case DataType::Vector4:
#if _WIN32
return TypeId::kUIntPtr;
#else
return asmjit::TypeId::kFloat32x4;
#endif
}
return asmjit::TypeId::kVoid;
}

uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmjit::Arch arch, const PLH::Callback::CallbackEntry callback) {
if (m_functionPtr) {
return m_functionPtr;
}

if (!g_jitRuntime) {
m_errorCode = "JitRuntime invalid";
return 0;
}

/*AsmJit is smart enough to track register allocations and will forward
the proper registers the right values and fixup any it dirtied earlier.
This can only be done if it knows the signature, and ABI, so we give it
Expand Down Expand Up @@ -94,7 +89,6 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmji

// Create labels
asmjit::Label supercede = cc.newLabel();
asmjit::Label override = cc.newLabel();
asmjit::Label noPost = cc.newLabel();

// map argument slots to registers, following abi.
Expand All @@ -109,43 +103,14 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmji
} else if (asmjit::TypeUtils::isFloat(argType)) {
arg = cc.newXmm();
} else {
//Log::log("Parameters wider than 64bits not supported", ErrorLevel::SEV);
m_errorCode = "Parameters wider than 64bits not supported";
return 0;
}

func->setArg(argIdx, arg);
argRegisters.push_back(std::move(arg));
}

#if _WIN32
uint32_t retSize = (uint32_t)(sizeof(uint64_t));
#else
bool isPod = asmjit::TypeUtils::isVec128(sig.ret());
bool isIntPod = asmjit::TypeUtils::isBetween(sig.ret(), asmjit::TypeId::kInt8x16, asmjit::TypeId::kUInt64x2);
bool isFloatPod = asmjit::TypeUtils::isBetween(sig.ret(), asmjit::TypeId::kFloat32x4, asmjit::TypeId::kFloat64x2);
uint32_t retSize = (uint32_t)(sizeof(uint64_t) * (isPod ? 2 : 1));
#endif

std::vector<asmjit::x86::Reg> retRegisters;
retRegisters.reserve(2);
if (sig.hasRet()) {
if (asmjit::TypeUtils::isInt(sig.ret())) {
retRegisters.push_back(cc.newUIntPtr());
}
#if !_WIN32
/*else if (isIntPod) {
retRegisters.push_back(cc.newUIntPtr());
retRegisters.push_back(cc.newUIntPtr());
} else if (isFloatPod) {
retRegisters.push_back(cc.newXmm());
retRegisters.push_back(cc.newXmm());
}*/
#endif
else {
retRegisters.push_back(cc.newXmm());
}
}

// setup the stack structure to hold arguments for user callback
uint32_t stackSize = (uint32_t)(sizeof(uint64_t) * sig.argCount());
asmjit::x86::Mem argsStack = cc.newStack(stackSize, 16);
Expand All @@ -172,7 +137,7 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmji
} else if(asmjit::TypeUtils::isFloat(argType)) {
cc.movq(argsStackIdx, argRegisters.at(argIdx).as<asmjit::x86::Xmm>());
} else {
//Log::log("Parameters wider than 64bits not supported", ErrorLevel::SEV);
m_errorCode = "Parameters wider than 64bits not supported";
return 0;
}

Expand All @@ -197,6 +162,7 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmji
cc.mov(argCountParam, (uint8_t)sig.argCount());

// create buffer for ret val
uint32_t retSize = (uint32_t)(sizeof(uint64_t));
asmjit::x86::Mem retStack = cc.newStack(retSize, 16);
asmjit::x86::Gp retStruct = cc.newUIntPtr("retStruct");
cc.lea(retStruct, retStack);
Expand Down Expand Up @@ -251,35 +217,22 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmji
for (uint8_t argIdx = 0; argIdx < sig.argCount(); argIdx++) {
origInvokeNode->setArg(argIdx, argRegisters.at(argIdx));
}
for (size_t retIdx = 0; retIdx < retRegisters.size(); retIdx++) {
origInvokeNode->setRet(retIdx, retRegisters[retIdx]);
}

if (sig.hasRet()) {
asmjit::x86::Reg retRegister;
if (asmjit::TypeUtils::isInt(sig.ret())) {
retRegister = cc.newUIntPtr();
} else {
retRegister = cc.newXmm();
}
origInvokeNode->setRet(0, retRegister);

asmjit::x86::Mem retStackIdx(retStack);
retStackIdx.setSize(sizeof(uint64_t));
if (asmjit::TypeUtils::isInt(sig.ret())) {
cc.mov(retStackIdx, retRegisters.at(0).as<asmjit::x86::Gp>());
}
#if !_WIN32
/*else if (isIntPod) {
asmjit::x86::Mem retStackIdxUpper(retStack);
retStackIdxUpper.addOffset(sizeof(uint64_t));
retStackIdxUpper.setSize(sizeof(uint64_t));
cc.mov(retStackIdx, retRegisters.at(0).as<asmjit::x86::Gp>());
cc.mov(retStackIdxUpper, retRegisters.at(1).as<asmjit::x86::Gp>());
} else if (isFloatPod) {
asmjit::x86::Mem retStackIdxUpper(retStack);
retStackIdxUpper.addOffset(sizeof(uint64_t));
retStackIdxUpper.setSize(sizeof(uint64_t));
cc.movq(retStackIdx, retRegisters.at(0).as<asmjit::x86::Xmm>());
cc.movq(retStackIdxUpper, retRegisters.at(1).as<asmjit::x86::Xmm>());
}*/
#endif
else {
cc.movq(retStackIdx, retRegisters.at(0).as<asmjit::x86::Xmm>());
cc.mov(retStackIdx, retRegister.as<asmjit::x86::Gp>());
} else {
cc.movq(retStackIdx, retRegister.as<asmjit::x86::Xmm>());
}
}

Expand Down Expand Up @@ -327,47 +280,13 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmji
cc.bind(noPost);

if (sig.hasRet()) {
//cc.test(retValue, ReturnFlag::Override);
//cc.jnz(override);
/*
#if !_WIN32
if (isPod) {
cc.ret(retRegisters.at(0), retRegisters.at(1));
} else
#endif
{
cc.ret(retRegisters.at(0));
}*/

//cc.bind(override);

asmjit::x86::Mem retStackIdx(retStack);
retStackIdx.setSize(sizeof(uint64_t));
if (asmjit::TypeUtils::isInt(sig.ret())) {
asmjit::x86::Gp tmp = cc.newUIntPtr();
cc.mov(tmp, retStackIdx);
cc.ret(tmp);
}
#if !_WIN32
/*else if (isIntPod) {
asmjit::x86::Mem retStackIdxUpper(retStack);
retStackIdxUpper.addOffset(sizeof(uint64_t));
retStackIdxUpper.setSize(sizeof(uint64_t));
cc.mov(asmjit::x86::rax, retStackIdx);
cc.mov(asmjit::x86::rdx, retStackIdxUpper);
cc.ret();
} else if (isFloatPod) {
asmjit::x86::Mem retStackIdxUpper(retStack);
retStackIdxUpper.addOffset(sizeof(uint64_t));
retStackIdxUpper.setSize(sizeof(uint64_t));
cc.movq(asmjit::x86::xmm0, retStackIdx);
cc.movq(asmjit::x86::xmm1, retStackIdxUpper);
cc.ret();
}*/
#endif
else {
} else {
asmjit::x86::Xmm tmp = cc.newXmm();
cc.movq(tmp, retStackIdx);
cc.ret(tmp);
Expand All @@ -378,35 +297,16 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const asmji

cc.endFunc();

// write to buffer
cc.finalize();

// worst case, overestimates for case trampolines needed
code.flatten();
size_t size = code.codeSize();

// Allocate a virtual memory (executable).
m_callbackBuf = std::make_unique<char[]>(size);
if (!m_callbackBuf) {
PolyHook2DebugBreak();
if (asmjit::Error err = g_jitRuntime->add(&m_functionPtr, &code)) {
m_functionPtr = 0;
m_errorCode = asmjit::DebugUtils::errorAsString(err);
return 0;
}

m_callbackPtr = (uint64_t) m_callbackBuf.get();

MemoryProtector protector(m_callbackPtr, size, R | W | X, *this, false);

// if multiple sections, resolve linkage (1 atm)
if (code.hasUnresolvedLinks()) {
code.resolveUnresolvedLinks();
}

// Relocate to the base-address of the allocated memory.
code.relocateToBase(m_callbackPtr);
code.copyFlattenedData(m_callbackBuf.get(), size);

//Log::log("JIT Stub:\n" + std::string(log.data()), ErrorLevel::INFO);
return m_callbackPtr;
return m_functionPtr;
}

uint64_t PLH::Callback::getJitFunc(DataType retType, const std::vector<DataType>& paramTypes, const asmjit::Arch arch, const CallbackEntry callback) {
Expand Down Expand Up @@ -485,5 +385,23 @@ uint64_t* PLH::Callback::getTrampolineHolder() {
}

uint64_t* PLH::Callback::getCallbackHolder() {
return &m_callbackPtr;
return &m_functionPtr;
}

std::string_view PLH::Callback::getError() const {
return !m_functionPtr && m_errorCode ? m_errorCode : "";
}

PLH::Callback::Callback() {
}

PLH::Callback::Callback(Callback&& other) noexcept : m_callbacks(std::move(m_callbacks)), m_functionPtr(other.m_functionPtr), m_trampolinePtr(other.m_trampolinePtr) {
other.m_functionPtr = 0;
other.m_trampolinePtr = 0;
}

PLH::Callback::~Callback() {
if (m_functionPtr) {
g_jitRuntime->release(m_functionPtr);
}
}
Loading

0 comments on commit df50466

Please sign in to comment.