diff --git a/src/hooks.cpp b/src/hooks.cpp index 3147cd9..e908dcb 100644 --- a/src/hooks.cpp +++ b/src/hooks.cpp @@ -14,12 +14,53 @@ namespace { bool initialized{false}; std::optional application; + VkResult myvkCreateInstance( + const VkInstanceCreateInfo* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkInstance* pInstance) { + // add extensions + std::vector extensions(pCreateInfo->enabledExtensionCount); + std::copy_n(pCreateInfo->ppEnabledExtensionNames, extensions.size(), extensions.data()); + + const std::vector requiredExtensions = { + VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME, + VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME + }; + for (const auto& ext : requiredExtensions) { + auto it = std::ranges::find(extensions, ext); + if (it == extensions.end()) + extensions.push_back(ext); + } + + VkInstanceCreateInfo createInfo = *pCreateInfo; + createInfo.enabledExtensionCount = static_cast(extensions.size()); + createInfo.ppEnabledExtensionNames = extensions.data(); + return vkCreateInstance(&createInfo, pAllocator, pInstance); + } + VkResult myvkCreateDevice( VkPhysicalDevice physicalDevice, const VkDeviceCreateInfo* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkDevice* pDevice) { - auto res = vkCreateDevice(physicalDevice, pCreateInfo, pAllocator, pDevice); + // add extensions + std::vector extensions(pCreateInfo->enabledExtensionCount); + std::copy_n(pCreateInfo->ppEnabledExtensionNames, extensions.size(), extensions.data()); + + const std::vector requiredExtensions = { + VK_KHR_EXTERNAL_MEMORY_EXTENSION_NAME, + VK_KHR_EXTERNAL_MEMORY_FD_EXTENSION_NAME + }; + for (const auto& ext : requiredExtensions) { + auto it = std::ranges::find(extensions, ext); + if (it == extensions.end()) + extensions.push_back(ext); + } + + VkDeviceCreateInfo createInfo = *pCreateInfo; + createInfo.enabledExtensionCount = static_cast(extensions.size()); + createInfo.ppEnabledExtensionNames = extensions.data(); + auto res = vkCreateDevice(physicalDevice, &createInfo, pAllocator, pDevice); // extract graphics and present queues std::vector queueCreateInfos(pCreateInfo->queueCreateInfoCount); @@ -150,7 +191,6 @@ namespace { return VK_SUCCESS; } - void myvkDestroySwapchainKHR( VkDevice device, VkSwapchainKHR swapchain, @@ -195,6 +235,8 @@ void Hooks::initialize() { } // register hooks to vulkan loader + Loader::VK::registerSymbol("vkCreateInstance", + reinterpret_cast(myvkCreateInstance)); Loader::VK::registerSymbol("vkCreateDevice", reinterpret_cast(myvkCreateDevice)); Loader::VK::registerSymbol("vkDestroyDevice", @@ -208,6 +250,8 @@ void Hooks::initialize() { // register hooks to dynamic loader under libvulkan.so.1 Loader::DL::File vk1("libvulkan.so.1"); + vk1.defineSymbol("vkCreateInstance", + reinterpret_cast(myvkCreateInstance)); vk1.defineSymbol("vkCreateDevice", reinterpret_cast(myvkCreateDevice)); vk1.defineSymbol("vkDestroyDevice", @@ -222,6 +266,8 @@ void Hooks::initialize() { // register hooks to dynamic loader under libvulkan.so Loader::DL::File vk2("libvulkan.so"); + vk2.defineSymbol("vkCreateInstance", + reinterpret_cast(myvkCreateInstance)); vk2.defineSymbol("vkCreateDevice", reinterpret_cast(myvkCreateDevice)); vk2.defineSymbol("vkDestroyDevice",