feat(fp16): account for new shader bindings

This commit is contained in:
PancakeTAS 2025-08-04 16:26:02 +02:00 committed by Pancake
parent 6c3571e672
commit 77d1b68b8b
3 changed files with 57 additions and 8 deletions

View file

@ -119,6 +119,10 @@ namespace LSFG::Core {
: descriptorSet(&descriptorSet), device(&device) {}
std::vector<VkWriteDescriptorSet> entries;
size_t bufferIdx{0};
size_t samplerIdx{16};
size_t inputIdx{32};
size_t outputIdx{48};
};
}

View file

@ -12,8 +12,9 @@
#include "core/buffer.hpp"
#include "common/exception.hpp"
#include <memory>
#include <cstddef>
#include <cstdint>
#include <memory>
using namespace LSFG::Core;
@ -55,10 +56,11 @@ void DescriptorSet::bind(const CommandBuffer& commandBuffer, const Pipeline& pip
// updater class
DescriptorSetUpdateBuilder& DescriptorSetUpdateBuilder::add(VkDescriptorType type, const Image& image) {
size_t* idx{type == VK_DESCRIPTOR_TYPE_STORAGE_IMAGE ? &this->outputIdx : &this->inputIdx};
this->entries.push_back({
.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
.dstSet = this->descriptorSet->handle(),
.dstBinding = static_cast<uint32_t>(this->entries.size()),
.dstBinding = static_cast<uint32_t>(*idx),
.descriptorCount = 1,
.descriptorType = type,
.pImageInfo = new VkDescriptorImageInfo {
@ -67,6 +69,7 @@ DescriptorSetUpdateBuilder& DescriptorSetUpdateBuilder::add(VkDescriptorType typ
},
.pBufferInfo = nullptr
});
(*idx)++;
return *this;
}
@ -74,7 +77,7 @@ DescriptorSetUpdateBuilder& DescriptorSetUpdateBuilder::add(VkDescriptorType typ
this->entries.push_back({
.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
.dstSet = this->descriptorSet->handle(),
.dstBinding = static_cast<uint32_t>(this->entries.size()),
.dstBinding = static_cast<uint32_t>(this->samplerIdx++),
.descriptorCount = 1,
.descriptorType = type,
.pImageInfo = new VkDescriptorImageInfo {
@ -89,7 +92,7 @@ DescriptorSetUpdateBuilder& DescriptorSetUpdateBuilder::add(VkDescriptorType typ
this->entries.push_back({
.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
.dstSet = this->descriptorSet->handle(),
.dstBinding = static_cast<uint32_t>(this->entries.size()),
.dstBinding = static_cast<uint32_t>(this->samplerIdx++),
.descriptorCount = 1,
.descriptorType = type,
.pImageInfo = nullptr,
@ -102,16 +105,34 @@ DescriptorSetUpdateBuilder& DescriptorSetUpdateBuilder::add(VkDescriptorType typ
}
DescriptorSetUpdateBuilder& DescriptorSetUpdateBuilder::add(VkDescriptorType type) {
size_t* idx{};
switch (type) {
case VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE:
idx = &this->inputIdx;
break;
case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
idx = &this->outputIdx;
break;
case VK_DESCRIPTOR_TYPE_SAMPLER:
idx = &this->samplerIdx;
break;
case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
idx = &this->bufferIdx;
break;
default:
throw LSFG::vulkan_error(VK_ERROR_UNKNOWN, "Unsupported descriptor type");
}
this->entries.push_back({
.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
.dstSet = this->descriptorSet->handle(),
.dstBinding = static_cast<uint32_t>(this->entries.size()),
.dstBinding = static_cast<uint32_t>(*idx),
.descriptorCount = 1,
.descriptorType = type,
.pImageInfo = new VkDescriptorImageInfo {
},
.pBufferInfo = nullptr
});
(*idx)++;
return *this;
}

View file

@ -29,16 +29,40 @@ ShaderModule::ShaderModule(const Core::Device& device, const std::vector<uint8_t
// create descriptor set layout
std::vector<VkDescriptorSetLayoutBinding> layoutBindings;
size_t bindIdx = 0;
size_t bufferIdx{0};
size_t samplerIdx{16};
size_t inputIdx{32};
size_t outputIdx{48};
for (const auto &[count, type] : descriptorTypes)
for (size_t i = 0; i < count; i++, bindIdx++)
for (size_t i = 0; i < count; i++) {
size_t* bindIdx{};
switch (type) {
case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
bindIdx = &bufferIdx;
break;
case VK_DESCRIPTOR_TYPE_SAMPLER:
bindIdx = &samplerIdx;
break;
case VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE:
bindIdx = &inputIdx;
break;
case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
bindIdx = &outputIdx;
break;
default:
throw LSFG::vulkan_error(VK_ERROR_UNKNOWN, "Unsupported descriptor type");
}
layoutBindings.emplace_back(VkDescriptorSetLayoutBinding {
.binding = static_cast<uint32_t>(bindIdx),
.binding = static_cast<uint32_t>(*bindIdx),
.descriptorType = type,
.descriptorCount = 1,
.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT
});
(*bindIdx)++;
}
const VkDescriptorSetLayoutCreateInfo layoutDesc{
.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO,
.bindingCount = static_cast<uint32_t>(layoutBindings.size()),