Implement critical sections using WaitOnAddress.

This commit is contained in:
Skyth 2024-11-04 13:16:58 +03:00
parent cd99cf2c04
commit fa5fc9aae7
2 changed files with 17 additions and 54 deletions

View file

@ -132,6 +132,7 @@ target_link_libraries(UnleashedRecomp PRIVATE
PkgConfig::tomlplusplus PkgConfig::tomlplusplus
zstd::libzstd_static zstd::libzstd_static
unofficial::concurrentqueue::concurrentqueue unofficial::concurrentqueue::concurrentqueue
Synchronization
) )
target_include_directories(UnleashedRecomp PRIVATE target_include_directories(UnleashedRecomp PRIVATE

View file

@ -586,62 +586,32 @@ struct Semaphore : HostObject<XKSEMAPHORE>
} }
}; };
extern "C" NTSYSAPI NTSTATUS NTAPI NtReleaseKeyedEvent( // https://devblogs.microsoft.com/oldnewthing/20160825-00/?p=94165
IN HANDLE KeyedEventHandle,
IN PVOID Key,
IN BOOLEAN Alertable,
IN PLARGE_INTEGER Timeout OPTIONAL);
void RtlLeaveCriticalSection(XRTL_CRITICAL_SECTION* cs) void RtlLeaveCriticalSection(XRTL_CRITICAL_SECTION* cs)
{ {
//printf("!!! STUB !!! RtlLeaveCriticalSection\n"); cs->RecursionCount--;
if (cs->RecursionCount != 0) {
if (--cs->RecursionCount != 0)
{
InterlockedDecrement(&cs->LockCount);
return; return;
} }
cs->OwningThread = NULL; InterlockedExchange(&cs->OwningThread, 0);
WakeByAddressSingle(&cs->OwningThread);
if (InterlockedDecrement(&cs->LockCount) != -1)
NtReleaseKeyedEvent(nullptr, cs, FALSE, nullptr);
} }
extern "C" NTSYSAPI NTSTATUS NTAPI NtWaitForKeyedEvent(
IN HANDLE KeyedEventHandle,
IN PVOID Key,
IN BOOLEAN Alertable,
IN PLARGE_INTEGER Timeout OPTIONAL);
void RtlEnterCriticalSection(XRTL_CRITICAL_SECTION* cs) void RtlEnterCriticalSection(XRTL_CRITICAL_SECTION* cs)
{ {
//printf("!!! STUB !!! RtlEnterCriticalSection\n"); DWORD thisThread = GetCurrentThreadId();
while (true)
const uint32_t thread = static_cast<uint32_t>(GetPPCContext()->r13.u64);
if (cs->OwningThread == thread)
{ {
InterlockedIncrement(&cs->LockCount); DWORD previousOwner = InterlockedCompareExchangeAcquire(&cs->OwningThread, thisThread, 0);
++cs->RecursionCount; if (previousOwner == 0 || previousOwner == thisThread) {
return; cs->RecursionCount++;
}
uint32_t spinCount = cs->Header.Absolute * 256;
while (spinCount--)
{
if (InterlockedCompareExchange(&cs->LockCount, 0, -1) == -1)
{
cs->OwningThread = thread;
cs->RecursionCount = 1;
return; return;
} }
WaitOnAddress(&cs->OwningThread, &previousOwner, sizeof(previousOwner), INFINITE);
} }
if (InterlockedIncrement(&cs->LockCount) != 0)
NtWaitForKeyedEvent(nullptr, cs, FALSE, nullptr);
cs->OwningThread = thread;
cs->RecursionCount = 1;
} }
void RtlImageXexHeaderField() void RtlImageXexHeaderField()
@ -1110,19 +1080,11 @@ void XexGetModuleHandle()
bool RtlTryEnterCriticalSection(XRTL_CRITICAL_SECTION* cs) bool RtlTryEnterCriticalSection(XRTL_CRITICAL_SECTION* cs)
{ {
const uint32_t thread = static_cast<uint32_t>(GetPPCContext()->r13.u64); DWORD thisThread = GetCurrentThreadId();
DWORD previousOwner = InterlockedCompareExchangeAcquire(&cs->OwningThread, thisThread, 0);
if (InterlockedCompareExchange(&cs->LockCount, 0, -1) == -1) if (previousOwner == 0 || previousOwner == thisThread) {
{ cs->RecursionCount++;
cs->OwningThread = thread;
cs->RecursionCount = 1;
return true;
}
if (cs->OwningThread == thread)
{
InterlockedIncrement(&cs->LockCount);
++cs->RecursionCount;
return true; return true;
} }