chore(cleanup): progress on shader registry

This commit is contained in:
PancakeTAS 2025-09-11 20:49:46 +02:00
parent 76674ff9b0
commit cfbc105b9d
No known key found for this signature in database
13 changed files with 480 additions and 118 deletions

View file

@ -1,70 +0,0 @@
#pragma once
#include "core/device.hpp"
#include "core/pipeline.hpp"
#include "core/shadermodule.hpp"
#include <vulkan/vulkan_core.h>
#include <cstdint>
#include <cstddef>
#include <functional>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace LSFG::Pool {
///
/// Shader pool for each Vulkan device.
///
class ShaderPool {
public:
ShaderPool() noexcept = default;
///
/// Create the shader pool.
///
/// @param source Function to retrieve shader source code by name.
/// @param fp16 If true, use the FP16 variant of shaders.
///
/// @throws std::runtime_error if the shader pool cannot be created.
///
ShaderPool(
const std::function<std::vector<uint8_t>(const std::string&, bool)>& source,
bool fp16)
: source(source), fp16(fp16) {}
///
/// Retrieve a shader module by name or create it.
///
/// @param name Name of the shader module
/// @param types Descriptor types for the shader module
/// @return Shader module
///
/// @throws LSFG::vulkan_error if the shader module cannot be created.
///
Core::ShaderModule getShader(
const Core::Device& device, const std::string& name,
const std::vector<std::pair<size_t, VkDescriptorType>>& types);
///
/// Retrieve a pipeline shader module by name or create it.
///
/// @param name Name of the shader module
/// @return Pipeline shader module or empty
///
/// @throws LSFG::vulkan_error if the shader module cannot be created.
///
Core::Pipeline getPipeline(
const Core::Device& device, const std::string& name);
private:
std::function<std::vector<uint8_t>(const std::string&, bool)> source;
bool fp16{false};
std::unordered_map<std::string, Core::ShaderModule> shaders;
std::unordered_map<std::string, Core::Pipeline> pipelines;
};
}

View file

@ -1,48 +0,0 @@
#include "pool/shaderpool.hpp"
#include "core/shadermodule.hpp"
#include "core/device.hpp"
#include "core/pipeline.hpp"
#include <vulkan/vulkan_core.h>
#include <cstddef>
#include <stdexcept>
#include <string>
#include <vector>
#include <utility>
using namespace LSFG;
using namespace LSFG::Pool;
Core::ShaderModule ShaderPool::getShader(
const Core::Device& device, const std::string& name,
const std::vector<std::pair<size_t, VkDescriptorType>>& types) {
auto it = shaders.find(name);
if (it != shaders.end())
return it->second;
// grab the shader
auto bytecode = this->source(name, this->fp16);
if (bytecode.empty())
throw std::runtime_error("Shader code is empty: " + name);
// create the shader module
Core::ShaderModule shader(device, bytecode, types);
shaders[name] = shader;
return shader;
}
Core::Pipeline ShaderPool::getPipeline(
const Core::Device& device, const std::string& name) {
auto it = pipelines.find(name);
if (it != pipelines.end())
return it->second;
// grab the shader module
auto shader = this->getShader(device, name, {});
// create the pipeline
Core::Pipeline pipeline(device, shader);
pipelines[name] = pipeline;
return pipeline;
}

View file

@ -0,0 +1,45 @@
#pragma once
#include <vulkan/vulkan_core.h>
#include <stdexcept>
#include <string>
namespace LSFG {
/// Simple exception class for stacking errors.
class error : public std::runtime_error {
public:
///
/// Construct a new error with a message.
///
/// @param message The error message.
///
explicit error(const std::string& message);
///
/// Construct a new error with a message.
///
/// @param message The error message.
/// @param exe The original exception to rethrow.
///
explicit error(const std::string& message,
const std::exception& exe);
/// Get the exception as a string.
[[nodiscard]] const char* what() const noexcept override {
return message.c_str();
}
// Trivially copyable, moveable and destructible
error(const error&) = default;
error(error&&) = default;
error& operator=(const error&) = default;
error& operator=(error&&) = default;
~error() noexcept override;
private:
std::string message;
};
}

36
framegen/include/lsfg.hpp Normal file
View file

@ -0,0 +1,36 @@
#pragma once
#include "vk/core/device.hpp"
#include "vk/core/instance.hpp"
#include "vk/pool/shader_pool.hpp"
#include "vk/registry/shader_registry.hpp"
#include <filesystem>
namespace LSFG {
// FIXME: device picking
///
/// Lossless Scaling Frame Generation instance.
///
class Instance {
public:
///
/// Create an instance.
///
/// @param dll Path to the Lossless.dll file.
///
/// @throws LSFG::error if lsfg creation fails.
///
Instance(const std::filesystem::path& dll);
private:
VK::Core::Instance vk;
VK::Core::Device vkd;
VK::Registry::ShaderRegistry registry;
VK::Pool::ShaderPool shaders;
};
}

View file

@ -0,0 +1,21 @@
#pragma once
#include <unordered_map>
#include <filesystem>
#include <cstdint>
#include <vector>
namespace Trans::DLL {
///
/// Parse all resources from a DLL file.
///
/// @param filename Path to the DLL file.
/// @return A map of resource IDs to their binary data.
///
/// @throws std::runtime_error on various failure points.
///
std::unordered_map<uint32_t, std::vector<uint8_t>> parseDLL(
const std::filesystem::path& filename);
}

View file

@ -0,0 +1,23 @@
#pragma once
#include "vk/registry/shader_registry.hpp"
#include <filesystem>
namespace Trans::RSRC {
///
/// Load all resources into memory.
///
/// @param filename Path to the DLL file
/// @param registry Shader registry
/// @param fp16 Prefer FP16 shaders
///
/// @throws LSFG::error if loading fails.
///
void loadResources(
const std::filesystem::path& filename,
VK::Registry::ShaderRegistry& registry,
bool fp16);
}

View file

@ -11,6 +11,9 @@
namespace VK::Core { namespace VK::Core {
// FIXME: The toggle for fp32 shouldn't be implemented here.
// FIXME: Device UUID needs an overhaul.
/// ///
/// C++ wrapper class for a Vulkan device. /// C++ wrapper class for a Vulkan device.
/// ///

View file

@ -0,0 +1,17 @@
#include "exception.hpp"
#include <stdexcept>
#include <format>
#include <string>
using namespace LSFG;
error::error(const std::string& message)
: std::runtime_error(message), message(message) {}
error::error(const std::string& message, const std::exception& exe)
: std::runtime_error(message) {
this->message = std::format("{}\n- {}", message, exe.what());
}
error::~error() noexcept = default;

13
framegen/src/lsfg.cpp Normal file
View file

@ -0,0 +1,13 @@
#include "lsfg.hpp"
#include "trans/rsrc.hpp"
using namespace LSFG;
Instance::Instance(const std::filesystem::path& dll)
: vkd(vk, 0, false) {
// load shaders from dll file
const bool fp16 = vkd.supportsFP16();
Trans::RSRC::loadResources(dll, this->registry, fp16);
// ...
}

206
framegen/src/trans/dll.cpp Normal file
View file

@ -0,0 +1,206 @@
#include "trans/dll.hpp"
#include <unordered_map>
#include <filesystem>
#include <stdexcept>
#include <algorithm>
#include <iostream>
#include <optional>
#include <cstddef>
#include <cstdint>
#include <fstream>
#include <utility>
#include <vector>
#include <array>
#include <span>
using namespace Trans;
/// DOS file header
struct DOSHeader {
uint16_t magic; // 0x5A4D
std::array<uint16_t, 29> pad;
int32_t pe_offset; // file offset
};
/// PE header
struct PEHeader {
uint32_t signature; // "PE\0\0"
std::array<uint16_t, 1> pad1;
uint16_t sect_count;
std::array<uint16_t, 6> pad2;
uint16_t opt_hdr_size;
std::array<uint16_t, 1> pad3;
};
/// (partial!) PE optional header
struct PEOptionalHeader {
uint16_t magic; // 0x20B
std::array<uint16_t, 63> pad4;
std::pair<uint32_t, uint32_t> resource_table; // file offset/size
};
/// Section header
struct SectionHeader {
std::array<uint16_t, 4> pad1;
uint32_t vsize; // virtual
uint32_t vaddress;
uint32_t fsize; // raw
uint32_t foffset;
std::array<uint16_t, 8> pad2;
};
/// Resource directory
struct ResourceDirectory {
std::array<uint16_t, 6> pad;
uint16_t name_count;
uint16_t id_count;
};
/// Resource directory entry
struct ResourceDirectoryEntry {
uint32_t id;
uint32_t offset; // high bit = directory
};
/// Resource data entry
struct ResourceDataEntry {
uint32_t offset;
uint32_t size;
std::array<uint32_t, 2> pad;
};
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunsafe-buffer-usage-in-container"
namespace {
/// Safely cast a vector to a pointer of type T
template<typename T>
const T* safe_cast(const std::vector<uint8_t>& data, size_t offset) {
const size_t end = offset + sizeof(T);
if (end > data.size() || end < offset)
throw std::runtime_error("Buffer overflow during safe cast");
return reinterpret_cast<const T*>(&data.at(offset));
}
/// Safely cast a vector to a span of T
template<typename T>
std::span<const T> span_cast(const std::vector<uint8_t>& data, size_t offset, size_t count) {
const size_t end = offset + (count * sizeof(T));
if (end > data.size() || end < offset)
throw std::runtime_error("Buffer overflow during safe cast");
return std::span<const T>(reinterpret_cast<const T*>(&data.at(offset)), count);
}
}
#pragma clang diagnostic pop
std::unordered_map<uint32_t, std::vector<uint8_t>> DLL::parseDLL(
const std::filesystem::path& filename) {
std::ifstream file(filename, std::ios::binary | std::ios::ate);
if (!file.is_open())
throw std::runtime_error("Failed to open Lossless.dll");
const auto size = file.tellg();
file.seekg(0, std::ios::beg);
std::vector<uint8_t> data(static_cast<size_t>(size));
if (!file.read(reinterpret_cast<char*>(data.data()), size))
throw std::runtime_error("Failed to read Lossless.dll");
// parse dos header
size_t fileOffset = 0;
const auto* dosHdr = safe_cast<const DOSHeader>(data, 0);
if (dosHdr->magic != 0x5A4D)
throw std::runtime_error("Invalid DOS header magic number");
// parse pe header
fileOffset += static_cast<size_t>(dosHdr->pe_offset);
const auto* peHdr = safe_cast<const PEHeader>(data, fileOffset);
if (peHdr->signature != 0x00004550)
throw std::runtime_error("Invalid PE header signature");
// parse optional pe header
fileOffset += sizeof(PEHeader);
const auto* peOptHdr = safe_cast<const PEOptionalHeader>(data, fileOffset);
if (peOptHdr->magic != 0x20B)
throw std::runtime_error("Unsupported PE format (not PE32+)");
const auto& [rsrc_rva, rsrc_size] = peOptHdr->resource_table;
// locate section containing resources
std::optional<size_t> rsrc_offset;
fileOffset += peHdr->opt_hdr_size;
const auto sectHdrs = span_cast<const SectionHeader>(data, fileOffset, peHdr->sect_count);
for (const auto& sectHdr : sectHdrs) {
if (rsrc_rva < sectHdr.vaddress || rsrc_rva > (sectHdr.vaddress + sectHdr.vsize))
continue;
rsrc_offset.emplace((rsrc_rva - sectHdr.vaddress) + sectHdr.foffset);
break;
}
if (!rsrc_offset)
throw std::runtime_error("Failed to locate resource section");
// parse resource directory
fileOffset = rsrc_offset.value();
const auto* rsrcDir = safe_cast<const ResourceDirectory>(data, fileOffset);
if (rsrcDir->id_count < 3)
throw std::runtime_error("Incorrect resource directory");
// find resource table with data type
std::optional<size_t> rsrc_tbl_offset;
fileOffset = rsrc_offset.value() + sizeof(ResourceDirectory);
const auto rsrcDirEntries = span_cast<const ResourceDirectoryEntry>(
data, fileOffset, rsrcDir->name_count + rsrcDir->id_count);
for (const auto& rsrcDirEntry : rsrcDirEntries) {
if (rsrcDirEntry.id != 10) // RT_RCDATA
continue;
if ((rsrcDirEntry.offset & 0x80000000) == 0)
throw std::runtime_error("Expected resource directory, but found data entry");
rsrc_tbl_offset.emplace(rsrcDirEntry.offset & 0x7FFFFFFF);
}
if (!rsrc_tbl_offset)
throw std::runtime_error("Failed to locate RT_RCDATA directory");
// parse data type resource directory
fileOffset = rsrc_offset.value() + rsrc_tbl_offset.value();
const auto* rsrcTbl = safe_cast<const ResourceDirectory>(data, fileOffset);
if (rsrcTbl->id_count < 1)
throw std::runtime_error("Incorrect RT_RCDATA directory");
// collect all resources
fileOffset += sizeof(ResourceDirectory);
const auto rsrcTblEntries = span_cast<const ResourceDirectoryEntry>(
data, fileOffset, rsrcTbl->name_count + rsrcTbl->id_count);
std::unordered_map<uint32_t, std::vector<uint8_t>> resources;
for (const auto& rsrcTblEntry : rsrcTblEntries) {
if ((rsrcTblEntry.offset & 0x80000000) == 0)
throw std::runtime_error("Expected resource directory, but found data entry");
// skip over language directory
fileOffset = rsrc_offset.value() + (rsrcTblEntry.offset & 0x7FFFFFFF);
const auto* langDir = safe_cast<const ResourceDirectory>(data, fileOffset);
if (langDir->id_count < 1)
throw std::runtime_error("Incorrect language directory");
fileOffset += sizeof(ResourceDirectory);
const auto* langDirEntry = safe_cast<const ResourceDirectoryEntry>(data, fileOffset);
if ((langDirEntry->offset & 0x80000000) != 0)
throw std::runtime_error("Expected resource data entry, but found directory");
// parse resource data entry
fileOffset = rsrc_offset.value() + (langDirEntry->offset & 0x7FFFFFFF);
const auto* entry = safe_cast<const ResourceDataEntry>(data, fileOffset);
if (entry->offset < rsrc_rva || entry->offset > (rsrc_rva + rsrc_size))
throw std::runtime_error("Resource data entry points outside resource section");
// extract resource
std::vector<uint8_t> resource(entry->size);
fileOffset = (entry->offset - rsrc_rva) + rsrc_offset.value();
if (fileOffset + entry->size > data.size())
throw std::runtime_error("Resource data entry points outside file");
std::copy_n(&data.at(fileOffset), entry->size, resource.data());
resources.emplace(rsrcTblEntry.id, std::move(resource));
}
return resources;
}

116
framegen/src/trans/rsrc.cpp Normal file
View file

@ -0,0 +1,116 @@
#include "trans/rsrc.hpp"
#include "exception.hpp"
#include "trans/dll.hpp"
#include "vk/registry/shader_registry.hpp"
#include <cstddef>
#include <cstdint>
#include <filesystem>
#include <vector>
using namespace Trans;
namespace {
const size_t OFFSET_FP16 = 49;
std::vector<uint8_t> get(
const std::unordered_map<uint32_t, std::vector<uint8_t>>& source,
uint32_t id,
bool fp16) {
auto it = source.find(id + (fp16 ? OFFSET_FP16 : 0));
if (it == source.end())
throw LSFG::error("Missing resource ID: " + std::to_string(id));
return it->second;
}
}
/*
TODO for normal mode
{ "gamma[0]", 257 + NO },
{ "gamma[1]", 259 + NO },
{ "gamma[2]", 260 + NO },
{ "gamma[3]", 261 + NO },
{ "gamma[4]", 262 + NO },
{ "delta[0]", 257 + NO },
{ "delta[1]", 263 + NO },
{ "delta[2]", 264 + NO },
{ "delta[3]", 265 + NO },
{ "delta[4]", 266 + NO },
{ "delta[5]", 258 + NO },
{ "delta[6]", 271 + NO },
{ "delta[7]", 272 + NO },
{ "delta[8]", 273 + NO },
{ "delta[9]", 274 + NO },
{ "generate", 256 + NO },
}};
*/
void RSRC::loadResources(
const std::filesystem::path& filename,
VK::Registry::ShaderRegistry& registry,
bool fp16) {
// parse dll file
if (!std::filesystem::exists(filename))
throw LSFG::error("DLL file does not exist: " + filename.string());
std::unordered_map<uint32_t, std::vector<uint8_t>> rsrcs;
try {
rsrcs = DLL::parseDLL(filename);
} catch (const std::runtime_error& e) {
throw LSFG::error("Unable to parse Lossless.dll file", e);
}
// register resources
registry.registerModule("mipmaps", VK::Registry::ShaderModuleInfo {
.code = get(rsrcs, 304, fp16),
.sampledImages = 1, .storageImages = 7,
.buffers = 1, .samplers = 1
});
registry.registerModule("alpha[0]", VK::Registry::ShaderModuleInfo {
.code = get(rsrcs, 316, fp16),
.sampledImages = 1, .storageImages = 2,
.samplers = 1
});
registry.registerModule("alpha[1]", VK::Registry::ShaderModuleInfo {
.code = get(rsrcs, 317, fp16),
.sampledImages = 2, .storageImages = 2,
.samplers = 1
});
registry.registerModule("alpha[2]", VK::Registry::ShaderModuleInfo {
.code = get(rsrcs, 318, fp16),
.sampledImages = 2, .storageImages = 4,
.samplers = 1
});
registry.registerModule("alpha[3]", VK::Registry::ShaderModuleInfo {
.code = get(rsrcs, 319, fp16),
.sampledImages = 4, .storageImages = 4,
.samplers = 1
});
registry.registerModule("beta[0]", VK::Registry::ShaderModuleInfo {
.code = get(rsrcs, 324, fp16),
.sampledImages = 12, .storageImages = 2,
.samplers = 1
});
registry.registerModule("beta[1]", VK::Registry::ShaderModuleInfo {
.code = get(rsrcs, 325, fp16),
.sampledImages = 2, .storageImages = 2,
.samplers = 1
});
registry.registerModule("beta[2]", VK::Registry::ShaderModuleInfo {
.code = get(rsrcs, 326, fp16),
.sampledImages = 2, .storageImages = 2,
.samplers = 1
});
registry.registerModule("beta[3]", VK::Registry::ShaderModuleInfo {
.code = get(rsrcs, 327, fp16),
.sampledImages = 2, .storageImages = 2,
.samplers = 1
});
registry.registerModule("beta[4]", VK::Registry::ShaderModuleInfo {
.code = get(rsrcs, 328, fp16),
.sampledImages = 4, .storageImages = 6,
.buffers = 1, .samplers = 1
});
}