Cross-platform semaphore implementation.

This commit is contained in:
Skyth 2024-12-15 22:31:04 +03:00
parent afce26fc35
commit a8163b1e29
4 changed files with 86 additions and 26 deletions

View file

@ -59,12 +59,14 @@ GuestThreadHandle::~GuestThreadHandle()
thread.join(); thread.join();
} }
void GuestThreadHandle::Wait(uint32_t timeout) uint32_t GuestThreadHandle::Wait(uint32_t timeout)
{ {
assert(timeout == INFINITE); assert(timeout == INFINITE);
if (thread.joinable()) if (thread.joinable())
thread.join(); thread.join();
return STATUS_WAIT_0;
} }
uint32_t GuestThread::Start(const GuestThreadParams& params) uint32_t GuestThread::Start(const GuestThreadParams& params)

View file

@ -29,7 +29,7 @@ struct GuestThreadHandle : KernelObject
GuestThreadHandle(const GuestThreadParams& params); GuestThreadHandle(const GuestThreadParams& params);
~GuestThreadHandle() override; ~GuestThreadHandle() override;
void Wait(uint32_t timeout) override; uint32_t Wait(uint32_t timeout) override;
}; };
struct GuestThread struct GuestThread

View file

@ -261,7 +261,7 @@ DWORD NtWaitForSingleObjectEx(DWORD Handle, DWORD WaitMode, DWORD Alertable, XLP
if (IsKernelObject(Handle)) if (IsKernelObject(Handle))
{ {
GetKernelObject(Handle)->Wait(timeout); return GetKernelObject(Handle)->Wait(timeout);
} }
else else
{ {
@ -271,14 +271,14 @@ DWORD NtWaitForSingleObjectEx(DWORD Handle, DWORD WaitMode, DWORD Alertable, XLP
{ {
return STATUS_USER_APC; return STATUS_USER_APC;
} }
else if (status) else if (status == WAIT_TIMEOUT)
{ {
return STATUS_ALERTED; return STATUS_TIMEOUT;
}
} }
return STATUS_SUCCESS; return STATUS_SUCCESS;
} }
}
void NtWriteFile() void NtWriteFile()
{ {
@ -529,7 +529,7 @@ uint32_t KeSetAffinityThread(DWORD Thread, DWORD Affinity, XLPDWORD lpPreviousAf
return 0; return 0;
} }
struct Event : KernelObject, HostObject<XKEVENT> struct Event final : KernelObject, HostObject<XKEVENT>
{ {
HANDLE handle; HANDLE handle;
@ -549,13 +549,69 @@ struct Event : KernelObject, HostObject<XKEVENT>
} }
}; };
struct Semaphore : KernelObject, HostObject<XKSEMAPHORE> struct Semaphore final : KernelObject, HostObject<XKSEMAPHORE>
{ {
HANDLE handle; std::atomic<uint32_t> count;
uint32_t maximumCount;
Semaphore(XKSEMAPHORE* semaphore) Semaphore(XKSEMAPHORE* semaphore)
: count(semaphore->Header.SignalState), maximumCount(semaphore->Limit)
{ {
handle = CreateSemaphoreA(nullptr, semaphore->Header.SignalState, semaphore->Limit, nullptr); }
Semaphore(uint32_t count, uint32_t maximumCount)
: count(count), maximumCount(maximumCount)
{
}
uint32_t Wait(uint32_t timeout) override
{
if (timeout == 0)
{
uint32_t currentCount = count.load();
if (currentCount != 0)
{
if (count.compare_exchange_weak(currentCount, currentCount - 1))
return STATUS_SUCCESS;
}
return STATUS_TIMEOUT;
}
else if (timeout == INFINITE)
{
uint32_t currentCount;
while (true)
{
currentCount = count.load();
if (currentCount != 0)
{
if (count.compare_exchange_weak(currentCount, currentCount - 1))
return STATUS_SUCCESS;
}
else
{
count.wait(0);
}
}
return STATUS_SUCCESS;
}
else
{
assert(false && "Unhandled timeout value.");
return STATUS_TIMEOUT;
}
}
void Release(uint32_t releaseCount, uint32_t* previousCount)
{
if (previousCount != nullptr)
*previousCount = count;
assert(count + releaseCount <= maximumCount);
count += releaseCount;
count.notify_all();
} }
}; };
@ -876,27 +932,26 @@ bool KeResetEvent(XKEVENT* pEvent)
DWORD KeWaitForSingleObject(XDISPATCHER_HEADER* Object, DWORD WaitReason, DWORD WaitMode, bool Alertable, XLPQWORD Timeout) DWORD KeWaitForSingleObject(XDISPATCHER_HEADER* Object, DWORD WaitReason, DWORD WaitMode, bool Alertable, XLPQWORD Timeout)
{ {
const uint64_t timeout = GuestTimeoutToMilliseconds(Timeout); const uint32_t timeout = GuestTimeoutToMilliseconds(Timeout);
assert(timeout == INFINITE);
HANDLE handle = nullptr;
switch (Object->Type) switch (Object->Type)
{ {
case 0: case 0:
case 1: case 1:
handle = QueryKernelObject<Event>(*Object)->handle; WaitForSingleObjectEx(QueryKernelObject<Event>(*Object)->handle, timeout, Alertable);
break; break;
case 5: case 5:
handle = QueryKernelObject<Semaphore>(*Object)->handle; QueryKernelObject<Semaphore>(*Object)->Wait(timeout);
break; break;
default: default:
assert(false); assert(false && "Unrecognized kernel object type.");
break; return STATUS_TIMEOUT;
} }
return WaitForSingleObjectEx(handle, timeout, Alertable); return STATUS_SUCCESS;
} }
static std::vector<size_t> g_tlsFreeIndices; static std::vector<size_t> g_tlsFreeIndices;
@ -1198,16 +1253,17 @@ uint32_t NtSetEvent(uint32_t handle, uint32_t* previousState)
NTSTATUS NtCreateSemaphore(XLPDWORD Handle, XOBJECT_ATTRIBUTES* ObjectAttributes, DWORD InitialCount, DWORD MaximumCount) NTSTATUS NtCreateSemaphore(XLPDWORD Handle, XOBJECT_ATTRIBUTES* ObjectAttributes, DWORD InitialCount, DWORD MaximumCount)
{ {
*Handle = (uint32_t)CreateSemaphoreA(nullptr, InitialCount, MaximumCount, nullptr); *Handle = GetKernelHandle(CreateKernelObject<Semaphore>(InitialCount, MaximumCount));
return STATUS_SUCCESS; return STATUS_SUCCESS;
} }
NTSTATUS NtReleaseSemaphore(uint32_t Handle, DWORD ReleaseCount, LONG* PreviousCount) NTSTATUS NtReleaseSemaphore(Semaphore* Handle, DWORD ReleaseCount, LONG* PreviousCount)
{ {
ReleaseSemaphore((HANDLE)Handle, ReleaseCount, PreviousCount); uint32_t previousCount;
Handle->Release(ReleaseCount, &previousCount);
if (PreviousCount) if (PreviousCount != nullptr)
*PreviousCount = ByteSwap(*PreviousCount); *PreviousCount = ByteSwap(previousCount);
return STATUS_SUCCESS; return STATUS_SUCCESS;
} }
@ -1386,7 +1442,8 @@ void KfLowerIrql() { }
uint32_t KeReleaseSemaphore(XKSEMAPHORE* semaphore, uint32_t increment, uint32_t adjustment, uint32_t wait) uint32_t KeReleaseSemaphore(XKSEMAPHORE* semaphore, uint32_t increment, uint32_t adjustment, uint32_t wait)
{ {
auto* object = QueryKernelObject<Semaphore>(semaphore->Header); auto* object = QueryKernelObject<Semaphore>(semaphore->Header);
return ReleaseSemaphore(object->handle, adjustment, nullptr) ? 0 : 0xFFFFFFFF; object->Release(adjustment, nullptr);
return STATUS_SUCCESS;
} }
void XAudioGetVoiceCategoryVolume() void XAudioGetVoiceCategoryVolume()

View file

@ -12,9 +12,10 @@ struct KernelObject
{ {
} }
virtual void Wait(uint32_t timeout) virtual uint32_t Wait(uint32_t timeout)
{ {
assert(false && "Wait not implemented for this kernel object."); assert(false && "Wait not implemented for this kernel object.");
return STATUS_TIMEOUT;
} }
}; };