Skip to content

Commit

Permalink
core: Swap TCB in and out of FS segment around guest code on Linux.
Browse files Browse the repository at this point in the history
  • Loading branch information
squidbus committed Sep 19, 2024
1 parent 383a761 commit 49b6317
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 60 deletions.
11 changes: 1 addition & 10 deletions src/core/cpu_patches.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,11 +600,8 @@ struct PatchInfo {
};

static const std::unordered_map<ZydisMnemonic, PatchInfo> Patches = {
#if defined(_WIN32)
// Windows needs a trampoline.
#ifdef _WIN32
{ZYDIS_MNEMONIC_MOV, {FilterTcbAccess, GenerateTcbAccess, true}},
#elif !defined(__APPLE__)
{ZYDIS_MNEMONIC_MOV, {FilterTcbAccess, GenerateTcbAccess, false}},
#endif

#ifdef __APPLE__
Expand Down Expand Up @@ -789,12 +786,6 @@ void PrePatchInstructions(u64 segment_addr, u64 segment_size) {
code_page += 0x1000;
}
}
#elif !defined(_WIN32)
// Linux and others have an FS segment pointing to valid memory, so continue to do full
// ahead-of-time patching for now until a better solution is worked out.
if (!Patches.empty()) {
TryPatchAot(reinterpret_cast<void*>(segment_addr), segment_size);
}
#endif
}

Expand Down
2 changes: 1 addition & 1 deletion src/core/libraries/avplayer/avplayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ void RegisterlibSceAvPlayer(Core::Loader::SymbolsResolver* sym) {
LIB_FUNCTION("XC9wM+xULz8", "libSceAvPlayer", 1, "libSceAvPlayer", 1, 0, sceAvPlayerJumpToTime);
LIB_FUNCTION("9y5v+fGN4Wk", "libSceAvPlayer", 1, "libSceAvPlayer", 1, 0, sceAvPlayerPause);
LIB_FUNCTION("HD1YKVU26-M", "libSceAvPlayer", 1, "libSceAvPlayer", 1, 0, sceAvPlayerPostInit);
LIB_FUNCTION("agig-iDRrTE", "libSceAvPlayer", 1, "libSceAvPlayer", 1, 0, sceAvPlayerPrintf);
// LIB_FUNCTION("agig-iDRrTE", "libSceAvPlayer", 1, "libSceAvPlayer", 1, 0, sceAvPlayerPrintf);
LIB_FUNCTION("w5moABNwnRY", "libSceAvPlayer", 1, "libSceAvPlayer", 1, 0, sceAvPlayerResume);
LIB_FUNCTION("k-q+xOxdc3E", "libSceAvPlayer", 1, "libSceAvPlayer", 1, 0,
sceAvPlayerSetAvSyncMode);
Expand Down
4 changes: 2 additions & 2 deletions src/core/libraries/avplayer/avplayer_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ void PS4_SYSV_ABI AvPlayerState::AutoPlayEventCallback(void* opaque, SceAvPlayer
const auto callback = self->m_event_replacement.event_callback;
const auto ptr = self->m_event_replacement.object_ptr;
if (callback != nullptr) {
auto* linker = Common::Singleton<Core::Linker>::Instance();
const auto* linker = Common::Singleton<Core::Linker>::Instance();
linker->ExecuteGuest(callback, ptr, event_id, 0, event_data);
}
}
Expand Down Expand Up @@ -365,7 +365,7 @@ void AvPlayerState::EmitEvent(SceAvPlayerEvents event_id, void* event_data) {
const auto callback = m_init_data.event_replacement.event_callback;
if (callback) {
const auto ptr = m_init_data.event_replacement.object_ptr;
auto* linker = Common::Singleton<Core::Linker>::Instance();
const auto* linker = Common::Singleton<Core::Linker>::Instance();
linker->ExecuteGuest(callback, ptr, event_id, 0, event_data);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/core/libraries/kernel/thread_management.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ static void cleanup_thread(void* arg) {
static void* run_thread(void* arg) {
auto* thread = static_cast<ScePthread>(arg);
Common::SetCurrentThreadName(thread->name.c_str());
auto* linker = Common::Singleton<Core::Linker>::Instance();
const auto* linker = Common::Singleton<Core::Linker>::Instance();
void* ret = nullptr;
g_pthread_self = thread;
pthread_cleanup_push(cleanup_thread, thread);
Expand Down
42 changes: 5 additions & 37 deletions src/core/libraries/libs.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,12 @@
#include <functional>

#include "common/logging/log.h"
#include "common/singleton.h"
#include "core/linker.h"
#include "core/loader/elf.h"
#include "core/loader/symbols_resolver.h"

template <size_t N>
struct StringLiteral {
constexpr StringLiteral(const char (&str)[N]) {
std::copy_n(str, N, value);
}

char value[N];
};

template <StringLiteral name, class F, F f>
struct wrapper_impl;

template <StringLiteral name, class R, class... Args, PS4_SYSV_ABI R (*f)(Args...)>
struct wrapper_impl<name, PS4_SYSV_ABI R (*)(Args...), f> {
static R PS4_SYSV_ABI wrap(Args... args) {
if (std::string_view(name.value) != "scePthreadEqual" &&
std::string_view(name.value) != "sceUserServiceGetEvent") {
// LOG_WARNING(Core_Linker, "Function {} called", name.value);
}
if constexpr (std::is_same_v<R, s32> || std::is_same_v<R, u32>) {
const u32 ret = f(args...);
if (ret != 0 && std::string_view(name.value) != "scePthreadEqual") {
LOG_WARNING(Core_Linker, "Function {} returned {:#x}", name.value, ret);
}
return ret;
}
// stuff
return f(args...);
}
};

template <StringLiteral name, class F, F f>
constexpr auto wrapper = wrapper_impl<name, F, f>::wrap;

// #define W(foo) wrapper<#foo, decltype(&foo), foo>
#define W(foo) foo
#define W(linker, foo) linker->WrapHost<#foo, decltype(&foo), foo>()

#define LIB_FUNCTION(nid, lib, libversion, mod, moduleVersionMajor, moduleVersionMinor, function) \
{ \
Expand All @@ -56,7 +23,8 @@ constexpr auto wrapper = wrapper_impl<name, F, f>::wrap;
sr.module_version_major = moduleVersionMajor; \
sr.module_version_minor = moduleVersionMinor; \
sr.type = Core::Loader::SymbolType::Function; \
auto func = reinterpret_cast<u64>(W(function)); \
const auto* linker = Common::Singleton<Core::Linker>::Instance(); \
auto func = reinterpret_cast<u64>(W(linker, function)); \
sym->AddSymbol(sr, func); \
}

Expand Down
4 changes: 2 additions & 2 deletions src/core/libraries/network/net_ctl_obj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ s32 Libraries::NetCtl::NetCtlInternal::registerNpToolkitCallback(

void Libraries::NetCtl::NetCtlInternal::checkCallback() {
std::unique_lock lock{m_mutex};
auto* linker = Common::Singleton<Core::Linker>::Instance();
const auto* linker = Common::Singleton<Core::Linker>::Instance();
for (auto& callback : callbacks) {
if (callback.func != nullptr) {
linker->ExecuteGuest(callback.func, ORBIS_NET_CTL_EVENT_TYPE_DISCONNECTED,
Expand All @@ -70,7 +70,7 @@ void Libraries::NetCtl::NetCtlInternal::checkCallback() {

void Libraries::NetCtl::NetCtlInternal::checkNpToolkitCallback() {
std::unique_lock lock{m_mutex};
auto* linker = Common::Singleton<Core::Linker>::Instance();
const auto* linker = Common::Singleton<Core::Linker>::Instance();
for (auto& callback : nptoolCallbacks) {
if (callback.func != nullptr) {
linker->ExecuteGuest(callback.func, ORBIS_NET_CTL_EVENT_TYPE_DISCONNECTED,
Expand Down
4 changes: 2 additions & 2 deletions src/core/linker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ void* Linker::TlsGetAddr(u64 module_index, u64 offset) {

thread_local std::once_flag init_tls_flag;

void Linker::EnsureThreadInitialized(bool is_primary) {
void Linker::EnsureThreadInitialized(bool is_primary) const {
std::call_once(init_tls_flag, [this, is_primary] {
#ifdef ARCH_X86_64
InitializeThreadPatchStack();
Expand All @@ -343,7 +343,7 @@ void Linker::EnsureThreadInitialized(bool is_primary) {
});
}

void Linker::InitTlsForThread(bool is_primary) {
void Linker::InitTlsForThread(bool is_primary) const {
static constexpr size_t TcbSize = 0x40;
static constexpr size_t TlsAllocAlign = 0x20;
const size_t total_tls_size = Common::AlignUp(static_tls_size, TlsAllocAlign) + TcbSize;
Expand Down
56 changes: 52 additions & 4 deletions src/core/linker.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <mutex>
#include <vector>
#include "core/module.h"
#include "core/tls.h"

namespace Core {

Expand Down Expand Up @@ -109,16 +110,39 @@ class Linker {
void DebugDump();

template <class ReturnType, class... FuncArgs, class... CallArgs>
ReturnType ExecuteGuest(PS4_SYSV_ABI ReturnType (*func)(FuncArgs...), CallArgs&&... args) {
ReturnType ExecuteGuest(PS4_SYSV_ABI ReturnType (*func)(FuncArgs...),
CallArgs&&... args) const {
// Make sure TLS is initialized for the thread before entering guest.
EnsureThreadInitialized();
return func(std::forward<CallArgs>(args)...);
SwapTls();
if constexpr (std::is_same_v<ReturnType, void>) {
func(std::forward<CallArgs>(args)...);
SwapTls();
return;
} else {
ReturnType ret = func(std::forward<CallArgs>(args)...);
SwapTls();
return ret;
}
}

template <size_t N>
struct StringLiteral {
constexpr StringLiteral(const char (&str)[N]) {
std::copy_n(str, N, value);
}
char value[N];
};

template <StringLiteral name, class FuncType, FuncType func>
FuncType WrapHost() const {
return host_wrapper_impl<name, FuncType, func>::wrap;
}

private:
const Module* FindExportedModule(const ModuleInfo& m, const LibraryInfo& l);
void EnsureThreadInitialized(bool is_primary = false);
void InitTlsForThread(bool is_primary);
void EnsureThreadInitialized(bool is_primary = false) const;
void InitTlsForThread(bool is_primary) const;

MemoryManager* memory;
std::mutex mutex;
Expand All @@ -129,6 +153,30 @@ class Linker {
AppHeapAPI heap_api{};
std::vector<std::unique_ptr<Module>> m_modules;
Loader::SymbolsResolver m_hle_symbols{};

template <StringLiteral name, class FuncType, FuncType func>
struct host_wrapper_impl;

template <StringLiteral name, class ReturnType, class... FuncArgs,
PS4_SYSV_ABI ReturnType (*func)(FuncArgs...)>
struct host_wrapper_impl<name, PS4_SYSV_ABI ReturnType (*)(FuncArgs...), func> {
static ReturnType PS4_SYSV_ABI wrap(FuncArgs... args) {
/*if (std::string_view(name.value) != "scePthreadEqual" &&
std::string_view(name.value) != "sceUserServiceGetEvent") {
LOG_WARNING(Core_Linker, "Function {} called", name.value);
}*/
SwapTls();
if constexpr (std::is_same_v<ReturnType, void>) {
func(args...);
SwapTls();
return;
} else {
ReturnType ret = func(args...);
SwapTls();
return ret;
}
}
};
};

} // namespace Core
2 changes: 1 addition & 1 deletion src/core/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Module::~Module() = default;

s32 Module::Start(size_t args, const void* argp, void* param) {
LOG_INFO(Core_Linker, "Module started : {}", name);
auto* linker = Common::Singleton<Core::Linker>::Instance();
const auto* linker = Common::Singleton<Core::Linker>::Instance();
const VAddr addr = dynamic_info.init_virtual_addr + GetBaseAddress();
return linker->ExecuteGuest(reinterpret_cast<EntryFunc>(addr), args, argp, param);
}
Expand Down
18 changes: 18 additions & 0 deletions src/core/tls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ Tcb* GetTcbBase() {
return reinterpret_cast<Tcb*>(TlsGetValue(GetTcbKey()));
}

void SwapTls() {}

#elif defined(__APPLE__) && defined(ARCH_X86_64)

// Apple x86_64
Expand Down Expand Up @@ -139,6 +141,8 @@ Tcb* GetTcbBase() {
return tcb;
}

void SwapTls() {}

#elif defined(ARCH_X86_64)

// Other POSIX x86_64
Expand All @@ -153,6 +157,18 @@ Tcb* GetTcbBase() {
return tcb;
}

void SwapTls() {
void* fsbase;
void* gsbase;
asm volatile("rdfsbase %0;"
"rdgsbase %1"
: "=r"(fsbase), "=r"(gsbase)::"memory");
asm volatile("wrfsbase %0;"
"wrgsbase %1" ::"r"(gsbase),
"r"(fsbase)
: "memory");
}

#else

// POSIX non-x86_64
Expand All @@ -179,6 +195,8 @@ Tcb* GetTcbBase() {
return static_cast<Tcb*>(pthread_getspecific(GetTcbKey()));
}

void SwapTls() {}

#endif

} // namespace Core
3 changes: 3 additions & 0 deletions src/core/tls.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,7 @@ void SetTcbBase(void* image_address);
/// Retrieves Tcb structure for the calling thread.
Tcb* GetTcbBase();

/// Swaps between the host and guest TLS area if needed.
void SwapTls();

} // namespace Core

0 comments on commit 49b6317

Please sign in to comment.