Cross-platform event implementation.

This commit is contained in:
Skyth 2024-12-16 13:14:45 +03:00
parent 4770e85573
commit f4453fece6

View file

@ -16,6 +16,129 @@
#include <ntstatus.h>
struct Event final : KernelObject, HostObject<XKEVENT>
{
bool manualReset;
std::atomic<bool> signaled;
Event(XKEVENT* header)
: manualReset(!header->Type), signaled(!!header->SignalState)
{
}
Event(bool manualReset, bool initialState)
: manualReset(manualReset), signaled(initialState)
{
}
uint32_t Wait(uint32_t timeout) override
{
if (timeout == 0)
{
if (!signaled)
return STATUS_TIMEOUT;
if (!manualReset)
signaled = false;
}
else if (timeout == INFINITE)
{
signaled.wait(false);
if (!manualReset)
signaled = false;
}
else
{
assert(false && "Unhandled timeout value.");
}
return STATUS_SUCCESS;
}
bool Set()
{
signaled = true;
signaled.notify_all();
return TRUE;
}
bool Reset()
{
signaled = false;
return TRUE;
}
};
static std::atomic<uint32_t> g_keSetEventGeneration;
struct Semaphore final : KernelObject, HostObject<XKSEMAPHORE>
{
std::atomic<uint32_t> count;
uint32_t maximumCount;
Semaphore(XKSEMAPHORE* semaphore)
: count(semaphore->Header.SignalState), maximumCount(semaphore->Limit)
{
}
Semaphore(uint32_t count, uint32_t maximumCount)
: count(count), maximumCount(maximumCount)
{
}
uint32_t Wait(uint32_t timeout) override
{
if (timeout == 0)
{
uint32_t currentCount = count.load();
if (currentCount != 0)
{
if (count.compare_exchange_weak(currentCount, currentCount - 1))
return STATUS_SUCCESS;
}
return STATUS_TIMEOUT;
}
else if (timeout == INFINITE)
{
uint32_t currentCount;
while (true)
{
currentCount = count.load();
if (currentCount != 0)
{
if (count.compare_exchange_weak(currentCount, currentCount - 1))
return STATUS_SUCCESS;
}
else
{
count.wait(0);
}
}
return STATUS_SUCCESS;
}
else
{
assert(false && "Unhandled timeout value.");
return STATUS_TIMEOUT;
}
}
void Release(uint32_t releaseCount, uint32_t* previousCount)
{
if (previousCount != nullptr)
*previousCount = count;
assert(count + releaseCount <= maximumCount);
count += releaseCount;
count.notify_all();
}
};
inline void CloseKernelObject(XDISPATCHER_HEADER& header)
{
if (header.WaitListHead.Flink != OBJECT_SIGNATURE)
@ -265,19 +388,10 @@ DWORD NtWaitForSingleObjectEx(DWORD Handle, DWORD WaitMode, DWORD Alertable, XLP
}
else
{
const auto status = WaitForSingleObjectEx((HANDLE)Handle, timeout, Alertable);
if (status == WAIT_IO_COMPLETION)
{
return STATUS_USER_APC;
}
else if (status == WAIT_TIMEOUT)
{
return STATUS_TIMEOUT;
}
return STATUS_SUCCESS;
assert(false && "Unrecognized handle value.");
}
return STATUS_TIMEOUT;
}
void NtWriteFile()
@ -367,9 +481,9 @@ void MmQueryStatistics()
LOG_UTILITY("!!! STUB !!!");
}
uint32_t NtCreateEvent(uint32_t* handle, void* objAttributes, uint32_t eventType, uint32_t initialState)
uint32_t NtCreateEvent(be<uint32_t>* handle, void* objAttributes, uint32_t eventType, uint32_t initialState)
{
*handle = ByteSwap((uint32_t)CreateEventA(nullptr, !eventType, !!initialState, nullptr));
*handle = GetKernelHandle(CreateKernelObject<Event>(!eventType, !!initialState));
return 0;
}
@ -529,92 +643,6 @@ uint32_t KeSetAffinityThread(DWORD Thread, DWORD Affinity, XLPDWORD lpPreviousAf
return 0;
}
struct Event final : KernelObject, HostObject<XKEVENT>
{
HANDLE handle;
Event(XKEVENT* header)
{
handle = CreateEventA(nullptr, !header->Type, !!header->SignalState, nullptr);
}
bool Set()
{
return SetEvent(handle);
}
bool Reset()
{
return ResetEvent(handle);
}
};
struct Semaphore final : KernelObject, HostObject<XKSEMAPHORE>
{
std::atomic<uint32_t> count;
uint32_t maximumCount;
Semaphore(XKSEMAPHORE* semaphore)
: count(semaphore->Header.SignalState), maximumCount(semaphore->Limit)
{
}
Semaphore(uint32_t count, uint32_t maximumCount)
: count(count), maximumCount(maximumCount)
{
}
uint32_t Wait(uint32_t timeout) override
{
if (timeout == 0)
{
uint32_t currentCount = count.load();
if (currentCount != 0)
{
if (count.compare_exchange_weak(currentCount, currentCount - 1))
return STATUS_SUCCESS;
}
return STATUS_TIMEOUT;
}
else if (timeout == INFINITE)
{
uint32_t currentCount;
while (true)
{
currentCount = count.load();
if (currentCount != 0)
{
if (count.compare_exchange_weak(currentCount, currentCount - 1))
return STATUS_SUCCESS;
}
else
{
count.wait(0);
}
}
return STATUS_SUCCESS;
}
else
{
assert(false && "Unhandled timeout value.");
return STATUS_TIMEOUT;
}
}
void Release(uint32_t releaseCount, uint32_t* previousCount)
{
if (previousCount != nullptr)
*previousCount = count;
assert(count + releaseCount <= maximumCount);
count += releaseCount;
count.notify_all();
}
};
void RtlLeaveCriticalSection(XRTL_CRITICAL_SECTION* cs)
{
cs->RecursionCount--;
@ -936,7 +964,12 @@ void KeUnlockL2()
bool KeSetEvent(XKEVENT* pEvent, DWORD Increment, bool Wait)
{
return QueryKernelObject<Event>(*pEvent)->Set();
bool result = QueryKernelObject<Event>(*pEvent)->Set();
++g_keSetEventGeneration;
g_keSetEventGeneration.notify_all();
return result;
}
bool KeResetEvent(XKEVENT* pEvent)
@ -953,7 +986,7 @@ DWORD KeWaitForSingleObject(XDISPATCHER_HEADER* Object, DWORD WaitReason, DWORD
{
case 0:
case 1:
WaitForSingleObjectEx(QueryKernelObject<Event>(*Object)->handle, timeout, Alertable);
QueryKernelObject<Event>(*Object)->Wait(timeout);
break;
case 5:
@ -1245,9 +1278,10 @@ void MmQueryAllocationSize()
LOG_UTILITY("!!! STUB !!!");
}
uint32_t NtClearEvent(uint32_t handle, uint32_t* previousState)
uint32_t NtClearEvent(Event* handle, uint32_t* previousState)
{
return ResetEvent((HANDLE)handle) ? 0 : 0xFFFFFFFF;
handle->Reset();
return 0;
}
uint32_t NtResumeThread(GuestThreadHandle* hThread, uint32_t* suspendCount)
@ -1260,9 +1294,10 @@ uint32_t NtResumeThread(GuestThreadHandle* hThread, uint32_t* suspendCount)
return S_OK;
}
uint32_t NtSetEvent(uint32_t handle, uint32_t* previousState)
uint32_t NtSetEvent(Event* handle, uint32_t* previousState)
{
return SetEvent((HANDLE)handle) ? 0 : 0xFFFFFFFF;
handle->Set();
return 0;
}
NTSTATUS NtCreateSemaphore(XLPDWORD Handle, XOBJECT_ATTRIBUTES* ObjectAttributes, DWORD InitialCount, DWORD MaximumCount)
@ -1431,19 +1466,41 @@ void NetDll_XNetGetTitleXnAddr()
DWORD KeWaitForMultipleObjects(DWORD Count, xpointer<XDISPATCHER_HEADER>* Objects, DWORD WaitType, DWORD WaitReason, DWORD WaitMode, DWORD Alertable, XLPQWORD Timeout)
{
// TODO: create actual objects by type.
// FIXME: This function is only accounting for events.
const uint64_t timeout = GuestTimeoutToMilliseconds(Timeout);
assert(timeout == INFINITE);
thread_local std::vector<HANDLE> events;
events.resize(Count);
for (size_t i = 0; i < Count; i++)
if (WaitType == 0) // Wait all
{
assert(Objects[i]->Type <= 1);
events[i] = QueryKernelObject<Event>(*Objects[i].get())->handle;
for (size_t i = 0; i < Count; i++)
QueryKernelObject<Event>(*Objects[i])->Wait(timeout);
}
else
{
thread_local std::vector<Event*> s_events;
s_events.resize(Count);
for (size_t i = 0; i < Count; i++)
s_events[i] = QueryKernelObject<Event>(*Objects[i]);
while (true)
{
uint32_t generation = g_keSetEventGeneration.load();
for (size_t i = 0; i < Count; i++)
{
if (s_events[i]->Wait(0) == STATUS_SUCCESS)
{
return WAIT_OBJECT_0 + i;
}
}
g_keSetEventGeneration.wait(generation);
}
}
return WaitForMultipleObjectsEx(Count, events.data(), WaitType == 0, timeout, Alertable);
return STATUS_SUCCESS;
}
uint32_t KeRaiseIrqlToDpcLevel()