Refactor kernel objects to be lock-free.

This commit is contained in:
Skyth 2024-12-14 15:40:13 +03:00
parent e89012127d
commit 7f45cb942d
7 changed files with 230 additions and 280 deletions

View file

@ -19,7 +19,7 @@ struct Heap
size_t Size(void* ptr); size_t Size(void* ptr);
template<typename T, typename... Args> template<typename T, typename... Args>
T* Alloc(Args... args) T* Alloc(Args&&... args)
{ {
T* obj = (T*)Alloc(sizeof(T)); T* obj = (T*)Alloc(sizeof(T));
new (obj) T(std::forward<Args>(args)...); new (obj) T(std::forward<Args>(args)...);
@ -27,7 +27,7 @@ struct Heap
} }
template<typename T, typename... Args> template<typename T, typename... Args>
T* AllocPhysical(Args... args) T* AllocPhysical(Args&&... args)
{ {
T* obj = (T*)AllocPhysical(sizeof(T), alignof(T)); T* obj = (T*)AllocPhysical(sizeof(T), alignof(T));
new (obj) T(std::forward<Args>(args)...); new (obj) T(std::forward<Args>(args)...);

View file

@ -23,7 +23,7 @@ inline void CloseKernelObject(XDISPATCHER_HEADER& header)
return; return;
} }
ObCloseHandle(header.WaitListHead.Blink); DestroyKernelObject(header.WaitListHead.Blink);
} }
DWORD GuestTimeoutToMilliseconds(XLPQWORD timeout) DWORD GuestTimeoutToMilliseconds(XLPQWORD timeout)
@ -232,12 +232,12 @@ DWORD NtCreateFile
uint32_t NtClose(uint32_t handle) uint32_t NtClose(uint32_t handle)
{ {
if (handle == (uint32_t)INVALID_HANDLE_VALUE) if (handle == GUEST_INVALID_HANDLE_VALUE)
return 0xFFFFFFFF; return 0xFFFFFFFF;
if (CHECK_GUEST_HANDLE(handle)) if (IsKernelObject(handle))
{ {
ObCloseHandle(HOST_HANDLE(handle)); DestroyKernelObject(handle);
return 0; return 0;
} }
@ -531,7 +531,7 @@ uint32_t KeSetAffinityThread(DWORD Thread, DWORD Affinity, XLPDWORD lpPreviousAf
return 0; return 0;
} }
struct Event : HostObject<XKEVENT> struct Event : KernelObject, HostObject<XKEVENT>
{ {
HANDLE handle; HANDLE handle;
@ -551,7 +551,7 @@ struct Event : HostObject<XKEVENT>
} }
}; };
struct Semaphore : HostObject<XKSEMAPHORE> struct Semaphore : KernelObject, HostObject<XKSEMAPHORE>
{ {
HANDLE handle; HANDLE handle;
@ -866,12 +866,12 @@ void KeUnlockL2()
bool KeSetEvent(XKEVENT* pEvent, DWORD Increment, bool Wait) bool KeSetEvent(XKEVENT* pEvent, DWORD Increment, bool Wait)
{ {
return ObQueryObject<Event>(*pEvent)->Set(); return QueryKernelObject<Event>(*pEvent)->Set();
} }
bool KeResetEvent(XKEVENT* pEvent) bool KeResetEvent(XKEVENT* pEvent)
{ {
return ObQueryObject<Event>(*pEvent)->Reset(); return QueryKernelObject<Event>(*pEvent)->Reset();
} }
DWORD KeWaitForSingleObject(XDISPATCHER_HEADER* Object, DWORD WaitReason, DWORD WaitMode, bool Alertable, XLPQWORD Timeout) DWORD KeWaitForSingleObject(XDISPATCHER_HEADER* Object, DWORD WaitReason, DWORD WaitMode, bool Alertable, XLPQWORD Timeout)
@ -884,11 +884,11 @@ DWORD KeWaitForSingleObject(XDISPATCHER_HEADER* Object, DWORD WaitReason, DWORD
{ {
case 0: case 0:
case 1: case 1:
handle = ObQueryObject<Event>(*Object)->handle; handle = QueryKernelObject<Event>(*Object)->handle;
break; break;
case 5: case 5:
handle = ObQueryObject<Semaphore>(*Object)->handle; handle = QueryKernelObject<Semaphore>(*Object)->handle;
break; break;
default: default:
@ -1340,7 +1340,7 @@ DWORD KeWaitForMultipleObjects(DWORD Count, xpointer<XDISPATCHER_HEADER>* Object
for (size_t i = 0; i < Count; i++) for (size_t i = 0; i < Count; i++)
{ {
assert(Objects[i]->Type <= 1); assert(Objects[i]->Type <= 1);
events[i] = ObQueryObject<Event>(*Objects[i].get())->handle; events[i] = QueryKernelObject<Event>(*Objects[i].get())->handle;
} }
return WaitForMultipleObjectsEx(Count, events.data(), WaitType == 0, timeout, Alertable); return WaitForMultipleObjectsEx(Count, events.data(), WaitType == 0, timeout, Alertable);
@ -1355,7 +1355,7 @@ void KfLowerIrql() { }
uint32_t KeReleaseSemaphore(XKSEMAPHORE* semaphore, uint32_t increment, uint32_t adjustment, uint32_t wait) uint32_t KeReleaseSemaphore(XKSEMAPHORE* semaphore, uint32_t increment, uint32_t adjustment, uint32_t wait)
{ {
auto* object = ObQueryObject<Semaphore>(semaphore->Header); auto* object = QueryKernelObject<Semaphore>(semaphore->Header);
return ReleaseSemaphore(object->handle, adjustment, nullptr) ? 0 : 0xFFFFFFFF; return ReleaseSemaphore(object->handle, adjustment, nullptr) ? 0 : 0xFFFFFFFF;
} }
@ -1382,7 +1382,7 @@ void KeInitializeSemaphore(XKSEMAPHORE* semaphore, uint32_t count, uint32_t limi
semaphore->Header.SignalState = count; semaphore->Header.SignalState = count;
semaphore->Limit = limit; semaphore->Limit = limit;
auto* object = ObQueryObject<Semaphore>(semaphore->Header); auto* object = QueryKernelObject<Semaphore>(semaphore->Header);
} }
void XMAReleaseContext() void XMAReleaseContext()

View file

@ -6,15 +6,13 @@
#include <cpu/guest_thread.h> #include <cpu/guest_thread.h>
#include <os/logger.h> #include <os/logger.h>
static constexpr uint32_t GUEST_INVALID_HANDLE_VALUE = 0xFFFFFFFF; struct FileHandle : KernelObject
struct FileHandle
{ {
std::fstream stream; std::fstream stream;
std::filesystem::path path; std::filesystem::path path;
}; };
struct FindHandle struct FindHandle : KernelObject
{ {
std::filesystem::path searchPath; std::filesystem::path searchPath;
std::filesystem::directory_iterator iterator; std::filesystem::directory_iterator iterator;
@ -37,19 +35,7 @@ struct FindHandle
} }
}; };
bool FileHandleCloser(void* handle) SWA_API FileHandle* XCreateFileA
{
delete (FileHandle *)(handle);
return false;
}
bool FindHandleCloser(void* handle)
{
delete (FindHandle *)(handle);
return false;
}
SWA_API uint32_t XCreateFileA
( (
LPCSTR lpFileName, LPCSTR lpFileName,
DWORD dwDesiredAccess, DWORD dwDesiredAccess,
@ -82,89 +68,75 @@ SWA_API uint32_t XCreateFileA
#ifdef _WIN32 #ifdef _WIN32
GuestThread::SetLastError(GetLastError()); GuestThread::SetLastError(GetLastError());
#endif #endif
return GUEST_INVALID_HANDLE_VALUE; return GetInvalidKernelObject<FileHandle>();
} }
FileHandle *fileHandle = new FileHandle(); FileHandle *fileHandle = CreateKernelObject<FileHandle>();
fileHandle->stream = std::move(fileStream); fileHandle->stream = std::move(fileStream);
fileHandle->path = std::move(filePath); fileHandle->path = std::move(filePath);
return GUEST_HANDLE(ObInsertObject(fileHandle, FileHandleCloser)); return fileHandle;
} }
static DWORD XGetFileSizeA(uint32_t hFile, LPDWORD lpFileSizeHigh) static DWORD XGetFileSizeA(FileHandle* hFile, LPDWORD lpFileSizeHigh)
{ {
FileHandle *handle = (FileHandle *)(ObQueryObject(HOST_HANDLE(hFile))); std::error_code ec;
if (handle != nullptr) auto fileSize = std::filesystem::file_size(hFile->path, ec);
if (!ec)
{ {
std::error_code ec; if (lpFileSizeHigh != nullptr)
auto fileSize = std::filesystem::file_size(handle->path, ec);
if (!ec)
{ {
if (lpFileSizeHigh != nullptr) *lpFileSizeHigh = ByteSwap(DWORD(fileSize >> 32U));
{
*lpFileSizeHigh = ByteSwap(DWORD(fileSize >> 32U));
}
return (DWORD)(fileSize);
} }
return (DWORD)(fileSize);
} }
return INVALID_FILE_SIZE; return INVALID_FILE_SIZE;
} }
BOOL XGetFileSizeExA(uint32_t hFile, PLARGE_INTEGER lpFileSize) BOOL XGetFileSizeExA(FileHandle* hFile, PLARGE_INTEGER lpFileSize)
{ {
FileHandle *handle = (FileHandle *)(ObQueryObject(HOST_HANDLE(hFile))); std::error_code ec;
if (handle != nullptr) auto fileSize = std::filesystem::file_size(hFile->path, ec);
if (!ec)
{ {
std::error_code ec; if (lpFileSize != nullptr)
auto fileSize = std::filesystem::file_size(handle->path, ec);
if (!ec)
{ {
if (lpFileSize != nullptr) lpFileSize->QuadPart = ByteSwap(fileSize);
{
lpFileSize->QuadPart = ByteSwap(fileSize);
}
return 1;
} }
return TRUE;
} }
return 0; return FALSE;
} }
BOOL XReadFile BOOL XReadFile
( (
uint32_t hFile, FileHandle* hFile,
LPVOID lpBuffer, LPVOID lpBuffer,
DWORD nNumberOfBytesToRead, DWORD nNumberOfBytesToRead,
XLPDWORD lpNumberOfBytesRead, XLPDWORD lpNumberOfBytesRead,
XOVERLAPPED* lpOverlapped XOVERLAPPED* lpOverlapped
) )
{ {
FileHandle *handle = (FileHandle *)(ObQueryObject(HOST_HANDLE(hFile)));
if (handle == nullptr)
{
return FALSE;
}
BOOL result = FALSE; BOOL result = FALSE;
if (lpOverlapped != nullptr) if (lpOverlapped != nullptr)
{ {
std::streamoff streamOffset = lpOverlapped->Offset + (std::streamoff(lpOverlapped->OffsetHigh.get()) << 32U); std::streamoff streamOffset = lpOverlapped->Offset + (std::streamoff(lpOverlapped->OffsetHigh.get()) << 32U);
handle->stream.clear(); hFile->stream.clear();
handle->stream.seekg(streamOffset, std::ios::beg); hFile->stream.seekg(streamOffset, std::ios::beg);
if (handle->stream.bad()) if (hFile->stream.bad())
{ {
return FALSE; return FALSE;
} }
} }
DWORD numberOfBytesRead; DWORD numberOfBytesRead;
handle->stream.read((char *)(lpBuffer), nNumberOfBytesToRead); hFile->stream.read((char *)(lpBuffer), nNumberOfBytesToRead);
if (!handle->stream.bad()) if (!hFile->stream.bad())
{ {
numberOfBytesRead = DWORD(handle->stream.gcount()); numberOfBytesRead = DWORD(hFile->stream.gcount());
result = TRUE; result = TRUE;
} }
@ -187,14 +159,8 @@ BOOL XReadFile
return result; return result;
} }
DWORD XSetFilePointer(uint32_t hFile, LONG lDistanceToMove, PLONG lpDistanceToMoveHigh, DWORD dwMoveMethod) DWORD XSetFilePointer(FileHandle* hFile, LONG lDistanceToMove, PLONG lpDistanceToMoveHigh, DWORD dwMoveMethod)
{ {
FileHandle *handle = (FileHandle *)(ObQueryObject(HOST_HANDLE(hFile)));
if (handle == nullptr)
{
return INVALID_SET_FILE_POINTER;
}
LONG distanceToMoveHigh = lpDistanceToMoveHigh ? ByteSwap(*lpDistanceToMoveHigh) : 0; LONG distanceToMoveHigh = lpDistanceToMoveHigh ? ByteSwap(*lpDistanceToMoveHigh) : 0;
std::streamoff streamOffset = lDistanceToMove + (std::streamoff(distanceToMoveHigh) << 32U); std::streamoff streamOffset = lDistanceToMove + (std::streamoff(distanceToMoveHigh) << 32U);
std::fstream::seekdir streamSeekDir = {}; std::fstream::seekdir streamSeekDir = {};
@ -214,28 +180,22 @@ DWORD XSetFilePointer(uint32_t hFile, LONG lDistanceToMove, PLONG lpDistanceToMo
break; break;
} }
handle->stream.clear(); hFile->stream.clear();
handle->stream.seekg(streamOffset, streamSeekDir); hFile->stream.seekg(streamOffset, streamSeekDir);
if (handle->stream.bad()) if (hFile->stream.bad())
{ {
return INVALID_SET_FILE_POINTER; return INVALID_SET_FILE_POINTER;
} }
std::streampos streamPos = handle->stream.tellg(); std::streampos streamPos = hFile->stream.tellg();
if (lpDistanceToMoveHigh != nullptr) if (lpDistanceToMoveHigh != nullptr)
*lpDistanceToMoveHigh = ByteSwap(LONG(streamPos >> 32U)); *lpDistanceToMoveHigh = ByteSwap(LONG(streamPos >> 32U));
return DWORD(streamPos); return DWORD(streamPos);
} }
BOOL XSetFilePointerEx(uint32_t hFile, LONG lDistanceToMove, PLARGE_INTEGER lpNewFilePointer, DWORD dwMoveMethod) BOOL XSetFilePointerEx(FileHandle* hFile, LONG lDistanceToMove, PLARGE_INTEGER lpNewFilePointer, DWORD dwMoveMethod)
{ {
FileHandle *handle = (FileHandle *)(ObQueryObject(HOST_HANDLE(hFile)));
if (handle == nullptr)
{
return FALSE;
}
std::fstream::seekdir streamSeekDir = {}; std::fstream::seekdir streamSeekDir = {};
switch (dwMoveMethod) switch (dwMoveMethod)
{ {
@ -253,27 +213,27 @@ BOOL XSetFilePointerEx(uint32_t hFile, LONG lDistanceToMove, PLARGE_INTEGER lpNe
break; break;
} }
handle->stream.clear(); hFile->stream.clear();
handle->stream.seekg(lDistanceToMove, streamSeekDir); hFile->stream.seekg(lDistanceToMove, streamSeekDir);
if (handle->stream.bad()) if (hFile->stream.bad())
{ {
return FALSE; return FALSE;
} }
if (lpNewFilePointer != nullptr) if (lpNewFilePointer != nullptr)
{ {
lpNewFilePointer->QuadPart = ByteSwap(LONGLONG(handle->stream.tellg())); lpNewFilePointer->QuadPart = ByteSwap(LONGLONG(hFile->stream.tellg()));
} }
return TRUE; return TRUE;
} }
uint32_t XFindFirstFileA(LPCSTR lpFileName, LPWIN32_FIND_DATAA lpFindFileData) FindHandle* XFindFirstFileA(LPCSTR lpFileName, LPWIN32_FIND_DATAA lpFindFileData)
{ {
const char *transformedPath = FileSystem::TransformPath(lpFileName); const char *transformedPath = FileSystem::TransformPath(lpFileName);
size_t transformedPathLength = strlen(transformedPath); size_t transformedPathLength = strlen(transformedPath);
if (transformedPathLength == 0) if (transformedPathLength == 0)
return GUEST_INVALID_HANDLE_VALUE; return (FindHandle*)GUEST_INVALID_HANDLE_VALUE;
std::filesystem::path dirPath; std::filesystem::path dirPath;
if (strstr(transformedPath, "\\*") == (&transformedPath[transformedPathLength - 2])) if (strstr(transformedPath, "\\*") == (&transformedPath[transformedPathLength - 2]))
@ -291,56 +251,48 @@ uint32_t XFindFirstFileA(LPCSTR lpFileName, LPWIN32_FIND_DATAA lpFindFileData)
} }
if (!std::filesystem::is_directory(dirPath)) if (!std::filesystem::is_directory(dirPath))
return GUEST_INVALID_HANDLE_VALUE; return GetInvalidKernelObject<FindHandle>();
std::filesystem::directory_iterator dirIterator(dirPath); std::filesystem::directory_iterator dirIterator(dirPath);
if (dirIterator == std::filesystem::directory_iterator()) if (dirIterator == std::filesystem::directory_iterator())
return GUEST_INVALID_HANDLE_VALUE; return GetInvalidKernelObject<FindHandle>();
FindHandle *findHandle = new FindHandle(); FindHandle *findHandle = CreateKernelObject<FindHandle>();
findHandle->searchPath = std::move(dirPath); findHandle->searchPath = std::move(dirPath);
findHandle->iterator = std::move(dirIterator); findHandle->iterator = std::move(dirIterator);
findHandle->fillFindData(lpFindFileData); findHandle->fillFindData(lpFindFileData);
return GUEST_HANDLE(ObInsertObject(findHandle, FindHandleCloser)); return findHandle;
} }
uint32_t XFindNextFileA(uint32_t Handle, LPWIN32_FIND_DATAA lpFindFileData) BOOL XFindNextFileA(FindHandle* Handle, LPWIN32_FIND_DATAA lpFindFileData)
{ {
FindHandle *findHandle = (FindHandle *)(ObQueryObject(HOST_HANDLE(Handle))); Handle->iterator++;
if (findHandle == nullptr)
return FALSE;
findHandle->iterator++; if (Handle->iterator == std::filesystem::directory_iterator())
if (findHandle->iterator == std::filesystem::directory_iterator())
{ {
return FALSE; return FALSE;
} }
else else
{ {
findHandle->fillFindData(lpFindFileData); Handle->fillFindData(lpFindFileData);
return TRUE; return TRUE;
} }
} }
BOOL XReadFileEx(uint32_t hFile, LPVOID lpBuffer, DWORD nNumberOfBytesToRead, XOVERLAPPED* lpOverlapped, uint32_t lpCompletionRoutine) BOOL XReadFileEx(FileHandle* hFile, LPVOID lpBuffer, DWORD nNumberOfBytesToRead, XOVERLAPPED* lpOverlapped, uint32_t lpCompletionRoutine)
{ {
FileHandle *handle = (FileHandle *)(ObQueryObject(HOST_HANDLE(hFile)));
if (handle == nullptr)
return FALSE;
BOOL result = FALSE; BOOL result = FALSE;
DWORD numberOfBytesRead; DWORD numberOfBytesRead;
std::streamoff streamOffset = lpOverlapped->Offset + (std::streamoff(lpOverlapped->OffsetHigh.get()) << 32U); std::streamoff streamOffset = lpOverlapped->Offset + (std::streamoff(lpOverlapped->OffsetHigh.get()) << 32U);
handle->stream.clear(); hFile->stream.clear();
handle->stream.seekg(streamOffset, std::ios::beg); hFile->stream.seekg(streamOffset, std::ios::beg);
if (handle->stream.bad()) if (hFile->stream.bad())
return FALSE; return FALSE;
handle->stream.read((char *)(lpBuffer), nNumberOfBytesToRead); hFile->stream.read((char *)(lpBuffer), nNumberOfBytesToRead);
if (!handle->stream.bad()) if (!hFile->stream.bad())
{ {
numberOfBytesRead = DWORD(handle->stream.gcount()); numberOfBytesRead = DWORD(hFile->stream.gcount());
result = TRUE; result = TRUE;
} }
@ -367,20 +319,16 @@ DWORD XGetFileAttributesA(LPCSTR lpFileName)
return INVALID_FILE_ATTRIBUTES; return INVALID_FILE_ATTRIBUTES;
} }
BOOL XWriteFile(uint32_t hFile, LPCVOID lpBuffer, DWORD nNumberOfBytesToWrite, LPDWORD lpNumberOfBytesWritten, LPOVERLAPPED lpOverlapped) BOOL XWriteFile(FileHandle* hFile, LPCVOID lpBuffer, DWORD nNumberOfBytesToWrite, LPDWORD lpNumberOfBytesWritten, LPOVERLAPPED lpOverlapped)
{ {
assert(lpOverlapped == nullptr && "Overlapped not implemented."); assert(lpOverlapped == nullptr && "Overlapped not implemented.");
FileHandle *handle = (FileHandle *)(ObQueryObject(HOST_HANDLE(hFile))); hFile->stream.write((const char *)(lpBuffer), nNumberOfBytesToWrite);
if (handle == nullptr) if (hFile->stream.bad())
return FALSE;
handle->stream.write((const char *)(lpBuffer), nNumberOfBytesToWrite);
if (handle->stream.bad())
return FALSE; return FALSE;
if (lpNumberOfBytesWritten != nullptr) if (lpNumberOfBytesWritten != nullptr)
*lpNumberOfBytesWritten = DWORD(handle->stream.gcount()); *lpNumberOfBytesWritten = DWORD(hFile->stream.gcount());
return TRUE; return TRUE;
} }

View file

@ -13,6 +13,79 @@
// Needed for commctrl // Needed for commctrl
#pragma comment(linker, "/manifestdependency:\"type='win32' name='Microsoft.Windows.Common-Controls' version='6.0.0.0' processorArchitecture='amd64' publicKeyToken='6595b64144ccf1df' language='*'\"") #pragma comment(linker, "/manifestdependency:\"type='win32' name='Microsoft.Windows.Common-Controls' version='6.0.0.0' processorArchitecture='amd64' publicKeyToken='6595b64144ccf1df' language='*'\"")
struct XamListener : KernelObject
{
uint32_t id{};
uint64_t areas{};
std::vector<std::tuple<DWORD, DWORD>> notifications;
XamListener(const XamListener&) = delete;
XamListener& operator=(const XamListener&) = delete;
XamListener();
~XamListener();
};
struct XamEnumeratorBase : KernelObject
{
virtual uint32_t Next(void* buffer)
{
return -1;
}
};
template<typename TIterator = std::vector<XHOSTCONTENT_DATA>::iterator>
struct XamEnumerator : XamEnumeratorBase
{
uint32_t fetch;
size_t size;
TIterator position;
TIterator begin;
TIterator end;
XamEnumerator() = default;
XamEnumerator(uint32_t fetch, size_t size, TIterator begin, TIterator end) : fetch(fetch), size(size), position(begin), begin(begin), end(end)
{
}
uint32_t Next(void* buffer) override
{
if (position == end)
{
return -1;
}
if (buffer == nullptr)
{
for (size_t i = 0; i < fetch; i++)
{
if (position == end)
{
return i == 0 ? -1 : i;
}
++position;
}
}
for (size_t i = 0; i < fetch; i++)
{
if (position == end)
{
return i == 0 ? -1 : i;
}
memcpy(buffer, &*position, size);
++position;
buffer = (void*)((size_t)buffer + size);
}
return fetch;
}
};
std::array<xxHashMap<XHOSTCONTENT_DATA>, 3> gContentRegistry{}; std::array<xxHashMap<XHOSTCONTENT_DATA>, 3> gContentRegistry{};
std::unordered_set<XamListener*> gListeners{}; std::unordered_set<XamListener*> gListeners{};
xxHashMap<std::string> gRootMap; xxHashMap<std::string> gRootMap;
@ -69,12 +142,11 @@ void XamRegisterContent(DWORD type, const std::string_view name, const std::stri
SWA_API DWORD XamNotifyCreateListener(uint64_t qwAreas) SWA_API DWORD XamNotifyCreateListener(uint64_t qwAreas)
{ {
int handle; auto* listener = CreateKernelObject<XamListener>();
auto* listener = ObCreateObject<XamListener>(handle);
listener->areas = qwAreas; listener->areas = qwAreas;
return GUEST_HANDLE(handle); return GetKernelHandle(listener);
} }
SWA_API void XamNotifyEnqueueEvent(DWORD dwId, DWORD dwParam) SWA_API void XamNotifyEnqueueEvent(DWORD dwId, DWORD dwParam)
@ -90,7 +162,7 @@ SWA_API void XamNotifyEnqueueEvent(DWORD dwId, DWORD dwParam)
SWA_API bool XNotifyGetNext(DWORD hNotification, DWORD dwMsgFilter, XDWORD* pdwId, XDWORD* pParam) SWA_API bool XNotifyGetNext(DWORD hNotification, DWORD dwMsgFilter, XDWORD* pdwId, XDWORD* pParam)
{ {
auto& listener = *ObTryQueryObject<XamListener>(HOST_HANDLE(hNotification)); auto& listener = *GetKernelObject<XamListener>(hNotification);
if (dwMsgFilter) if (dwMsgFilter)
{ {
@ -188,19 +260,19 @@ SWA_API uint32_t XamContentCreateEnumerator(DWORD dwUserIndex, DWORD DeviceID, D
const auto& registry = gContentRegistry[dwContentType - 1]; const auto& registry = gContentRegistry[dwContentType - 1];
const auto& values = registry | std::views::values; const auto& values = registry | std::views::values;
const int handle = ObInsertObject(new XamEnumerator(cItem, sizeof(_XCONTENT_DATA), values.begin(), values.end())); auto* enumerator = CreateKernelObject<XamEnumerator<decltype(values.begin())>>(cItem, sizeof(_XCONTENT_DATA), values.begin(), values.end());
if (pcbBuffer) if (pcbBuffer)
*pcbBuffer = sizeof(_XCONTENT_DATA) * cItem; *pcbBuffer = sizeof(_XCONTENT_DATA) * cItem;
*phEnum = GUEST_HANDLE(handle); *phEnum = GetKernelHandle(enumerator);
return 0; return 0;
} }
SWA_API uint32_t XamEnumerate(uint32_t hEnum, DWORD dwFlags, PVOID pvBuffer, DWORD cbBuffer, XLPDWORD pcItemsReturned, XXOVERLAPPED* pOverlapped) SWA_API uint32_t XamEnumerate(uint32_t hEnum, DWORD dwFlags, PVOID pvBuffer, DWORD cbBuffer, XLPDWORD pcItemsReturned, XXOVERLAPPED* pOverlapped)
{ {
auto* enumerator = ObTryQueryObject<XamEnumeratorBase>(HOST_HANDLE(hEnum)); auto* enumerator = GetKernelObject<XamEnumeratorBase>(hEnum);
const auto count = enumerator->Next(pvBuffer); const auto count = enumerator->Next(pvBuffer);
if (count == -1) if (count == -1)

View file

@ -5,82 +5,6 @@
#define MSG_AREA(msgid) (((msgid) >> 16) & 0xFFFF) #define MSG_AREA(msgid) (((msgid) >> 16) & 0xFFFF)
#define MSG_NUMBER(msgid) ((msgid) & 0xFFFF) #define MSG_NUMBER(msgid) ((msgid) & 0xFFFF)
struct XamListener
{
uint32_t id{};
uint64_t areas{};
std::vector<std::tuple<DWORD, DWORD>> notifications;
XamListener(const XamListener&) = delete;
XamListener& operator=(const XamListener&) = delete;
XamListener();
~XamListener();
};
class XamEnumeratorBase
{
public:
virtual ~XamEnumeratorBase() = default;
virtual uint32_t Next(void* buffer)
{
return -1;
}
};
template<typename TIterator = std::vector<XHOSTCONTENT_DATA>::iterator>
class XamEnumerator : public XamEnumeratorBase
{
public:
uint32_t fetch;
size_t size;
TIterator position;
TIterator begin;
TIterator end;
XamEnumerator() = default;
XamEnumerator(uint32_t fetch, size_t size, TIterator begin, TIterator end) : fetch(fetch), size(size), position(begin), begin(begin), end(end)
{
}
uint32_t Next(void* buffer) override
{
if (position == end)
{
return -1;
}
if (buffer == nullptr)
{
for (size_t i = 0; i < fetch; i++)
{
if (position == end)
{
return i == 0 ? -1 : i;
}
++position;
}
}
for (size_t i = 0; i < fetch; i++)
{
if (position == end)
{
return i == 0 ? -1 : i;
}
memcpy(buffer, &*position, size);
++position;
buffer = (void*)((size_t)buffer + size);
}
return fetch;
}
};
XCONTENT_DATA XamMakeContent(DWORD type, const std::string_view& name); XCONTENT_DATA XamMakeContent(DWORD type, const std::string_view& name);
void XamRegisterContent(const XCONTENT_DATA& data, const std::string_view& root); void XamRegisterContent(const XCONTENT_DATA& data, const std::string_view& root);

View file

@ -2,47 +2,36 @@
#include "xdm.h" #include "xdm.h"
#include "freelist.h" #include "freelist.h"
FreeList<std::tuple<std::unique_ptr<char>, TypeDestructor_t>> gKernelObjects; Mutex g_kernelLock;
Mutex gKernelLock;
void* ObQueryObject(size_t handle) void DestroyKernelObject(KernelObject* obj)
{ {
std::lock_guard guard{ gKernelLock }; obj->~KernelObject();
g_userHeap.Free(obj);
if (handle >= gKernelObjects.items.size())
return nullptr;
return std::get<0>(gKernelObjects[handle]).get();
} }
uint32_t ObInsertObject(void* object, TypeDestructor_t destructor) uint32_t GetKernelHandle(KernelObject* obj)
{ {
std::lock_guard guard{ gKernelLock }; assert(obj != GetInvalidKernelObject());
return g_memory.MapVirtual(obj);
const auto handle = gKernelObjects.Alloc();
auto& holder = gKernelObjects[handle];
std::get<0>(holder).reset(static_cast<char*>(object));
std::get<1>(holder) = destructor;
return handle;
} }
void ObCloseHandle(uint32_t handle) void DestroyKernelObject(uint32_t handle)
{ {
std::lock_guard guard{ gKernelLock }; DestroyKernelObject(GetKernelObject(handle));
}
auto& obj = gKernelObjects[handle];
bool IsKernelObject(uint32_t handle)
if (std::get<1>(obj)(std::get<0>(obj).get())) {
{ return (handle & 0x80000000) != 0;
std::get<0>(obj).reset(); }
}
else bool IsKernelObject(void* obj)
{ {
std::get<0>(obj).release(); return IsKernelObject(g_memory.MapVirtual(obj));
} }
gKernelObjects.Free(handle); bool IsInvalidKernelObject(void* obj)
{
return obj == GetInvalidKernelObject();
} }

View file

@ -1,56 +1,73 @@
#pragma once #pragma once
#define DUMMY_HANDLE (DWORD)('HAND')
#define OBJECT_SIGNATURE (DWORD)'XBOX'
extern Mutex gKernelLock; #include "heap.h"
#include "memory.h"
void* ObQueryObject(size_t handle); #define OBJECT_SIGNATURE (DWORD)'XBOX'
uint32_t ObInsertObject(void* object, TypeDestructor_t destructor); #define GUEST_INVALID_HANDLE_VALUE 0xFFFFFFFF
void ObCloseHandle(uint32_t handle);
struct KernelObject
{
virtual ~KernelObject()
{
;
}
};
template<typename T, typename... Args>
inline T* CreateKernelObject(Args&&... args)
{
static_assert(std::is_base_of_v<KernelObject, T>);
return g_userHeap.AllocPhysical<T>(std::forward<Args>(args)...);
}
template<typename T = KernelObject>
inline T* GetKernelObject(uint32_t handle)
{
assert(handle != GUEST_INVALID_HANDLE_VALUE);
return reinterpret_cast<T*>(g_memory.Translate(handle));
}
uint32_t GetKernelHandle(KernelObject* obj);
void DestroyKernelObject(KernelObject* obj);
void DestroyKernelObject(uint32_t handle);
bool IsKernelObject(uint32_t handle);
bool IsKernelObject(void* obj);
bool IsInvalidKernelObject(void* obj);
template<typename T = void>
inline T* GetInvalidKernelObject()
{
return reinterpret_cast<T*>(g_memory.Translate(GUEST_INVALID_HANDLE_VALUE));
}
extern Mutex g_kernelLock;
template<typename T> template<typename T>
T* ObQueryObject(XDISPATCHER_HEADER& header) inline T* QueryKernelObject(XDISPATCHER_HEADER& header)
{ {
std::lock_guard guard{ gKernelLock }; std::lock_guard guard{ g_kernelLock };
if (header.WaitListHead.Flink != OBJECT_SIGNATURE) if (header.WaitListHead.Flink != OBJECT_SIGNATURE)
{ {
header.WaitListHead.Flink = OBJECT_SIGNATURE; header.WaitListHead.Flink = OBJECT_SIGNATURE;
auto* obj = new T(reinterpret_cast<typename T::guest_type*>(&header)); auto* obj = CreateKernelObject<T>(reinterpret_cast<typename T::guest_type*>(&header));
header.WaitListHead.Blink = ObInsertObject(obj, DestroyObject<T>); header.WaitListHead.Blink = g_memory.MapVirtual(obj);
return obj; return obj;
} }
return static_cast<T*>(ObQueryObject(header.WaitListHead.Blink.get())); return static_cast<T*>(g_memory.Translate(header.WaitListHead.Blink.get()));
}
template<typename T>
size_t ObInsertObject(T* object)
{
return ObInsertObject(object, DestroyObject<T>);
}
template<typename T>
T* ObCreateObject(int& handle)
{
auto* obj = new T();
handle = ::ObInsertObject(obj, DestroyObject<T>);
return obj;
} }
// Get object without initialisation // Get object without initialisation
template<typename T> template<typename T>
T* ObTryQueryObject(XDISPATCHER_HEADER& header) inline T* TryQueryKernelObject(XDISPATCHER_HEADER& header)
{ {
if (header.WaitListHead.Flink != OBJECT_SIGNATURE) if (header.WaitListHead.Flink != OBJECT_SIGNATURE)
return nullptr; return nullptr;
return static_cast<T*>(ObQueryObject(header.WaitListHead.Blink)); return static_cast<T*>(g_memory.Translate(header.WaitListHead.Blink.get()));
} }
template<typename T>
T* ObTryQueryObject(int handle)
{
return static_cast<T*>(ObQueryObject(handle));
}