diff --git a/framegen/include/pool/shaderpool.hpp b/framegen/include/pool/shaderpool.hpp index e572154..39595ed 100644 --- a/framegen/include/pool/shaderpool.hpp +++ b/framegen/include/pool/shaderpool.hpp @@ -27,11 +27,14 @@ namespace LSFG::Pool { /// Create the shader pool. /// /// @param source Function to retrieve shader source code by name. + /// @param fp16 If true, use the FP16 variant of shaders. /// /// @throws std::runtime_error if the shader pool cannot be created. /// - ShaderPool(const std::function(const std::string&)>& source) - : source(source) {} + ShaderPool( + const std::function(const std::string&, bool)>& source, + bool fp16) + : source(source), fp16(fp16) {} /// /// Retrieve a shader module by name or create it. @@ -57,7 +60,9 @@ namespace LSFG::Pool { Core::Pipeline getPipeline( const Core::Device& device, const std::string& name); private: - std::function(const std::string&)> source; + std::function(const std::string&, bool)> source; + bool fp16{false}; + std::unordered_map shaders; std::unordered_map pipelines; }; diff --git a/framegen/public/lsfg_3_1.hpp b/framegen/public/lsfg_3_1.hpp index 6620008..6486612 100644 --- a/framegen/public/lsfg_3_1.hpp +++ b/framegen/public/lsfg_3_1.hpp @@ -23,7 +23,7 @@ namespace LSFG_3_1 { [[gnu::visibility("default")]] void initialize(uint64_t deviceUUID, bool isHdr, float flowScale, uint64_t generationCount, - const std::function(const std::string&)>& loader); + const std::function(const std::string&, bool)>& loader); /// /// Initialize the renderdoc API. diff --git a/framegen/public/lsfg_3_1p.hpp b/framegen/public/lsfg_3_1p.hpp index 14df35c..410541b 100644 --- a/framegen/public/lsfg_3_1p.hpp +++ b/framegen/public/lsfg_3_1p.hpp @@ -23,7 +23,7 @@ namespace LSFG_3_1P { [[gnu::visibility("default")]] void initialize(uint64_t deviceUUID, bool isHdr, float flowScale, uint64_t generationCount, - const std::function(const std::string&)>& loader); + const std::function(const std::string&, bool)>& loader); /// /// Initialize the renderdoc API. diff --git a/framegen/src/pool/shaderpool.cpp b/framegen/src/pool/shaderpool.cpp index 7aaf00b..1e7af9a 100644 --- a/framegen/src/pool/shaderpool.cpp +++ b/framegen/src/pool/shaderpool.cpp @@ -22,7 +22,7 @@ Core::ShaderModule ShaderPool::getShader( return it->second; // grab the shader - auto bytecode = this->source(name); + auto bytecode = this->source(name, this->fp16); if (bytecode.empty()) throw std::runtime_error("Shader code is empty: " + name); diff --git a/framegen/v3.1_src/lsfg.cpp b/framegen/v3.1_src/lsfg.cpp index 1f4c2dc..6b5db5f 100644 --- a/framegen/v3.1_src/lsfg.cpp +++ b/framegen/v3.1_src/lsfg.cpp @@ -35,7 +35,7 @@ namespace { void LSFG_3_1::initialize(uint64_t deviceUUID, bool isHdr, float flowScale, uint64_t generationCount, - const std::function(const std::string&)>& loader) { + const std::function(const std::string&, bool)>& loader) { if (instance.has_value() || device.has_value()) return; @@ -52,7 +52,7 @@ void LSFG_3_1::initialize(uint64_t deviceUUID, device->descriptorPool = Core::DescriptorPool(device->device); device->resources = Pool::ResourcePool(device->isHdr, device->flowScale); - device->shaders = Pool::ShaderPool(loader); + device->shaders = Pool::ShaderPool(loader, device->device.getFP16Support()); std::srand(static_cast(std::time(nullptr))); } diff --git a/framegen/v3.1p_src/lsfg.cpp b/framegen/v3.1p_src/lsfg.cpp index 1bca0ce..9252e82 100644 --- a/framegen/v3.1p_src/lsfg.cpp +++ b/framegen/v3.1p_src/lsfg.cpp @@ -35,7 +35,7 @@ namespace { void LSFG_3_1P::initialize(uint64_t deviceUUID, bool isHdr, float flowScale, uint64_t generationCount, - const std::function(const std::string&)>& loader) { + const std::function(const std::string&, bool)>& loader) { if (instance.has_value() || device.has_value()) return; @@ -52,7 +52,7 @@ void LSFG_3_1P::initialize(uint64_t deviceUUID, device->descriptorPool = Core::DescriptorPool(device->device); device->resources = Pool::ResourcePool(device->isHdr, device->flowScale); - device->shaders = Pool::ShaderPool(loader); + device->shaders = Pool::ShaderPool(loader, device->device.getFP16Support()); std::srand(static_cast(std::time(nullptr))); }