mirror of
https://github.com/PancakeTAS/lsfg-vk.git
synced 2025-10-30 07:01:10 +00:00
173 lines
6.6 KiB
C++
173 lines
6.6 KiB
C++
#include "shaderchains/alpha.hpp"
|
|
#include "utils.hpp"
|
|
|
|
using namespace LSFG::Shaderchains;
|
|
|
|
Alpha::Alpha(const Core::Device& device, const Core::DescriptorPool& pool,
|
|
Core::Image inImg)
|
|
: inImg(std::move(inImg)) {
|
|
this->shaderModules = {{
|
|
Core::ShaderModule(device, "rsc/shaders/alpha/0.spv",
|
|
{ { 1, VK_DESCRIPTOR_TYPE_SAMPLER },
|
|
{ 1, VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE },
|
|
{ 2, VK_DESCRIPTOR_TYPE_STORAGE_IMAGE } }),
|
|
Core::ShaderModule(device, "rsc/shaders/alpha/1.spv",
|
|
{ { 1, VK_DESCRIPTOR_TYPE_SAMPLER },
|
|
{ 2, VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE },
|
|
{ 2, VK_DESCRIPTOR_TYPE_STORAGE_IMAGE } }),
|
|
Core::ShaderModule(device, "rsc/shaders/alpha/2.spv",
|
|
{ { 1, VK_DESCRIPTOR_TYPE_SAMPLER },
|
|
{ 2, VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE },
|
|
{ 4, VK_DESCRIPTOR_TYPE_STORAGE_IMAGE } }),
|
|
Core::ShaderModule(device, "rsc/shaders/alpha/3.spv",
|
|
{ { 1, VK_DESCRIPTOR_TYPE_SAMPLER },
|
|
{ 4, VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE },
|
|
{ 4, VK_DESCRIPTOR_TYPE_STORAGE_IMAGE } })
|
|
}};
|
|
for (size_t i = 0; i < 4; i++) {
|
|
this->pipelines.at(i) = Core::Pipeline(device,
|
|
this->shaderModules.at(i));
|
|
if (i == 3) continue; // last shader is special
|
|
this->descriptorSets.at(i) = Core::DescriptorSet(device, pool,
|
|
this->shaderModules.at(i));
|
|
}
|
|
for (size_t i = 0; i < 3; i++)
|
|
this->specialDescriptorSets.at(i) = Core::DescriptorSet(device, pool,
|
|
this->shaderModules.at(3));
|
|
|
|
const auto extent = this->inImg.getExtent();
|
|
|
|
const VkExtent2D halfExtent = {
|
|
.width = (extent.width + 1) >> 1,
|
|
.height = (extent.height + 1) >> 1
|
|
};
|
|
for (size_t i = 0; i < 2; i++) {
|
|
this->tempImgs1.at(i) = Core::Image(device,
|
|
halfExtent,
|
|
VK_FORMAT_R8G8B8A8_UNORM,
|
|
VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_SAMPLED_BIT,
|
|
VK_IMAGE_ASPECT_COLOR_BIT);
|
|
this->tempImgs2.at(i) = Core::Image(device,
|
|
halfExtent,
|
|
VK_FORMAT_R8G8B8A8_UNORM,
|
|
VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_SAMPLED_BIT,
|
|
VK_IMAGE_ASPECT_COLOR_BIT);
|
|
}
|
|
|
|
const VkExtent2D quarterExtent = {
|
|
.width = (extent.width + 3) >> 2,
|
|
.height = (extent.height + 3) >> 2
|
|
};
|
|
for (size_t i = 0; i < 4; i++) {
|
|
this->tempImgs3.at(i) = Core::Image(device,
|
|
quarterExtent,
|
|
VK_FORMAT_R8G8B8A8_UNORM,
|
|
VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_SAMPLED_BIT,
|
|
VK_IMAGE_ASPECT_COLOR_BIT);
|
|
this->outImgs_0.at(i) = Core::Image(device,
|
|
quarterExtent,
|
|
VK_FORMAT_R8G8B8A8_UNORM,
|
|
VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_SAMPLED_BIT
|
|
| VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
|
|
VK_IMAGE_ASPECT_COLOR_BIT);
|
|
this->outImgs_1.at(i) = Core::Image(device,
|
|
quarterExtent,
|
|
VK_FORMAT_R8G8B8A8_UNORM,
|
|
VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_SAMPLED_BIT
|
|
| VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
|
|
VK_IMAGE_ASPECT_COLOR_BIT);
|
|
this->outImgs_2.at(i) = Core::Image(device,
|
|
quarterExtent,
|
|
VK_FORMAT_R8G8B8A8_UNORM,
|
|
VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_SAMPLED_BIT
|
|
| VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
|
|
VK_IMAGE_ASPECT_COLOR_BIT);
|
|
}
|
|
|
|
this->descriptorSets.at(0).update(device)
|
|
.add(VK_DESCRIPTOR_TYPE_SAMPLER, Globals::samplerClampBorder)
|
|
.add(VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, this->inImg)
|
|
.add(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, this->tempImgs1)
|
|
.build();
|
|
this->descriptorSets.at(1).update(device)
|
|
.add(VK_DESCRIPTOR_TYPE_SAMPLER, Globals::samplerClampBorder)
|
|
.add(VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, this->tempImgs1)
|
|
.add(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, this->tempImgs2)
|
|
.build();
|
|
this->descriptorSets.at(2).update(device)
|
|
.add(VK_DESCRIPTOR_TYPE_SAMPLER, Globals::samplerClampBorder)
|
|
.add(VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, this->tempImgs2)
|
|
.add(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, this->tempImgs3)
|
|
.build();
|
|
for (size_t fc = 0; fc < 3; fc++) {
|
|
auto* outImgs = &this->outImgs_0;
|
|
if (fc == 1) outImgs = &this->outImgs_1;
|
|
else if (fc == 2) outImgs = &this->outImgs_2;
|
|
this->specialDescriptorSets.at(fc).update(device)
|
|
.add(VK_DESCRIPTOR_TYPE_SAMPLER, Globals::samplerClampBorder)
|
|
.add(VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, this->tempImgs3)
|
|
.add(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, *outImgs)
|
|
.build();
|
|
}
|
|
|
|
// clear the output images (so they're not undefined)
|
|
for (size_t i = 0; i < 4; i++) {
|
|
Utils::clearImage(device, this->outImgs_0.at(i));
|
|
Utils::clearImage(device, this->outImgs_1.at(i));
|
|
Utils::clearImage(device, this->outImgs_2.at(i));
|
|
}
|
|
}
|
|
|
|
void Alpha::Dispatch(const Core::CommandBuffer& buf, uint64_t fc) {
|
|
const auto halfExtent = this->tempImgs1.at(0).getExtent();
|
|
const auto quarterExtent = this->tempImgs3.at(0).getExtent();
|
|
|
|
// first pass
|
|
uint32_t threadsX = (halfExtent.width + 7) >> 3;
|
|
uint32_t threadsY = (halfExtent.height + 7) >> 3;
|
|
|
|
Utils::BarrierBuilder(buf)
|
|
.addW2R(this->inImg)
|
|
.addR2W(this->tempImgs1)
|
|
.build();
|
|
|
|
this->pipelines.at(0).bind(buf);
|
|
this->descriptorSets.at(0).bind(buf, this->pipelines.at(0));
|
|
buf.dispatch(threadsX, threadsY, 1);
|
|
|
|
// second pass
|
|
Utils::BarrierBuilder(buf)
|
|
.addW2R(this->tempImgs1)
|
|
.addR2W(this->tempImgs2)
|
|
.build();
|
|
|
|
this->pipelines.at(1).bind(buf);
|
|
this->descriptorSets.at(1).bind(buf, this->pipelines.at(1));
|
|
buf.dispatch(threadsX, threadsY, 1);
|
|
|
|
// third pass
|
|
threadsX = (quarterExtent.width + 7) >> 3;
|
|
threadsY = (quarterExtent.height + 7) >> 3;
|
|
|
|
Utils::BarrierBuilder(buf)
|
|
.addW2R(this->tempImgs2)
|
|
.addR2W(this->tempImgs3)
|
|
.build();
|
|
|
|
this->pipelines.at(2).bind(buf);
|
|
this->descriptorSets.at(2).bind(buf, this->pipelines.at(2));
|
|
buf.dispatch(threadsX, threadsY, 1);
|
|
|
|
// fourth pass
|
|
auto* outImgs = &this->outImgs_0;
|
|
if ((fc % 3) == 1) outImgs = &this->outImgs_1;
|
|
else if ((fc % 3) == 2) outImgs = &this->outImgs_2;
|
|
Utils::BarrierBuilder(buf)
|
|
.addW2R(this->tempImgs3)
|
|
.addR2W(*outImgs)
|
|
.build();
|
|
|
|
this->pipelines.at(3).bind(buf);
|
|
this->specialDescriptorSets.at(fc % 3).bind(buf, this->pipelines.at(3));
|
|
buf.dispatch(threadsX, threadsY, 1);
|
|
}
|