diff --git a/src/Thunks/DllMainCRTStartup.hpp b/src/Thunks/DllMainCRTStartup.hpp index f6f9b38..df4fc75 100644 --- a/src/Thunks/DllMainCRTStartup.hpp +++ b/src/Thunks/DllMainCRTStartup.hpp @@ -58,11 +58,75 @@ namespace YY::Thunks::internal TlsRawItem* pPrev; BYTE* pBase; void** pOldTlsIndex; + + void __fastcall Free() noexcept + { + YY::Thunks::internal::Free(pOldTlsIndex); + pOldTlsIndex = nullptr; + auto _pBase = pBase; + pBase = nullptr; + YY::Thunks::internal::Free(_pBase); + } + }; + + struct TlsHeader + { + TlsRawItem* volatile pRoot = nullptr; + + void __fastcall RemoveItem(TlsRawItem* _pItem) noexcept + { + if (pRoot == _pItem) + { + _pItem->pPrev = nullptr; + pRoot = _pItem->pNext; + } + else + { + auto _pPrev = _pItem->pPrev; + auto _pNext = _pItem->pNext; + + _pPrev->pNext = _pNext; + + if (_pNext) + { + _pNext->pPrev = _pPrev; + } + } + } + + void __fastcall AddItem(TlsRawItem* _pFirst, TlsRawItem* _pLast = nullptr) noexcept + { + if (!_pLast) + _pLast = _pFirst; + + _pFirst->pPrev = nullptr; + auto _pRoot = pRoot; + _pLast->pNext = _pRoot; + if (_pRoot) + { + _pRoot->pPrev = _pLast; + } + pRoot = _pFirst; + } + + TlsRawItem* __fastcall Flush() noexcept + { + return (TlsRawItem*)InterlockedExchange((uintptr_t*)&pRoot, 0); + } + }; + + enum class TlsStatus + { + // 没有人尝试释放Tls + None, + // 线程Tls数据需要申请或者释放 + ThreadLock, + // DLL正在释放 + DllUnload }; - static thread_local TlsRawItem s_CurrentNode; static volatile LONG uStatus = 0; - static TlsRawItem* volatile pRoot = nullptr; + static TlsHeader g_TlsHeader; static SIZE_T __fastcall GetTlsIndexBufferCount(TEB* _pTeb) { @@ -157,26 +221,47 @@ namespace YY::Thunks::internal static bool __fastcall AllocTlsData(TEB* _pTeb = nullptr) noexcept { + const size_t _cbTlsRaw = _tls_used.EndAddressOfRawData - _tls_used.StartAddressOfRawData; + if (_cbTlsRaw == 0) + return true; + if (_tls_index == 0) return false; + if (!_pTeb) _pTeb = (TEB*)NtCurrentTeb(); auto _pTlsIndex = (void**)_pTeb->ThreadLocalStoragePointer; - void** _pOldTlsIndex = nullptr; - auto _cTlsIndexLength = GetTlsIndexBufferCount(_pTeb); + const auto _cTlsIndexLength = GetTlsIndexBufferCount(_pTeb); + if (_cTlsIndexLength > _tls_index) + { + InterlockedExchange((uintptr_t*)&_pTlsIndex[_tls_index], 0); + } + + auto _pRawTlsData = (BYTE*)Alloc(_cbTlsRaw + sizeof(TlsRawItem), HEAP_ZERO_MEMORY); + if (!_pRawTlsData) + { + return false; + } + + auto _pCurrentNode = (TlsRawItem*)(_pRawTlsData + _cbTlsRaw); if (_cTlsIndexLength <= _tls_index) { // Index不足,扩充…… auto _cNewTlsIndexLength = _tls_index + 128; auto _pNewTlsIndex = (void**)Alloc(_cNewTlsIndexLength * sizeof(void*), HEAP_ZERO_MEMORY); if (!_pNewTlsIndex) + { + Free(_pRawTlsData); return false; + } memcpy(_pNewTlsIndex, _pTlsIndex, _cTlsIndexLength * sizeof(void*)); if ((void*)InterlockedCompareExchange((uintptr_t*)&_pTeb->ThreadLocalStoragePointer, (uintptr_t)_pNewTlsIndex, (uintptr_t)_pTlsIndex) != _pTlsIndex) { + // 这是什么情况,DllMain期间桌怎么会有其他线程操作Tls? Free(_pNewTlsIndex); + Free(_pRawTlsData); return false; } @@ -187,38 +272,32 @@ namespace YY::Thunks::internal else { // 其他线程的无法直接释放,玩意缓存里恰好正在使用这块那么会崩溃的 - _pOldTlsIndex = _pTlsIndex; + _pCurrentNode->pOldTlsIndex = _pTlsIndex; } _pTlsIndex = _pNewTlsIndex; } - const size_t _cbTlsRaw = _tls_used.EndAddressOfRawData - _tls_used.StartAddressOfRawData; - auto _pRawTlsData = (BYTE*)Alloc(_cbTlsRaw, HEAP_ZERO_MEMORY); - if (!_pRawTlsData) - { - // 释放不是安全的,极小的概率可能野,但是现在现在就这样吧。 - Free(_pOldTlsIndex); - return false; - } memcpy(_pRawTlsData, (void*)_tls_used.StartAddressOfRawData, _cbTlsRaw); - InterlockedExchange((uintptr_t*)&_pTlsIndex[_tls_index], (uintptr_t)_pRawTlsData); - - s_CurrentNode.pBase = _pRawTlsData; - s_CurrentNode.pOldTlsIndex = _pOldTlsIndex; - + _pCurrentNode->pBase = _pRawTlsData; for (;;) { - if (!_interlockedbittestandset(&uStatus, 0)) + const auto _Status = (TlsStatus)InterlockedCompareExchange(&uStatus, LONG(TlsStatus::ThreadLock), LONG(TlsStatus::None)); + // 锁定成功? + if (_Status == TlsStatus::None) { - s_CurrentNode.pNext = pRoot; - if (pRoot) - { - pRoot->pPrev = &s_CurrentNode; - } - pRoot = &s_CurrentNode; - _interlockedbittestandreset(&uStatus, 0); + g_TlsHeader.AddItem(_pCurrentNode); + // 解除锁定 + InterlockedExchange(&uStatus, LONG(TlsStatus::None)); + InterlockedExchange((uintptr_t*)&_pTlsIndex[_tls_index], (uintptr_t)_pRawTlsData); break; } + + // 当前Dll正在卸载,不能再添加Tls + if (_Status == TlsStatus::DllUnload) + { + _pCurrentNode->Free(); + return false; + } } return true; } @@ -227,50 +306,56 @@ namespace YY::Thunks::internal { if (_tls_index == 0) return; + + const size_t _cbTlsRaw = _tls_used.EndAddressOfRawData - _tls_used.StartAddressOfRawData; + if (_cbTlsRaw == 0) + return; + auto _pTeb = (TEB*)NtCurrentTeb(); if (_tls_index >= GetTlsIndexBufferCount(_pTeb)) return; auto _ppTlsIndex = (void**)_pTeb->ThreadLocalStoragePointer; - if (_ppTlsIndex[_tls_index] == nullptr) + auto _pTlsRawData = (BYTE*)InterlockedExchange((uintptr_t*)&_ppTlsIndex[_tls_index], 0); + if (_pTlsRawData == nullptr) return; - if (s_CurrentNode.pBase != _ppTlsIndex[_tls_index]) - return; - - if (s_CurrentNode.pOldTlsIndex) - { - Free(s_CurrentNode.pOldTlsIndex); - s_CurrentNode.pOldTlsIndex = nullptr; - } - - s_CurrentNode.pBase = nullptr; + auto _pCurrentNode = (TlsRawItem*)(_pTlsRawData + _cbTlsRaw); for (;;) { - if (!_interlockedbittestandset(&uStatus, 0)) + const auto _Status = (TlsStatus)InterlockedCompareExchange(&uStatus, LONG(TlsStatus::ThreadLock), LONG(TlsStatus::None)); + // 锁定成功? + if (_Status == TlsStatus::None) { - auto pPrev = s_CurrentNode.pPrev; - auto pNext = s_CurrentNode.pNext; - if (pPrev) + __try { - pPrev->pNext = pNext; + // 检查一下这块Tls数据是否是我们申请的 + if (_pCurrentNode->pBase == _pTlsRawData) + { + g_TlsHeader.RemoveItem(_pCurrentNode); + } + else + { + _pCurrentNode = nullptr; + } } - else + __except (EXCEPTION_EXECUTE_HANDLER) { - pRoot = pNext; + _pCurrentNode = nullptr; } - if (pNext) - { - pNext->pPrev = pPrev; - } - _interlockedbittestandreset(&uStatus, 0); + // 解除锁定 + InterlockedExchange(&uStatus, LONG(TlsStatus::None)); - auto _pTlsRawData = (void*)InterlockedExchange((uintptr_t*)&_ppTlsIndex[_tls_index], 0); - Free(_pTlsRawData); - break; + if (_pCurrentNode) + _pCurrentNode->Free(); + return; } + + // 当前Dll正在卸载,这些内存统一由DllMain接管,FreeTlsIndex会统一释放内存 + if (_Status == TlsStatus::DllUnload) + return; } } @@ -280,7 +365,8 @@ namespace YY::Thunks::internal return false; _tls_index = GetMaxTlsIndex() + 1; - AllocTlsData((TEB*)NtCurrentTeb()); + if (!AllocTlsData((TEB*)NtCurrentTeb())) + return false; // 同时给所有历史的线程追加新DLL产生的Tls内存 do @@ -332,17 +418,24 @@ namespace YY::Thunks::internal if (_tls_index == 0) return; - // 故意不加锁…… - for (auto _pItem = pRoot; _pItem;) + for (;;) { - auto _pNext = _pItem->pNext; - if (_pItem->pOldTlsIndex) + const auto _Status = (TlsStatus)InterlockedCompareExchange(&uStatus, LONG(TlsStatus::DllUnload), LONG(TlsStatus::None)); + // 锁定成功? + if (_Status == TlsStatus::None) { - Free(_pItem->pOldTlsIndex); - _pItem->pOldTlsIndex = nullptr; + for (auto _pItem = g_TlsHeader.Flush(); _pItem;) + { + auto _pNext = _pItem->pNext; + _pItem->Free(); + _pItem = _pNext; + } + return; } - Free(_pItem->pBase); - _pItem = _pNext; + + // DllUnload 代表全局数据已经释放,只能进入一次 + if (_Status == TlsStatus::DllUnload) + return; } } @@ -411,7 +504,8 @@ namespace YY::Thunks::internal _tls_index_old = _tls_index; if (_tls_index == 0 && g_TlsMode == TlsMode::ByDllMainCRTStartupForYY_Thunks) { - CreateTlsIndex(); + if (!CreateTlsIndex()) + return FALSE; } } #endif @@ -451,9 +545,9 @@ namespace YY::Thunks::internal CallTlsCallback(_hInstance, _uReason); auto _bRet = _pfnDllMainCRTStartup(_hInstance, _uReason, _pReserved); - FreeTlsIndex(); - if (_pReserved != nullptr) + if (_pReserved == nullptr) { + FreeTlsIndex(); __YY_uninitialize_winapi_thunks(); } return _bRet; @@ -462,7 +556,7 @@ namespace YY::Thunks::internal #endif { auto _bRet = _pfnDllMainCRTStartup(_hInstance, _uReason, _pReserved); - if (_pReserved != nullptr) + if (_pReserved == nullptr) { __YY_uninitialize_winapi_thunks(); }