diff --git a/UnleashedRecomp/kernel/imports.cpp b/UnleashedRecomp/kernel/imports.cpp index d1d02103..6067f800 100644 --- a/UnleashedRecomp/kernel/imports.cpp +++ b/UnleashedRecomp/kernel/imports.cpp @@ -16,6 +16,129 @@ #include +struct Event final : KernelObject, HostObject +{ + bool manualReset; + std::atomic signaled; + + Event(XKEVENT* header) + : manualReset(!header->Type), signaled(!!header->SignalState) + { + } + + Event(bool manualReset, bool initialState) + : manualReset(manualReset), signaled(initialState) + { + } + + uint32_t Wait(uint32_t timeout) override + { + if (timeout == 0) + { + if (!signaled) + return STATUS_TIMEOUT; + + if (!manualReset) + signaled = false; + } + else if (timeout == INFINITE) + { + signaled.wait(false); + + if (!manualReset) + signaled = false; + } + else + { + assert(false && "Unhandled timeout value."); + } + + return STATUS_SUCCESS; + } + + bool Set() + { + signaled = true; + signaled.notify_all(); + + return TRUE; + } + + bool Reset() + { + signaled = false; + return TRUE; + } +}; + +static std::atomic g_keSetEventGeneration; + +struct Semaphore final : KernelObject, HostObject +{ + std::atomic count; + uint32_t maximumCount; + + Semaphore(XKSEMAPHORE* semaphore) + : count(semaphore->Header.SignalState), maximumCount(semaphore->Limit) + { + } + + Semaphore(uint32_t count, uint32_t maximumCount) + : count(count), maximumCount(maximumCount) + { + } + + uint32_t Wait(uint32_t timeout) override + { + if (timeout == 0) + { + uint32_t currentCount = count.load(); + if (currentCount != 0) + { + if (count.compare_exchange_weak(currentCount, currentCount - 1)) + return STATUS_SUCCESS; + } + + return STATUS_TIMEOUT; + } + else if (timeout == INFINITE) + { + uint32_t currentCount; + while (true) + { + currentCount = count.load(); + if (currentCount != 0) + { + if (count.compare_exchange_weak(currentCount, currentCount - 1)) + return STATUS_SUCCESS; + } + else + { + count.wait(0); + } + } + + return STATUS_SUCCESS; + } + else + { + assert(false && "Unhandled timeout value."); + return STATUS_TIMEOUT; + } + } + + void Release(uint32_t releaseCount, uint32_t* previousCount) + { + if (previousCount != nullptr) + *previousCount = count; + + assert(count + releaseCount <= maximumCount); + + count += releaseCount; + count.notify_all(); + } +}; + inline void CloseKernelObject(XDISPATCHER_HEADER& header) { if (header.WaitListHead.Flink != OBJECT_SIGNATURE) @@ -265,19 +388,10 @@ DWORD NtWaitForSingleObjectEx(DWORD Handle, DWORD WaitMode, DWORD Alertable, XLP } else { - const auto status = WaitForSingleObjectEx((HANDLE)Handle, timeout, Alertable); - - if (status == WAIT_IO_COMPLETION) - { - return STATUS_USER_APC; - } - else if (status == WAIT_TIMEOUT) - { - return STATUS_TIMEOUT; - } - - return STATUS_SUCCESS; + assert(false && "Unrecognized handle value."); } + + return STATUS_TIMEOUT; } void NtWriteFile() @@ -367,9 +481,9 @@ void MmQueryStatistics() LOG_UTILITY("!!! STUB !!!"); } -uint32_t NtCreateEvent(uint32_t* handle, void* objAttributes, uint32_t eventType, uint32_t initialState) +uint32_t NtCreateEvent(be* handle, void* objAttributes, uint32_t eventType, uint32_t initialState) { - *handle = ByteSwap((uint32_t)CreateEventA(nullptr, !eventType, !!initialState, nullptr)); + *handle = GetKernelHandle(CreateKernelObject(!eventType, !!initialState)); return 0; } @@ -529,92 +643,6 @@ uint32_t KeSetAffinityThread(DWORD Thread, DWORD Affinity, XLPDWORD lpPreviousAf return 0; } -struct Event final : KernelObject, HostObject -{ - HANDLE handle; - - Event(XKEVENT* header) - { - handle = CreateEventA(nullptr, !header->Type, !!header->SignalState, nullptr); - } - - bool Set() - { - return SetEvent(handle); - } - - bool Reset() - { - return ResetEvent(handle); - } -}; - -struct Semaphore final : KernelObject, HostObject -{ - std::atomic count; - uint32_t maximumCount; - - Semaphore(XKSEMAPHORE* semaphore) - : count(semaphore->Header.SignalState), maximumCount(semaphore->Limit) - { - } - - Semaphore(uint32_t count, uint32_t maximumCount) - : count(count), maximumCount(maximumCount) - { - } - - uint32_t Wait(uint32_t timeout) override - { - if (timeout == 0) - { - uint32_t currentCount = count.load(); - if (currentCount != 0) - { - if (count.compare_exchange_weak(currentCount, currentCount - 1)) - return STATUS_SUCCESS; - } - - return STATUS_TIMEOUT; - } - else if (timeout == INFINITE) - { - uint32_t currentCount; - while (true) - { - currentCount = count.load(); - if (currentCount != 0) - { - if (count.compare_exchange_weak(currentCount, currentCount - 1)) - return STATUS_SUCCESS; - } - else - { - count.wait(0); - } - } - - return STATUS_SUCCESS; - } - else - { - assert(false && "Unhandled timeout value."); - return STATUS_TIMEOUT; - } - } - - void Release(uint32_t releaseCount, uint32_t* previousCount) - { - if (previousCount != nullptr) - *previousCount = count; - - assert(count + releaseCount <= maximumCount); - - count += releaseCount; - count.notify_all(); - } -}; - void RtlLeaveCriticalSection(XRTL_CRITICAL_SECTION* cs) { cs->RecursionCount--; @@ -936,7 +964,12 @@ void KeUnlockL2() bool KeSetEvent(XKEVENT* pEvent, DWORD Increment, bool Wait) { - return QueryKernelObject(*pEvent)->Set(); + bool result = QueryKernelObject(*pEvent)->Set(); + + ++g_keSetEventGeneration; + g_keSetEventGeneration.notify_all(); + + return result; } bool KeResetEvent(XKEVENT* pEvent) @@ -953,7 +986,7 @@ DWORD KeWaitForSingleObject(XDISPATCHER_HEADER* Object, DWORD WaitReason, DWORD { case 0: case 1: - WaitForSingleObjectEx(QueryKernelObject(*Object)->handle, timeout, Alertable); + QueryKernelObject(*Object)->Wait(timeout); break; case 5: @@ -1245,9 +1278,10 @@ void MmQueryAllocationSize() LOG_UTILITY("!!! STUB !!!"); } -uint32_t NtClearEvent(uint32_t handle, uint32_t* previousState) +uint32_t NtClearEvent(Event* handle, uint32_t* previousState) { - return ResetEvent((HANDLE)handle) ? 0 : 0xFFFFFFFF; + handle->Reset(); + return 0; } uint32_t NtResumeThread(GuestThreadHandle* hThread, uint32_t* suspendCount) @@ -1260,9 +1294,10 @@ uint32_t NtResumeThread(GuestThreadHandle* hThread, uint32_t* suspendCount) return S_OK; } -uint32_t NtSetEvent(uint32_t handle, uint32_t* previousState) +uint32_t NtSetEvent(Event* handle, uint32_t* previousState) { - return SetEvent((HANDLE)handle) ? 0 : 0xFFFFFFFF; + handle->Set(); + return 0; } NTSTATUS NtCreateSemaphore(XLPDWORD Handle, XOBJECT_ATTRIBUTES* ObjectAttributes, DWORD InitialCount, DWORD MaximumCount) @@ -1431,19 +1466,41 @@ void NetDll_XNetGetTitleXnAddr() DWORD KeWaitForMultipleObjects(DWORD Count, xpointer* Objects, DWORD WaitType, DWORD WaitReason, DWORD WaitMode, DWORD Alertable, XLPQWORD Timeout) { - // TODO: create actual objects by type. + // FIXME: This function is only accounting for events. + const uint64_t timeout = GuestTimeoutToMilliseconds(Timeout); + assert(timeout == INFINITE); - thread_local std::vector events; - events.resize(Count); - - for (size_t i = 0; i < Count; i++) + if (WaitType == 0) // Wait all { - assert(Objects[i]->Type <= 1); - events[i] = QueryKernelObject(*Objects[i].get())->handle; + for (size_t i = 0; i < Count; i++) + QueryKernelObject(*Objects[i])->Wait(timeout); + } + else + { + thread_local std::vector s_events; + s_events.resize(Count); + + for (size_t i = 0; i < Count; i++) + s_events[i] = QueryKernelObject(*Objects[i]); + + while (true) + { + uint32_t generation = g_keSetEventGeneration.load(); + + for (size_t i = 0; i < Count; i++) + { + if (s_events[i]->Wait(0) == STATUS_SUCCESS) + { + return WAIT_OBJECT_0 + i; + } + } + + g_keSetEventGeneration.wait(generation); + } } - return WaitForMultipleObjectsEx(Count, events.data(), WaitType == 0, timeout, Alertable); + return STATUS_SUCCESS; } uint32_t KeRaiseIrqlToDpcLevel()