diff --git a/framegen/.old/include/pool/shaderpool.hpp b/framegen/.old/include/pool/shaderpool.hpp deleted file mode 100644 index 39595ed..0000000 --- a/framegen/.old/include/pool/shaderpool.hpp +++ /dev/null @@ -1,70 +0,0 @@ -#pragma once - -#include "core/device.hpp" -#include "core/pipeline.hpp" -#include "core/shadermodule.hpp" - -#include - -#include -#include -#include -#include -#include -#include -#include - -namespace LSFG::Pool { - - /// - /// Shader pool for each Vulkan device. - /// - class ShaderPool { - public: - ShaderPool() noexcept = default; - - /// - /// Create the shader pool. - /// - /// @param source Function to retrieve shader source code by name. - /// @param fp16 If true, use the FP16 variant of shaders. - /// - /// @throws std::runtime_error if the shader pool cannot be created. - /// - ShaderPool( - const std::function(const std::string&, bool)>& source, - bool fp16) - : source(source), fp16(fp16) {} - - /// - /// Retrieve a shader module by name or create it. - /// - /// @param name Name of the shader module - /// @param types Descriptor types for the shader module - /// @return Shader module - /// - /// @throws LSFG::vulkan_error if the shader module cannot be created. - /// - Core::ShaderModule getShader( - const Core::Device& device, const std::string& name, - const std::vector>& types); - - /// - /// Retrieve a pipeline shader module by name or create it. - /// - /// @param name Name of the shader module - /// @return Pipeline shader module or empty - /// - /// @throws LSFG::vulkan_error if the shader module cannot be created. - /// - Core::Pipeline getPipeline( - const Core::Device& device, const std::string& name); - private: - std::function(const std::string&, bool)> source; - bool fp16{false}; - - std::unordered_map shaders; - std::unordered_map pipelines; - }; - -} diff --git a/framegen/.old/src/pool/shaderpool.cpp b/framegen/.old/src/pool/shaderpool.cpp deleted file mode 100644 index 1e7af9a..0000000 --- a/framegen/.old/src/pool/shaderpool.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include "pool/shaderpool.hpp" -#include "core/shadermodule.hpp" -#include "core/device.hpp" -#include "core/pipeline.hpp" - -#include - -#include -#include -#include -#include -#include - -using namespace LSFG; -using namespace LSFG::Pool; - -Core::ShaderModule ShaderPool::getShader( - const Core::Device& device, const std::string& name, - const std::vector>& types) { - auto it = shaders.find(name); - if (it != shaders.end()) - return it->second; - - // grab the shader - auto bytecode = this->source(name, this->fp16); - if (bytecode.empty()) - throw std::runtime_error("Shader code is empty: " + name); - - // create the shader module - Core::ShaderModule shader(device, bytecode, types); - shaders[name] = shader; - return shader; -} - -Core::Pipeline ShaderPool::getPipeline( - const Core::Device& device, const std::string& name) { - auto it = pipelines.find(name); - if (it != pipelines.end()) - return it->second; - - // grab the shader module - auto shader = this->getShader(device, name, {}); - - // create the pipeline - Core::Pipeline pipeline(device, shader); - pipelines[name] = pipeline; - return pipeline; -} diff --git a/framegen/.old/src/pool/resourcepool.cpp b/framegen/.old2/resourcepool.cpp similarity index 100% rename from framegen/.old/src/pool/resourcepool.cpp rename to framegen/.old2/resourcepool.cpp diff --git a/framegen/.old/include/pool/resourcepool.hpp b/framegen/.old2/resourcepool.hpp similarity index 100% rename from framegen/.old/include/pool/resourcepool.hpp rename to framegen/.old2/resourcepool.hpp diff --git a/framegen/include/exception.hpp b/framegen/include/exception.hpp new file mode 100644 index 0000000..ec12ae8 --- /dev/null +++ b/framegen/include/exception.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include + +#include +#include + +namespace LSFG { + + /// Simple exception class for stacking errors. + class error : public std::runtime_error { + public: + /// + /// Construct a new error with a message. + /// + /// @param message The error message. + /// + explicit error(const std::string& message); + + /// + /// Construct a new error with a message. + /// + /// @param message The error message. + /// @param exe The original exception to rethrow. + /// + explicit error(const std::string& message, + const std::exception& exe); + + /// Get the exception as a string. + [[nodiscard]] const char* what() const noexcept override { + return message.c_str(); + } + + // Trivially copyable, moveable and destructible + error(const error&) = default; + error(error&&) = default; + error& operator=(const error&) = default; + error& operator=(error&&) = default; + ~error() noexcept override; + private: + std::string message; + }; + + +} diff --git a/framegen/include/lsfg.hpp b/framegen/include/lsfg.hpp new file mode 100644 index 0000000..0e55d7c --- /dev/null +++ b/framegen/include/lsfg.hpp @@ -0,0 +1,36 @@ +#pragma once + +#include "vk/core/device.hpp" +#include "vk/core/instance.hpp" +#include "vk/pool/shader_pool.hpp" +#include "vk/registry/shader_registry.hpp" + +#include + +namespace LSFG { + + // FIXME: device picking + + /// + /// Lossless Scaling Frame Generation instance. + /// + class Instance { + public: + /// + /// Create an instance. + /// + /// @param dll Path to the Lossless.dll file. + /// + /// @throws LSFG::error if lsfg creation fails. + /// + Instance(const std::filesystem::path& dll); + + private: + VK::Core::Instance vk; + VK::Core::Device vkd; + + VK::Registry::ShaderRegistry registry; + VK::Pool::ShaderPool shaders; + }; + +} diff --git a/framegen/include/trans/dll.hpp b/framegen/include/trans/dll.hpp new file mode 100644 index 0000000..072b8e4 --- /dev/null +++ b/framegen/include/trans/dll.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include +#include +#include +#include + +namespace Trans::DLL { + + /// + /// Parse all resources from a DLL file. + /// + /// @param filename Path to the DLL file. + /// @return A map of resource IDs to their binary data. + /// + /// @throws std::runtime_error on various failure points. + /// + std::unordered_map> parseDLL( + const std::filesystem::path& filename); + +} diff --git a/framegen/include/trans/rsrc.hpp b/framegen/include/trans/rsrc.hpp new file mode 100644 index 0000000..8f7e34a --- /dev/null +++ b/framegen/include/trans/rsrc.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include "vk/registry/shader_registry.hpp" + +#include + +namespace Trans::RSRC { + + /// + /// Load all resources into memory. + /// + /// @param filename Path to the DLL file + /// @param registry Shader registry + /// @param fp16 Prefer FP16 shaders + /// + /// @throws LSFG::error if loading fails. + /// + void loadResources( + const std::filesystem::path& filename, + VK::Registry::ShaderRegistry& registry, + bool fp16); + +} diff --git a/framegen/include/vk/core/device.hpp b/framegen/include/vk/core/device.hpp index 4952b34..46ea50a 100644 --- a/framegen/include/vk/core/device.hpp +++ b/framegen/include/vk/core/device.hpp @@ -11,6 +11,9 @@ namespace VK::Core { + // FIXME: The toggle for fp32 shouldn't be implemented here. + // FIXME: Device UUID needs an overhaul. + /// /// C++ wrapper class for a Vulkan device. /// diff --git a/framegen/src/exception.cpp b/framegen/src/exception.cpp new file mode 100644 index 0000000..9cfd801 --- /dev/null +++ b/framegen/src/exception.cpp @@ -0,0 +1,17 @@ +#include "exception.hpp" + +#include +#include +#include + +using namespace LSFG; + +error::error(const std::string& message) + : std::runtime_error(message), message(message) {} + +error::error(const std::string& message, const std::exception& exe) + : std::runtime_error(message) { + this->message = std::format("{}\n- {}", message, exe.what()); +} + +error::~error() noexcept = default; diff --git a/framegen/src/lsfg.cpp b/framegen/src/lsfg.cpp new file mode 100644 index 0000000..9f61dc2 --- /dev/null +++ b/framegen/src/lsfg.cpp @@ -0,0 +1,13 @@ +#include "lsfg.hpp" +#include "trans/rsrc.hpp" + +using namespace LSFG; + +Instance::Instance(const std::filesystem::path& dll) + : vkd(vk, 0, false) { + // load shaders from dll file + const bool fp16 = vkd.supportsFP16(); + Trans::RSRC::loadResources(dll, this->registry, fp16); + + // ... +} diff --git a/framegen/src/trans/dll.cpp b/framegen/src/trans/dll.cpp new file mode 100644 index 0000000..df05002 --- /dev/null +++ b/framegen/src/trans/dll.cpp @@ -0,0 +1,206 @@ +#include "trans/dll.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace Trans; + +/// DOS file header +struct DOSHeader { + uint16_t magic; // 0x5A4D + std::array pad; + int32_t pe_offset; // file offset +}; + +/// PE header +struct PEHeader { + uint32_t signature; // "PE\0\0" + std::array pad1; + uint16_t sect_count; + std::array pad2; + uint16_t opt_hdr_size; + std::array pad3; +}; + +/// (partial!) PE optional header +struct PEOptionalHeader { + uint16_t magic; // 0x20B + std::array pad4; + std::pair resource_table; // file offset/size +}; + +/// Section header +struct SectionHeader { + std::array pad1; + uint32_t vsize; // virtual + uint32_t vaddress; + uint32_t fsize; // raw + uint32_t foffset; + std::array pad2; +}; + +/// Resource directory +struct ResourceDirectory { + std::array pad; + uint16_t name_count; + uint16_t id_count; +}; + +/// Resource directory entry +struct ResourceDirectoryEntry { + uint32_t id; + uint32_t offset; // high bit = directory +}; + +/// Resource data entry +struct ResourceDataEntry { + uint32_t offset; + uint32_t size; + std::array pad; +}; + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunsafe-buffer-usage-in-container" +namespace { + /// Safely cast a vector to a pointer of type T + template + const T* safe_cast(const std::vector& data, size_t offset) { + const size_t end = offset + sizeof(T); + if (end > data.size() || end < offset) + throw std::runtime_error("Buffer overflow during safe cast"); + return reinterpret_cast(&data.at(offset)); + } + + /// Safely cast a vector to a span of T + template + std::span span_cast(const std::vector& data, size_t offset, size_t count) { + const size_t end = offset + (count * sizeof(T)); + if (end > data.size() || end < offset) + throw std::runtime_error("Buffer overflow during safe cast"); + return std::span(reinterpret_cast(&data.at(offset)), count); + } +} +#pragma clang diagnostic pop + +std::unordered_map> DLL::parseDLL( + const std::filesystem::path& filename) { + std::ifstream file(filename, std::ios::binary | std::ios::ate); + if (!file.is_open()) + throw std::runtime_error("Failed to open Lossless.dll"); + + const auto size = file.tellg(); + file.seekg(0, std::ios::beg); + + std::vector data(static_cast(size)); + if (!file.read(reinterpret_cast(data.data()), size)) + throw std::runtime_error("Failed to read Lossless.dll"); + + // parse dos header + size_t fileOffset = 0; + const auto* dosHdr = safe_cast(data, 0); + if (dosHdr->magic != 0x5A4D) + throw std::runtime_error("Invalid DOS header magic number"); + + // parse pe header + fileOffset += static_cast(dosHdr->pe_offset); + const auto* peHdr = safe_cast(data, fileOffset); + if (peHdr->signature != 0x00004550) + throw std::runtime_error("Invalid PE header signature"); + + // parse optional pe header + fileOffset += sizeof(PEHeader); + const auto* peOptHdr = safe_cast(data, fileOffset); + if (peOptHdr->magic != 0x20B) + throw std::runtime_error("Unsupported PE format (not PE32+)"); + const auto& [rsrc_rva, rsrc_size] = peOptHdr->resource_table; + + // locate section containing resources + std::optional rsrc_offset; + fileOffset += peHdr->opt_hdr_size; + const auto sectHdrs = span_cast(data, fileOffset, peHdr->sect_count); + for (const auto& sectHdr : sectHdrs) { + if (rsrc_rva < sectHdr.vaddress || rsrc_rva > (sectHdr.vaddress + sectHdr.vsize)) + continue; + + rsrc_offset.emplace((rsrc_rva - sectHdr.vaddress) + sectHdr.foffset); + break; + } + if (!rsrc_offset) + throw std::runtime_error("Failed to locate resource section"); + + // parse resource directory + fileOffset = rsrc_offset.value(); + const auto* rsrcDir = safe_cast(data, fileOffset); + if (rsrcDir->id_count < 3) + throw std::runtime_error("Incorrect resource directory"); + + // find resource table with data type + std::optional rsrc_tbl_offset; + fileOffset = rsrc_offset.value() + sizeof(ResourceDirectory); + const auto rsrcDirEntries = span_cast( + data, fileOffset, rsrcDir->name_count + rsrcDir->id_count); + for (const auto& rsrcDirEntry : rsrcDirEntries) { + if (rsrcDirEntry.id != 10) // RT_RCDATA + continue; + if ((rsrcDirEntry.offset & 0x80000000) == 0) + throw std::runtime_error("Expected resource directory, but found data entry"); + + rsrc_tbl_offset.emplace(rsrcDirEntry.offset & 0x7FFFFFFF); + } + if (!rsrc_tbl_offset) + throw std::runtime_error("Failed to locate RT_RCDATA directory"); + + // parse data type resource directory + fileOffset = rsrc_offset.value() + rsrc_tbl_offset.value(); + const auto* rsrcTbl = safe_cast(data, fileOffset); + if (rsrcTbl->id_count < 1) + throw std::runtime_error("Incorrect RT_RCDATA directory"); + + // collect all resources + fileOffset += sizeof(ResourceDirectory); + const auto rsrcTblEntries = span_cast( + data, fileOffset, rsrcTbl->name_count + rsrcTbl->id_count); + std::unordered_map> resources; + for (const auto& rsrcTblEntry : rsrcTblEntries) { + if ((rsrcTblEntry.offset & 0x80000000) == 0) + throw std::runtime_error("Expected resource directory, but found data entry"); + + // skip over language directory + fileOffset = rsrc_offset.value() + (rsrcTblEntry.offset & 0x7FFFFFFF); + const auto* langDir = safe_cast(data, fileOffset); + if (langDir->id_count < 1) + throw std::runtime_error("Incorrect language directory"); + + fileOffset += sizeof(ResourceDirectory); + const auto* langDirEntry = safe_cast(data, fileOffset); + if ((langDirEntry->offset & 0x80000000) != 0) + throw std::runtime_error("Expected resource data entry, but found directory"); + + // parse resource data entry + fileOffset = rsrc_offset.value() + (langDirEntry->offset & 0x7FFFFFFF); + const auto* entry = safe_cast(data, fileOffset); + if (entry->offset < rsrc_rva || entry->offset > (rsrc_rva + rsrc_size)) + throw std::runtime_error("Resource data entry points outside resource section"); + + // extract resource + std::vector resource(entry->size); + fileOffset = (entry->offset - rsrc_rva) + rsrc_offset.value(); + if (fileOffset + entry->size > data.size()) + throw std::runtime_error("Resource data entry points outside file"); + std::copy_n(&data.at(fileOffset), entry->size, resource.data()); + resources.emplace(rsrcTblEntry.id, std::move(resource)); + } + + return resources; +} diff --git a/framegen/src/trans/rsrc.cpp b/framegen/src/trans/rsrc.cpp new file mode 100644 index 0000000..9cd0ef0 --- /dev/null +++ b/framegen/src/trans/rsrc.cpp @@ -0,0 +1,116 @@ +#include "trans/rsrc.hpp" +#include "exception.hpp" +#include "trans/dll.hpp" +#include "vk/registry/shader_registry.hpp" +#include +#include +#include +#include + +using namespace Trans; + +namespace { + const size_t OFFSET_FP16 = 49; + + std::vector get( + const std::unordered_map>& source, + uint32_t id, + bool fp16) { + auto it = source.find(id + (fp16 ? OFFSET_FP16 : 0)); + if (it == source.end()) + throw LSFG::error("Missing resource ID: " + std::to_string(id)); + + return it->second; + } +} + +/* +TODO for normal mode + { "gamma[0]", 257 + NO }, + { "gamma[1]", 259 + NO }, + { "gamma[2]", 260 + NO }, + { "gamma[3]", 261 + NO }, + { "gamma[4]", 262 + NO }, + { "delta[0]", 257 + NO }, + { "delta[1]", 263 + NO }, + { "delta[2]", 264 + NO }, + { "delta[3]", 265 + NO }, + { "delta[4]", 266 + NO }, + { "delta[5]", 258 + NO }, + { "delta[6]", 271 + NO }, + { "delta[7]", 272 + NO }, + { "delta[8]", 273 + NO }, + { "delta[9]", 274 + NO }, + { "generate", 256 + NO }, +}}; + */ + +void RSRC::loadResources( + const std::filesystem::path& filename, + VK::Registry::ShaderRegistry& registry, + bool fp16) { + // parse dll file + if (!std::filesystem::exists(filename)) + throw LSFG::error("DLL file does not exist: " + filename.string()); + + std::unordered_map> rsrcs; + try { + rsrcs = DLL::parseDLL(filename); + } catch (const std::runtime_error& e) { + throw LSFG::error("Unable to parse Lossless.dll file", e); + } + + // register resources + registry.registerModule("mipmaps", VK::Registry::ShaderModuleInfo { + .code = get(rsrcs, 304, fp16), + .sampledImages = 1, .storageImages = 7, + .buffers = 1, .samplers = 1 + }); + + registry.registerModule("alpha[0]", VK::Registry::ShaderModuleInfo { + .code = get(rsrcs, 316, fp16), + .sampledImages = 1, .storageImages = 2, + .samplers = 1 + }); + registry.registerModule("alpha[1]", VK::Registry::ShaderModuleInfo { + .code = get(rsrcs, 317, fp16), + .sampledImages = 2, .storageImages = 2, + .samplers = 1 + }); + registry.registerModule("alpha[2]", VK::Registry::ShaderModuleInfo { + .code = get(rsrcs, 318, fp16), + .sampledImages = 2, .storageImages = 4, + .samplers = 1 + }); + registry.registerModule("alpha[3]", VK::Registry::ShaderModuleInfo { + .code = get(rsrcs, 319, fp16), + .sampledImages = 4, .storageImages = 4, + .samplers = 1 + }); + + registry.registerModule("beta[0]", VK::Registry::ShaderModuleInfo { + .code = get(rsrcs, 324, fp16), + .sampledImages = 12, .storageImages = 2, + .samplers = 1 + }); + registry.registerModule("beta[1]", VK::Registry::ShaderModuleInfo { + .code = get(rsrcs, 325, fp16), + .sampledImages = 2, .storageImages = 2, + .samplers = 1 + }); + registry.registerModule("beta[2]", VK::Registry::ShaderModuleInfo { + .code = get(rsrcs, 326, fp16), + .sampledImages = 2, .storageImages = 2, + .samplers = 1 + }); + registry.registerModule("beta[3]", VK::Registry::ShaderModuleInfo { + .code = get(rsrcs, 327, fp16), + .sampledImages = 2, .storageImages = 2, + .samplers = 1 + }); + registry.registerModule("beta[4]", VK::Registry::ShaderModuleInfo { + .code = get(rsrcs, 328, fp16), + .sampledImages = 4, .storageImages = 6, + .buffers = 1, .samplers = 1 + }); +}