feat(fp16): passthrough fp16 to shader pool

This commit is contained in:
PancakeTAS 2025-07-30 18:58:45 +02:00 committed by Pancake
parent b93a4eeaf2
commit 3fcde7c126
6 changed files with 15 additions and 10 deletions

View file

@ -27,11 +27,14 @@ namespace LSFG::Pool {
/// Create the shader pool. /// Create the shader pool.
/// ///
/// @param source Function to retrieve shader source code by name. /// @param source Function to retrieve shader source code by name.
/// @param fp16 If true, use the FP16 variant of shaders.
/// ///
/// @throws std::runtime_error if the shader pool cannot be created. /// @throws std::runtime_error if the shader pool cannot be created.
/// ///
ShaderPool(const std::function<std::vector<uint8_t>(const std::string&)>& source) ShaderPool(
: source(source) {} const std::function<std::vector<uint8_t>(const std::string&, bool)>& source,
bool fp16)
: source(source), fp16(fp16) {}
/// ///
/// Retrieve a shader module by name or create it. /// Retrieve a shader module by name or create it.
@ -57,7 +60,9 @@ namespace LSFG::Pool {
Core::Pipeline getPipeline( Core::Pipeline getPipeline(
const Core::Device& device, const std::string& name); const Core::Device& device, const std::string& name);
private: private:
std::function<std::vector<uint8_t>(const std::string&)> source; std::function<std::vector<uint8_t>(const std::string&, bool)> source;
bool fp16{false};
std::unordered_map<std::string, Core::ShaderModule> shaders; std::unordered_map<std::string, Core::ShaderModule> shaders;
std::unordered_map<std::string, Core::Pipeline> pipelines; std::unordered_map<std::string, Core::Pipeline> pipelines;
}; };

View file

@ -23,7 +23,7 @@ namespace LSFG_3_1 {
[[gnu::visibility("default")]] [[gnu::visibility("default")]]
void initialize(uint64_t deviceUUID, void initialize(uint64_t deviceUUID,
bool isHdr, float flowScale, uint64_t generationCount, bool isHdr, float flowScale, uint64_t generationCount,
const std::function<std::vector<uint8_t>(const std::string&)>& loader); const std::function<std::vector<uint8_t>(const std::string&, bool)>& loader);
/// ///
/// Initialize the renderdoc API. /// Initialize the renderdoc API.

View file

@ -23,7 +23,7 @@ namespace LSFG_3_1P {
[[gnu::visibility("default")]] [[gnu::visibility("default")]]
void initialize(uint64_t deviceUUID, void initialize(uint64_t deviceUUID,
bool isHdr, float flowScale, uint64_t generationCount, bool isHdr, float flowScale, uint64_t generationCount,
const std::function<std::vector<uint8_t>(const std::string&)>& loader); const std::function<std::vector<uint8_t>(const std::string&, bool)>& loader);
/// ///
/// Initialize the renderdoc API. /// Initialize the renderdoc API.

View file

@ -22,7 +22,7 @@ Core::ShaderModule ShaderPool::getShader(
return it->second; return it->second;
// grab the shader // grab the shader
auto bytecode = this->source(name); auto bytecode = this->source(name, this->fp16);
if (bytecode.empty()) if (bytecode.empty())
throw std::runtime_error("Shader code is empty: " + name); throw std::runtime_error("Shader code is empty: " + name);

View file

@ -35,7 +35,7 @@ namespace {
void LSFG_3_1::initialize(uint64_t deviceUUID, void LSFG_3_1::initialize(uint64_t deviceUUID,
bool isHdr, float flowScale, uint64_t generationCount, bool isHdr, float flowScale, uint64_t generationCount,
const std::function<std::vector<uint8_t>(const std::string&)>& loader) { const std::function<std::vector<uint8_t>(const std::string&, bool)>& loader) {
if (instance.has_value() || device.has_value()) if (instance.has_value() || device.has_value())
return; return;
@ -52,7 +52,7 @@ void LSFG_3_1::initialize(uint64_t deviceUUID,
device->descriptorPool = Core::DescriptorPool(device->device); device->descriptorPool = Core::DescriptorPool(device->device);
device->resources = Pool::ResourcePool(device->isHdr, device->flowScale); device->resources = Pool::ResourcePool(device->isHdr, device->flowScale);
device->shaders = Pool::ShaderPool(loader); device->shaders = Pool::ShaderPool(loader, device->device.getFP16Support());
std::srand(static_cast<uint32_t>(std::time(nullptr))); std::srand(static_cast<uint32_t>(std::time(nullptr)));
} }

View file

@ -35,7 +35,7 @@ namespace {
void LSFG_3_1P::initialize(uint64_t deviceUUID, void LSFG_3_1P::initialize(uint64_t deviceUUID,
bool isHdr, float flowScale, uint64_t generationCount, bool isHdr, float flowScale, uint64_t generationCount,
const std::function<std::vector<uint8_t>(const std::string&)>& loader) { const std::function<std::vector<uint8_t>(const std::string&, bool)>& loader) {
if (instance.has_value() || device.has_value()) if (instance.has_value() || device.has_value())
return; return;
@ -52,7 +52,7 @@ void LSFG_3_1P::initialize(uint64_t deviceUUID,
device->descriptorPool = Core::DescriptorPool(device->device); device->descriptorPool = Core::DescriptorPool(device->device);
device->resources = Pool::ResourcePool(device->isHdr, device->flowScale); device->resources = Pool::ResourcePool(device->isHdr, device->flowScale);
device->shaders = Pool::ShaderPool(loader); device->shaders = Pool::ShaderPool(loader, device->device.getFP16Support());
std::srand(static_cast<uint32_t>(std::time(nullptr))); std::srand(static_cast<uint32_t>(std::time(nullptr)));
} }