diff --git a/lsfg-vk-backend/src/modules/pipeline/signature.hpp b/lsfg-vk-backend/src/modules/pipeline/signature.hpp index 544ecf8..fa95fff 100644 --- a/lsfg-vk-backend/src/modules/pipeline/signature.hpp +++ b/lsfg-vk-backend/src/modules/pipeline/signature.hpp @@ -187,6 +187,92 @@ namespace lsfgvk::pipeline { } } + // Calculate pipeline stages by reordering passes with dependencies as constraints + std::vector writtenImages; + for (size_t i = 0; i < this->m_images.size(); i++) { + const auto& image{this->m_images.at(i)}; + if (image.flags & ImageFlag::ExternalInput) + writtenImages.push_back(i); + } + + std::vector remainingPasses(this->m_passes.size()); + std::iota(remainingPasses.begin(), remainingPasses.end(), 0); + + size_t currentStageIndex{0}; + std::pair currentStageBounds{ + 0, + this->m_splitIndices.empty() ? this->m_passes.size() : this->m_splitIndices.front() + }; + + while (!remainingPasses.empty()) { + auto& currentStage{s.stages.emplace_back()}; + + // Find all passes that may be executed next + std::vector validPasses{}; + for (const auto& passIdx : remainingPasses) { + if (passIdx < currentStageBounds.first || passIdx >= currentStageBounds.second) + continue; // Skip passes that are not in the current stage + + const auto& pass{this->m_passes.at(passIdx)}; + + bool isValid{true}; + for (const auto& image : pass.inputs) { + if (!image.idx()) + continue; + if (std::ranges::find(writtenImages, *image.idx()) != writtenImages.end()) + continue; + + isValid = false; + break; + } + + if (!isValid) + continue; + + validPasses.push_back(passIdx); + } + + // If no valid pass exists in the current stage, move on to the next stage + if (validPasses.empty() && currentStageIndex < this->m_splitIndices.size()) { + currentStageIndex++; + currentStageBounds = { + currentStageBounds.second, + currentStageIndex < this->m_splitIndices.size() ? + this->m_splitIndices.at(currentStageIndex) : this->m_passes.size() + }; + + s.stages.pop_back(); + s.splitIndices.emplace_back(s.stages.size()); + continue; + } + + // Sort valid passes by shader name + auto begin = std::ranges::begin(validPasses); + auto end = std::ranges::end(validPasses); + for (auto i = begin; i != end; i++) { + std::rotate( + std::upper_bound(begin, i, *i, [this](size_t a, size_t b) { + return this->m_passes.at(a).shader < this->m_passes.at(b).shader; + }), + i, std::next(i) + ); + } + + // Merge passes into execution step + for (const auto& passIdx : validPasses) { + const auto& pass{this->m_passes.at(passIdx)}; + + for (const auto& resource : pass.outputs) { + if (!resource.idx()) + continue; + writtenImages.push_back(*resource.idx()); + } + + currentStage.passes.push_back(passIdx); + remainingPasses.erase(std::ranges::find(remainingPasses, passIdx)); + } + } + // Copy remaining resources into signature for (const auto& shader : shaderInfos) s.shaders.emplace_back(shader.id, shader.hasHdrVariant);