Skip to content

Commit

Permalink
Bug #108, 修复AllocTlsData可能破坏当前线程的Tls内容
Browse files Browse the repository at this point in the history
  • Loading branch information
mingkuang-Chuyu committed Aug 4, 2024
1 parent 4682629 commit ac63103
Showing 1 changed file with 159 additions and 65 deletions.
224 changes: 159 additions & 65 deletions src/Thunks/DllMainCRTStartup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,75 @@ namespace YY::Thunks::internal
TlsRawItem* pPrev;
BYTE* pBase;
void** pOldTlsIndex;

void __fastcall Free() noexcept
{
YY::Thunks::internal::Free(pOldTlsIndex);
pOldTlsIndex = nullptr;
auto _pBase = pBase;
pBase = nullptr;
YY::Thunks::internal::Free(_pBase);
}
};

struct TlsHeader
{
TlsRawItem* volatile pRoot = nullptr;

void __fastcall RemoveItem(TlsRawItem* _pItem) noexcept
{
if (pRoot == _pItem)
{
_pItem->pPrev = nullptr;
pRoot = _pItem->pNext;
}
else
{
auto _pPrev = _pItem->pPrev;
auto _pNext = _pItem->pNext;

_pPrev->pNext = _pNext;

if (_pNext)
{
_pNext->pPrev = _pPrev;
}
}
}

void __fastcall AddItem(TlsRawItem* _pFirst, TlsRawItem* _pLast = nullptr) noexcept
{
if (!_pLast)
_pLast = _pFirst;

_pFirst->pPrev = nullptr;
auto _pRoot = pRoot;
_pLast->pNext = _pRoot;
if (_pRoot)
{
_pRoot->pPrev = _pLast;
}
pRoot = _pFirst;
}

TlsRawItem* __fastcall Flush() noexcept
{
return (TlsRawItem*)InterlockedExchange((uintptr_t*)&pRoot, 0);
}
};

enum class TlsStatus
{
// 没有人尝试释放Tls
None,
// 线程Tls数据需要申请或者释放
ThreadLock,
// DLL正在释放
DllUnload
};
static thread_local TlsRawItem s_CurrentNode;

static volatile LONG uStatus = 0;
static TlsRawItem* volatile pRoot = nullptr;
static TlsHeader g_TlsHeader;

static SIZE_T __fastcall GetTlsIndexBufferCount(TEB* _pTeb)
{
Expand Down Expand Up @@ -157,26 +221,47 @@ namespace YY::Thunks::internal

static bool __fastcall AllocTlsData(TEB* _pTeb = nullptr) noexcept
{
const size_t _cbTlsRaw = _tls_used.EndAddressOfRawData - _tls_used.StartAddressOfRawData;
if (_cbTlsRaw == 0)
return true;

if (_tls_index == 0)
return false;

if (!_pTeb)
_pTeb = (TEB*)NtCurrentTeb();

auto _pTlsIndex = (void**)_pTeb->ThreadLocalStoragePointer;
void** _pOldTlsIndex = nullptr;
auto _cTlsIndexLength = GetTlsIndexBufferCount(_pTeb);
const auto _cTlsIndexLength = GetTlsIndexBufferCount(_pTeb);
if (_cTlsIndexLength > _tls_index)
{
InterlockedExchange((uintptr_t*)&_pTlsIndex[_tls_index], 0);
}

auto _pRawTlsData = (BYTE*)Alloc(_cbTlsRaw + sizeof(TlsRawItem), HEAP_ZERO_MEMORY);
if (!_pRawTlsData)
{
return false;
}

auto _pCurrentNode = (TlsRawItem*)(_pRawTlsData + _cbTlsRaw);
if (_cTlsIndexLength <= _tls_index)
{
// Index不足,扩充……
auto _cNewTlsIndexLength = _tls_index + 128;
auto _pNewTlsIndex = (void**)Alloc(_cNewTlsIndexLength * sizeof(void*), HEAP_ZERO_MEMORY);
if (!_pNewTlsIndex)
{
Free(_pRawTlsData);
return false;
}

memcpy(_pNewTlsIndex, _pTlsIndex, _cTlsIndexLength * sizeof(void*));
if ((void*)InterlockedCompareExchange((uintptr_t*)&_pTeb->ThreadLocalStoragePointer, (uintptr_t)_pNewTlsIndex, (uintptr_t)_pTlsIndex) != _pTlsIndex)
{
// 这是什么情况,DllMain期间桌怎么会有其他线程操作Tls?
Free(_pNewTlsIndex);
Free(_pRawTlsData);
return false;
}

Expand All @@ -187,38 +272,32 @@ namespace YY::Thunks::internal
else
{
// 其他线程的无法直接释放,玩意缓存里恰好正在使用这块那么会崩溃的
_pOldTlsIndex = _pTlsIndex;
_pCurrentNode->pOldTlsIndex = _pTlsIndex;
}
_pTlsIndex = _pNewTlsIndex;
}
const size_t _cbTlsRaw = _tls_used.EndAddressOfRawData - _tls_used.StartAddressOfRawData;
auto _pRawTlsData = (BYTE*)Alloc(_cbTlsRaw, HEAP_ZERO_MEMORY);
if (!_pRawTlsData)
{
// 释放不是安全的,极小的概率可能野,但是现在现在就这样吧。
Free(_pOldTlsIndex);
return false;
}

memcpy(_pRawTlsData, (void*)_tls_used.StartAddressOfRawData, _cbTlsRaw);
InterlockedExchange((uintptr_t*)&_pTlsIndex[_tls_index], (uintptr_t)_pRawTlsData);

s_CurrentNode.pBase = _pRawTlsData;
s_CurrentNode.pOldTlsIndex = _pOldTlsIndex;

_pCurrentNode->pBase = _pRawTlsData;
for (;;)
{
if (!_interlockedbittestandset(&uStatus, 0))
const auto _Status = (TlsStatus)InterlockedCompareExchange(&uStatus, LONG(TlsStatus::ThreadLock), LONG(TlsStatus::None));
// 锁定成功?
if (_Status == TlsStatus::None)
{
s_CurrentNode.pNext = pRoot;
if (pRoot)
{
pRoot->pPrev = &s_CurrentNode;
}
pRoot = &s_CurrentNode;
_interlockedbittestandreset(&uStatus, 0);
g_TlsHeader.AddItem(_pCurrentNode);
// 解除锁定
InterlockedExchange(&uStatus, LONG(TlsStatus::None));
InterlockedExchange((uintptr_t*)&_pTlsIndex[_tls_index], (uintptr_t)_pRawTlsData);
break;
}

// 当前Dll正在卸载,不能再添加Tls
if (_Status == TlsStatus::DllUnload)
{
_pCurrentNode->Free();
return false;
}
}
return true;
}
Expand All @@ -227,50 +306,56 @@ namespace YY::Thunks::internal
{
if (_tls_index == 0)
return;

const size_t _cbTlsRaw = _tls_used.EndAddressOfRawData - _tls_used.StartAddressOfRawData;
if (_cbTlsRaw == 0)
return;

auto _pTeb = (TEB*)NtCurrentTeb();
if (_tls_index >= GetTlsIndexBufferCount(_pTeb))
return;

auto _ppTlsIndex = (void**)_pTeb->ThreadLocalStoragePointer;
if (_ppTlsIndex[_tls_index] == nullptr)
auto _pTlsRawData = (BYTE*)InterlockedExchange((uintptr_t*)&_ppTlsIndex[_tls_index], 0);
if (_pTlsRawData == nullptr)
return;

if (s_CurrentNode.pBase != _ppTlsIndex[_tls_index])
return;

if (s_CurrentNode.pOldTlsIndex)
{
Free(s_CurrentNode.pOldTlsIndex);
s_CurrentNode.pOldTlsIndex = nullptr;
}

s_CurrentNode.pBase = nullptr;
auto _pCurrentNode = (TlsRawItem*)(_pTlsRawData + _cbTlsRaw);

for (;;)
{
if (!_interlockedbittestandset(&uStatus, 0))
const auto _Status = (TlsStatus)InterlockedCompareExchange(&uStatus, LONG(TlsStatus::ThreadLock), LONG(TlsStatus::None));
// 锁定成功?
if (_Status == TlsStatus::None)
{
auto pPrev = s_CurrentNode.pPrev;
auto pNext = s_CurrentNode.pNext;
if (pPrev)
__try
{
pPrev->pNext = pNext;
// 检查一下这块Tls数据是否是我们申请的
if (_pCurrentNode->pBase == _pTlsRawData)
{
g_TlsHeader.RemoveItem(_pCurrentNode);
}
else
{
_pCurrentNode = nullptr;
}
}
else
__except (EXCEPTION_EXECUTE_HANDLER)
{
pRoot = pNext;
_pCurrentNode = nullptr;
}

if (pNext)
{
pNext->pPrev = pPrev;
}
_interlockedbittestandreset(&uStatus, 0);
// 解除锁定
InterlockedExchange(&uStatus, LONG(TlsStatus::None));

auto _pTlsRawData = (void*)InterlockedExchange((uintptr_t*)&_ppTlsIndex[_tls_index], 0);
Free(_pTlsRawData);
break;
if (_pCurrentNode)
_pCurrentNode->Free();
return;
}

// 当前Dll正在卸载,这些内存统一由DllMain接管,FreeTlsIndex会统一释放内存
if (_Status == TlsStatus::DllUnload)
return;
}
}

Expand All @@ -280,7 +365,8 @@ namespace YY::Thunks::internal
return false;

_tls_index = GetMaxTlsIndex() + 1;
AllocTlsData((TEB*)NtCurrentTeb());
if (!AllocTlsData((TEB*)NtCurrentTeb()))
return false;

// 同时给所有历史的线程追加新DLL产生的Tls内存
do
Expand Down Expand Up @@ -332,17 +418,24 @@ namespace YY::Thunks::internal
if (_tls_index == 0)
return;

// 故意不加锁……
for (auto _pItem = pRoot; _pItem;)
for (;;)
{
auto _pNext = _pItem->pNext;
if (_pItem->pOldTlsIndex)
const auto _Status = (TlsStatus)InterlockedCompareExchange(&uStatus, LONG(TlsStatus::DllUnload), LONG(TlsStatus::None));
// 锁定成功?
if (_Status == TlsStatus::None)
{
Free(_pItem->pOldTlsIndex);
_pItem->pOldTlsIndex = nullptr;
for (auto _pItem = g_TlsHeader.Flush(); _pItem;)
{
auto _pNext = _pItem->pNext;
_pItem->Free();
_pItem = _pNext;
}
return;
}
Free(_pItem->pBase);
_pItem = _pNext;

// DllUnload 代表全局数据已经释放,只能进入一次
if (_Status == TlsStatus::DllUnload)
return;
}
}

Expand Down Expand Up @@ -411,7 +504,8 @@ namespace YY::Thunks::internal
_tls_index_old = _tls_index;
if (_tls_index == 0 && g_TlsMode == TlsMode::ByDllMainCRTStartupForYY_Thunks)
{
CreateTlsIndex();
if (!CreateTlsIndex())
return FALSE;
}
}
#endif
Expand Down Expand Up @@ -451,9 +545,9 @@ namespace YY::Thunks::internal
CallTlsCallback(_hInstance, _uReason);
auto _bRet = _pfnDllMainCRTStartup(_hInstance, _uReason, _pReserved);

FreeTlsIndex();
if (_pReserved != nullptr)
if (_pReserved == nullptr)
{
FreeTlsIndex();
__YY_uninitialize_winapi_thunks();
}
return _bRet;
Expand All @@ -462,7 +556,7 @@ namespace YY::Thunks::internal
#endif
{
auto _bRet = _pfnDllMainCRTStartup(_hInstance, _uReason, _pReserved);
if (_pReserved != nullptr)
if (_pReserved == nullptr)
{
__YY_uninitialize_winapi_thunks();
}
Expand Down

0 comments on commit ac63103

Please sign in to comment.