builder for descriptor set updates

This commit is contained in:
PancakeTAS 2025-06-29 23:31:34 +02:00
parent 80f27ad188
commit b463d71f69
No known key found for this signature in database
3 changed files with 111 additions and 63 deletions

View file

@ -17,6 +17,8 @@
namespace Vulkan::Core {
class DescriptorSetUpdateBuilder;
///
/// C++ wrapper class for a Vulkan descriptor set.
///
@ -36,16 +38,18 @@ namespace Vulkan::Core {
DescriptorSet(const Device& device,
DescriptorPool pool, const ShaderModule& shaderModule);
using ResourcePair = std::pair<VkDescriptorType, std::variant<Image, Sampler, Buffer>>;
using ResourceList = std::variant<
std::pair<VkDescriptorType, const std::vector<Image>&>,
std::pair<VkDescriptorType, const std::vector<Sampler>&>,
std::pair<VkDescriptorType, const std::vector<Buffer>&>
>;
///
/// Update the descriptor set with resources.
///
/// @param device Vulkan device
/// @param resources Resources to update the descriptor set with
///
void update(const Device& device,
const std::vector<std::vector<ResourcePair>>& resources) const;
[[nodiscard]] DescriptorSetUpdateBuilder update(const Device& device) const;
///
/// Bind a descriptor set to a command buffer.
@ -68,6 +72,37 @@ namespace Vulkan::Core {
std::shared_ptr<VkDescriptorSet> descriptorSet;
};
///
/// Builder class for updating a descriptor set.
///
class DescriptorSetUpdateBuilder {
friend class DescriptorSet;
public:
/// Add a resource to the descriptor set update.
DescriptorSetUpdateBuilder& add(VkDescriptorType type, const Image& image);
DescriptorSetUpdateBuilder& add(VkDescriptorType type, const Sampler& sampler);
DescriptorSetUpdateBuilder& add(VkDescriptorType type, const Buffer& buffer);
/// Add a list of resources to the descriptor set update.
DescriptorSetUpdateBuilder& add(VkDescriptorType type, const std::vector<Image>& images) {
for (const auto& image : images) this->add(type, image); return *this; }
DescriptorSetUpdateBuilder& add(VkDescriptorType type, const std::vector<Sampler>& samplers) {
for (const auto& sampler : samplers) this->add(type, sampler); return *this; }
DescriptorSetUpdateBuilder& add(VkDescriptorType type, const std::vector<Buffer>& buffers) {
for (const auto& buffer : buffers) this->add(type, buffer); return *this; }
/// Finish building the descriptor set update.
void build() const;
private:
const DescriptorSet* descriptorSet;
const Device* device;
DescriptorSetUpdateBuilder(const DescriptorSet& descriptorSet, const Device& device)
: descriptorSet(&descriptorSet), device(&device) {}
std::vector<VkWriteDescriptorSet> entries;
};
}
#endif // DESCRIPTORSET_HPP

View file

@ -27,47 +27,8 @@ DescriptorSet::DescriptorSet(const Device& device,
);
}
void DescriptorSet::update(const Device& device,
const std::vector<std::vector<ResourcePair>>& resources) const {
std::vector<VkWriteDescriptorSet> writeDescriptorSets;
uint32_t bindingIndex = 0;
for (const auto& list : resources) {
for (const auto& [type, resource] : list) {
VkWriteDescriptorSet writeDesc{
.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
.dstSet = this->handle(),
.dstBinding = bindingIndex++,
.descriptorCount = 1,
.descriptorType = type,
};
if (std::holds_alternative<Image>(resource)) {
const VkDescriptorImageInfo imageInfo{
.imageView = std::get<Image>(resource).getView(),
.imageLayout = VK_IMAGE_LAYOUT_GENERAL
};
writeDesc.pImageInfo = &imageInfo;
} else if (std::holds_alternative<Sampler>(resource)) {
const VkDescriptorImageInfo imageInfo{
.sampler = std::get<Sampler>(resource).handle()
};
writeDesc.pImageInfo = &imageInfo;
} else if (std::holds_alternative<Buffer>(resource)) {
const auto& buffer = std::get<Buffer>(resource);
const VkDescriptorBufferInfo bufferInfo{
.buffer = buffer.handle(),
.range = buffer.getSize()
};
writeDesc.pBufferInfo = &bufferInfo;
}
writeDescriptorSets.push_back(writeDesc);
}
}
vkUpdateDescriptorSets(device.handle(),
static_cast<uint32_t>(writeDescriptorSets.size()),
writeDescriptorSets.data(), 0, nullptr);
DescriptorSetUpdateBuilder DescriptorSet::update(const Device& device) const {
return { *this, device };
}
void DescriptorSet::bind(const CommandBuffer& commandBuffer, const Pipeline& pipeline) const {
@ -76,3 +37,67 @@ void DescriptorSet::bind(const CommandBuffer& commandBuffer, const Pipeline& pip
VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.getLayout(),
0, 1, &descriptorSetHandle, 0, nullptr);
}
// updater class
DescriptorSetUpdateBuilder& DescriptorSetUpdateBuilder::add(VkDescriptorType type, const Image& image) {
this->entries.push_back({
.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
.dstSet = this->descriptorSet->handle(),
.dstBinding = static_cast<uint32_t>(this->entries.size()),
.descriptorCount = 1,
.descriptorType = type,
.pImageInfo = new VkDescriptorImageInfo {
.imageView = image.getView(),
.imageLayout = VK_IMAGE_LAYOUT_GENERAL
},
.pBufferInfo = nullptr
});
return *this;
}
DescriptorSetUpdateBuilder& DescriptorSetUpdateBuilder::add(VkDescriptorType type, const Sampler& sampler) {
this->entries.push_back({
.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
.dstSet = this->descriptorSet->handle(),
.dstBinding = static_cast<uint32_t>(this->entries.size()),
.descriptorCount = 1,
.descriptorType = type,
.pImageInfo = new VkDescriptorImageInfo {
.sampler = sampler.handle(),
},
.pBufferInfo = nullptr
});
return *this;
}
DescriptorSetUpdateBuilder& DescriptorSetUpdateBuilder::add(VkDescriptorType type, const Buffer& buffer) {
this->entries.push_back({
.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
.dstSet = this->descriptorSet->handle(),
.dstBinding = static_cast<uint32_t>(this->entries.size()),
.descriptorCount = 1,
.descriptorType = type,
.pImageInfo = nullptr,
.pBufferInfo = new VkDescriptorBufferInfo {
.buffer = buffer.handle(),
.range = buffer.getSize()
}
});
return *this;
}
void DescriptorSetUpdateBuilder::build() const {
if (this->entries.empty()) return;
vkUpdateDescriptorSets(this->device->handle(),
static_cast<uint32_t>(this->entries.size()),
this->entries.data(), 0, nullptr);
// NOLINTBEGIN
for (const auto& entry : this->entries) {
delete entry.pImageInfo;
delete entry.pBufferInfo;
}
// NOLINTEND
}

View file

@ -12,7 +12,6 @@
#include "instance.hpp"
#include "utils/memorybarriers.hpp"
#include <algorithm>
#include <array>
#include <cassert>
#include <iostream>
@ -75,23 +74,12 @@ int main() {
// load descriptor set
const Core::DescriptorSet descriptorSet(device, descriptorPool, computeShader);
descriptorSet.update(
device,
{
{{ VK_DESCRIPTOR_TYPE_SAMPLER, sampler }},
{{ VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, inputImages[0] }},
{
{ VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, outputImages[0] },
{ VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, outputImages[1] },
{ VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, outputImages[2] },
{ VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, outputImages[3] },
{ VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, outputImages[4] },
{ VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, outputImages[5] },
{ VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, outputImages[6] }
},
{{ VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, buffer }}
}
);
descriptorSet.update(device)
.add(VK_DESCRIPTOR_TYPE_SAMPLER, sampler)
.add(VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, inputImages)
.add(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, outputImages)
.add(VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, buffer)
.build();
// start pass
Core::Fence fence(device);