refactor(cleanup): proper device picker

This commit is contained in:
PancakeTAS 2025-12-15 00:10:10 +01:00
parent 5fa3ddc8e3
commit 83c3ce68b0
7 changed files with 72 additions and 12 deletions

View file

@ -4,6 +4,7 @@
#include <filesystem>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
@ -38,6 +39,13 @@ namespace lsfgvk {
~error() override;
};
/// Function type for picking a device based on its name and IDs
using DevicePicker = std::function<bool(
const std::string& deviceName,
std::pair<const std::string&, const std::string&> ids, // (vendor ID, device ID) in 0xXXXX format
const std::optional<std::string>& pci // (bus:slot.func) if available, no padded zeros
)>;
///
/// Main entry point of the library
///
@ -46,14 +54,14 @@ namespace lsfgvk {
///
/// Create a lsfg-vk instance
///
/// @param devicePicker Function that picks a physical device based on its name.
/// @param devicePicker Function that picks a physical device based on its name or other identifiers.
/// @param shaderDllPath Path to the Lossless.dll file to load shaders from.
/// @param allowLowPrecision Whether to load low-precision (FP16) shaders if supported by the device.
///
/// @throws lsfgvk::error on failure
///
Instance(
const std::function<bool(const std::string&)>& devicePicker,
const DevicePicker& devicePicker,
const std::filesystem::path& shaderDllPath,
bool allowLowPrecision
);

View file

@ -1,7 +1,9 @@
#include "utils.hpp"
#include <array>
#include <cstddef>
#include <cstdint>
#include <string>
#include <vulkan/vulkan_core.h>
@ -32,3 +34,14 @@ VkExtent2D ls::add_shift_extent(VkExtent2D extent, uint32_t a, uint32_t i) {
.height = (extent.height + a) >> i
};
}
std::string ls::to_hex_id(uint32_t id) {
const std::array<char, 17> chars = std::to_array("0123456789ABCDEF");
std::string result = "0x";
result += chars.at((id >> 12) & 0xF);
result += chars.at((id >> 8) & 0xF);
result += chars.at((id >> 4) & 0xF);
result += chars.at(id & 0xF);
return result;
}

View file

@ -9,6 +9,7 @@
#include <array>
#include <cstddef>
#include <cstdint>
#include <string>
#include <vector>
#include <vulkan/vulkan_core.h>
@ -70,4 +71,7 @@ namespace ls {
/// @param i the amount to shift by
/// @return the shifted extent
VkExtent2D add_shift_extent(VkExtent2D extent, uint32_t a, uint32_t i);
/// convert a device/vendor id into a hex string
std::string to_hex_id(uint32_t id);
}

View file

@ -119,20 +119,48 @@ namespace lsfgvk {
}
Instance::Instance(
const std::function<bool(const std::string&)>& devicePicker,
const DevicePicker& devicePicker,
const std::filesystem::path& shaderDllPath,
bool allowLowPrecision) {
const auto selectFunc = [&devicePicker](const vk::VulkanInstanceFuncs funcs,
const std::vector<VkPhysicalDevice>& devices) {
for (const auto& device : devices) {
VkPhysicalDeviceProperties props;
funcs.GetPhysicalDeviceProperties(device, &props);
// check if the physical device supports VK_EXT_pci_bus_info
uint32_t ext_count{};
funcs.EnumerateDeviceExtensionProperties(device, nullptr, &ext_count, nullptr);
std::array<char, 256> devname = std::to_array(props.deviceName);
std::vector<VkExtensionProperties> extensions(ext_count);
funcs.EnumerateDeviceExtensionProperties(device, nullptr, &ext_count, extensions.data());
const bool has_pci_ext = std::ranges::find_if(extensions,
[](const VkExtensionProperties& ext) {
return std::string(std::to_array(ext.extensionName).data())
== VK_EXT_PCI_BUS_INFO_EXTENSION_NAME;
}) != extensions.end();
// then fetch all available properties
VkPhysicalDevicePCIBusInfoPropertiesEXT pciInfo{
.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PCI_BUS_INFO_PROPERTIES_EXT
};
VkPhysicalDeviceProperties2 props{
.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2,
.pNext = has_pci_ext ? &pciInfo : nullptr
};
funcs.GetPhysicalDeviceProperties2(device, &props);
std::array<char, 256> devname = std::to_array(props.properties.deviceName);
devname[255] = '\0'; // ensure null-termination
const std::string& deviceName{devname.data()};
if (devicePicker(deviceName))
if (devicePicker(
std::string(devname.data()),
{ ls::to_hex_id(props.properties.vendorID),
ls::to_hex_id(props.properties.deviceID) },
has_pci_ext ? std::optional<std::string>{
std::to_string(pciInfo.pciBus) + ":" +
std::to_string(pciInfo.pciDevice) + "." +
std::to_string(pciInfo.pciFunction)
} : std::nullopt
))
return device;
}

View file

@ -17,7 +17,8 @@ namespace vk {
struct VulkanInstanceFuncs {
PFN_vkDestroyInstance DestroyInstance;
PFN_vkEnumeratePhysicalDevices EnumeratePhysicalDevices;
PFN_vkGetPhysicalDeviceProperties GetPhysicalDeviceProperties;
PFN_vkEnumerateDeviceExtensionProperties EnumerateDeviceExtensionProperties;
PFN_vkGetPhysicalDeviceProperties2 GetPhysicalDeviceProperties2;
PFN_vkGetPhysicalDeviceQueueFamilyProperties GetPhysicalDeviceQueueFamilyProperties;
PFN_vkGetPhysicalDeviceFeatures2 GetPhysicalDeviceFeatures2;
PFN_vkGetPhysicalDeviceMemoryProperties GetPhysicalDeviceMemoryProperties;

View file

@ -99,8 +99,10 @@ namespace {
.DestroyInstance = ipa<PFN_vkDestroyInstance>(mpa, i, "vkDestroyInstance"),
.EnumeratePhysicalDevices = ipa<PFN_vkEnumeratePhysicalDevices>(mpa, i,
"vkEnumeratePhysicalDevices"),
.GetPhysicalDeviceProperties = ipa<PFN_vkGetPhysicalDeviceProperties>(mpa, i,
"vkGetPhysicalDeviceProperties"),
.EnumerateDeviceExtensionProperties = ipa<PFN_vkEnumerateDeviceExtensionProperties>(mpa, i,
"vkEnumerateDeviceExtensionProperties"),
.GetPhysicalDeviceProperties2 = ipa<PFN_vkGetPhysicalDeviceProperties2>(mpa, i,
"vkGetPhysicalDeviceProperties2"),
.GetPhysicalDeviceQueueFamilyProperties =
ipa<PFN_vkGetPhysicalDeviceQueueFamilyProperties>(mpa, i,
"vkGetPhysicalDeviceQueueFamilyProperties"),

View file

@ -119,7 +119,11 @@ int main() {
// initialize lsfg-vk
lsfgvk::Instance lsfgvk{
[](const std::string&) {
[](
const std::string&,
std::pair<const std::string&, const std::string&>,
const std::optional<std::string>&
) {
return true;
},
"/home/pancake/.steam/steam/steamapps/common/Lossless Scaling/Lossless.dll",