From d68e88314f2137cd75a23b679f47550ad9e2bace Mon Sep 17 00:00:00 2001 From: "Skyth (Asilkan)" <19259897+blueskythlikesclouds@users.noreply.github.com> Date: Sun, 15 Dec 2024 23:13:06 +0300 Subject: [PATCH] Cross-platform semaphore implementation. (#43) --- UnleashedRecomp/cpu/guest_thread.cpp | 4 +- UnleashedRecomp/cpu/guest_thread.h | 2 +- UnleashedRecomp/kernel/imports.cpp | 103 +++++++++++++++++++++------ UnleashedRecomp/kernel/xdm.h | 3 +- 4 files changed, 86 insertions(+), 26 deletions(-) diff --git a/UnleashedRecomp/cpu/guest_thread.cpp b/UnleashedRecomp/cpu/guest_thread.cpp index 51e9a0be..32412b40 100644 --- a/UnleashedRecomp/cpu/guest_thread.cpp +++ b/UnleashedRecomp/cpu/guest_thread.cpp @@ -59,12 +59,14 @@ GuestThreadHandle::~GuestThreadHandle() thread.join(); } -void GuestThreadHandle::Wait(uint32_t timeout) +uint32_t GuestThreadHandle::Wait(uint32_t timeout) { assert(timeout == INFINITE); if (thread.joinable()) thread.join(); + + return STATUS_WAIT_0; } uint32_t GuestThread::Start(const GuestThreadParams& params) diff --git a/UnleashedRecomp/cpu/guest_thread.h b/UnleashedRecomp/cpu/guest_thread.h index 91ec9834..15d16bf6 100644 --- a/UnleashedRecomp/cpu/guest_thread.h +++ b/UnleashedRecomp/cpu/guest_thread.h @@ -29,7 +29,7 @@ struct GuestThreadHandle : KernelObject GuestThreadHandle(const GuestThreadParams& params); ~GuestThreadHandle() override; - void Wait(uint32_t timeout) override; + uint32_t Wait(uint32_t timeout) override; }; struct GuestThread diff --git a/UnleashedRecomp/kernel/imports.cpp b/UnleashedRecomp/kernel/imports.cpp index 8b275dc8..ebbf7e4e 100644 --- a/UnleashedRecomp/kernel/imports.cpp +++ b/UnleashedRecomp/kernel/imports.cpp @@ -261,7 +261,7 @@ DWORD NtWaitForSingleObjectEx(DWORD Handle, DWORD WaitMode, DWORD Alertable, XLP if (IsKernelObject(Handle)) { - GetKernelObject(Handle)->Wait(timeout); + return GetKernelObject(Handle)->Wait(timeout); } else { @@ -271,13 +271,13 @@ DWORD NtWaitForSingleObjectEx(DWORD Handle, DWORD WaitMode, DWORD Alertable, XLP { return STATUS_USER_APC; } - else if (status) + else if (status == WAIT_TIMEOUT) { - return STATUS_ALERTED; + return STATUS_TIMEOUT; } - } - return STATUS_SUCCESS; + return STATUS_SUCCESS; + } } void NtWriteFile() @@ -529,7 +529,7 @@ uint32_t KeSetAffinityThread(DWORD Thread, DWORD Affinity, XLPDWORD lpPreviousAf return 0; } -struct Event : KernelObject, HostObject +struct Event final : KernelObject, HostObject { HANDLE handle; @@ -549,13 +549,69 @@ struct Event : KernelObject, HostObject } }; -struct Semaphore : KernelObject, HostObject +struct Semaphore final : KernelObject, HostObject { - HANDLE handle; + std::atomic count; + uint32_t maximumCount; Semaphore(XKSEMAPHORE* semaphore) + : count(semaphore->Header.SignalState), maximumCount(semaphore->Limit) { - handle = CreateSemaphoreA(nullptr, semaphore->Header.SignalState, semaphore->Limit, nullptr); + } + + Semaphore(uint32_t count, uint32_t maximumCount) + : count(count), maximumCount(maximumCount) + { + } + + uint32_t Wait(uint32_t timeout) override + { + if (timeout == 0) + { + uint32_t currentCount = count.load(); + if (currentCount != 0) + { + if (count.compare_exchange_weak(currentCount, currentCount - 1)) + return STATUS_SUCCESS; + } + + return STATUS_TIMEOUT; + } + else if (timeout == INFINITE) + { + uint32_t currentCount; + while (true) + { + currentCount = count.load(); + if (currentCount != 0) + { + if (count.compare_exchange_weak(currentCount, currentCount - 1)) + return STATUS_SUCCESS; + } + else + { + count.wait(0); + } + } + + return STATUS_SUCCESS; + } + else + { + assert(false && "Unhandled timeout value."); + return STATUS_TIMEOUT; + } + } + + void Release(uint32_t releaseCount, uint32_t* previousCount) + { + if (previousCount != nullptr) + *previousCount = count; + + assert(count + releaseCount <= maximumCount); + + count += releaseCount; + count.notify_all(); } }; @@ -876,27 +932,26 @@ bool KeResetEvent(XKEVENT* pEvent) DWORD KeWaitForSingleObject(XDISPATCHER_HEADER* Object, DWORD WaitReason, DWORD WaitMode, bool Alertable, XLPQWORD Timeout) { - const uint64_t timeout = GuestTimeoutToMilliseconds(Timeout); - - HANDLE handle = nullptr; + const uint32_t timeout = GuestTimeoutToMilliseconds(Timeout); + assert(timeout == INFINITE); switch (Object->Type) { case 0: case 1: - handle = QueryKernelObject(*Object)->handle; + WaitForSingleObjectEx(QueryKernelObject(*Object)->handle, timeout, Alertable); break; case 5: - handle = QueryKernelObject(*Object)->handle; + QueryKernelObject(*Object)->Wait(timeout); break; default: - assert(false); - break; + assert(false && "Unrecognized kernel object type."); + return STATUS_TIMEOUT; } - return WaitForSingleObjectEx(handle, timeout, Alertable); + return STATUS_SUCCESS; } static std::vector g_tlsFreeIndices; @@ -1198,16 +1253,17 @@ uint32_t NtSetEvent(uint32_t handle, uint32_t* previousState) NTSTATUS NtCreateSemaphore(XLPDWORD Handle, XOBJECT_ATTRIBUTES* ObjectAttributes, DWORD InitialCount, DWORD MaximumCount) { - *Handle = (uint32_t)CreateSemaphoreA(nullptr, InitialCount, MaximumCount, nullptr); + *Handle = GetKernelHandle(CreateKernelObject(InitialCount, MaximumCount)); return STATUS_SUCCESS; } -NTSTATUS NtReleaseSemaphore(uint32_t Handle, DWORD ReleaseCount, LONG* PreviousCount) +NTSTATUS NtReleaseSemaphore(Semaphore* Handle, DWORD ReleaseCount, LONG* PreviousCount) { - ReleaseSemaphore((HANDLE)Handle, ReleaseCount, PreviousCount); + uint32_t previousCount; + Handle->Release(ReleaseCount, &previousCount); - if (PreviousCount) - *PreviousCount = ByteSwap(*PreviousCount); + if (PreviousCount != nullptr) + *PreviousCount = ByteSwap(previousCount); return STATUS_SUCCESS; } @@ -1386,7 +1442,8 @@ void KfLowerIrql() { } uint32_t KeReleaseSemaphore(XKSEMAPHORE* semaphore, uint32_t increment, uint32_t adjustment, uint32_t wait) { auto* object = QueryKernelObject(semaphore->Header); - return ReleaseSemaphore(object->handle, adjustment, nullptr) ? 0 : 0xFFFFFFFF; + object->Release(adjustment, nullptr); + return STATUS_SUCCESS; } void XAudioGetVoiceCategoryVolume() diff --git a/UnleashedRecomp/kernel/xdm.h b/UnleashedRecomp/kernel/xdm.h index 4b50d714..4c933228 100644 --- a/UnleashedRecomp/kernel/xdm.h +++ b/UnleashedRecomp/kernel/xdm.h @@ -12,9 +12,10 @@ struct KernelObject { } - virtual void Wait(uint32_t timeout) + virtual uint32_t Wait(uint32_t timeout) { assert(false && "Wait not implemented for this kernel object."); + return STATUS_TIMEOUT; } };