diff --git a/include/core/commandbuffer.hpp b/include/core/commandbuffer.hpp index 4c73ff6..b01cf16 100644 --- a/include/core/commandbuffer.hpp +++ b/include/core/commandbuffer.hpp @@ -54,6 +54,17 @@ namespace Vulkan::Core { /// void begin(); + /// + /// Dispatch a compute command. + /// + /// @param x Number of groups in the X dimension + /// @param y Number of groups in the Y dimension + /// @param z Number of groups in the Z dimension + /// + /// @throws std::logic_error if the command buffer is not in Recording state + /// + void dispatch(uint32_t x, uint32_t y, uint32_t z); + /// /// End recording commands in the command buffer. /// @@ -85,7 +96,7 @@ namespace Vulkan::Core { /// Get the state of the command buffer. [[nodiscard]] CommandBufferState getState() const { return *this->state; } /// Get the Vulkan handle. - [[nodiscard]] auto handle() const { *this->commandBuffer; } + [[nodiscard]] auto handle() const { return *this->commandBuffer; } /// Check whether the object is valid. [[nodiscard]] bool isValid() const { return static_cast(this->commandBuffer); } diff --git a/include/core/commandpool.hpp b/include/core/commandpool.hpp index b2d0909..3936983 100644 --- a/include/core/commandpool.hpp +++ b/include/core/commandpool.hpp @@ -9,12 +9,6 @@ namespace Vulkan::Core { - /// Enumeration for different types of command pools. - enum class CommandPoolType { - /// Used for compute-type command buffers. - Compute - }; - /// /// C++ wrapper class for a Vulkan command pool. /// @@ -26,12 +20,11 @@ namespace Vulkan::Core { /// Create the command pool. /// /// @param device Vulkan device - /// @param type Type of command pool to create. /// /// @throws std::invalid_argument if the device is invalid. /// @throws ls::vulkan_error if object creation fails. /// - CommandPool(const Device& device, CommandPoolType type); + CommandPool(const Device& device); /// Get the Vulkan handle. [[nodiscard]] auto handle() const { return *this->commandPool; } diff --git a/include/core/pipeline.hpp b/include/core/pipeline.hpp index d951359..36c7a65 100644 --- a/include/core/pipeline.hpp +++ b/include/core/pipeline.hpp @@ -1,6 +1,7 @@ #ifndef PIPELINE_HPP #define PIPELINE_HPP +#include "core/commandbuffer.hpp" #include "core/shadermodule.hpp" #include "device.hpp" @@ -28,6 +29,15 @@ namespace Vulkan::Core { /// Pipeline(const Device& device, const ShaderModule& shader); + /// + /// Bind the pipeline to a command buffer. + /// + /// @param commandBuffer Command buffer to bind the pipeline to. + /// + /// @throws std::invalid_argument if the command buffer is invalid. + /// + void bind(const CommandBuffer& commandBuffer) const; + /// Get the Vulkan handle. [[nodiscard]] auto handle() const { return *this->pipeline; } /// Get the pipeline layout. diff --git a/src/core/commandbuffer.cpp b/src/core/commandbuffer.cpp index 3b66840..fa50bab 100644 --- a/src/core/commandbuffer.cpp +++ b/src/core/commandbuffer.cpp @@ -44,6 +44,13 @@ void CommandBuffer::begin() { *this->state = CommandBufferState::Recording; } +void CommandBuffer::dispatch(uint32_t x, uint32_t y, uint32_t z) { + if (*this->state != CommandBufferState::Recording) + throw std::logic_error("Command buffer is not in Recording state"); + + vkCmdDispatch(*this->commandBuffer, x, y, z); +} + void CommandBuffer::end() { if (*this->state != CommandBufferState::Recording) throw std::logic_error("Command buffer is not in Recording state"); diff --git a/src/core/commandpool.cpp b/src/core/commandpool.cpp index 65b7ddf..6061e37 100644 --- a/src/core/commandpool.cpp +++ b/src/core/commandpool.cpp @@ -3,21 +3,14 @@ using namespace Vulkan::Core; -CommandPool::CommandPool(const Device& device, CommandPoolType type) { +CommandPool::CommandPool(const Device& device) { if (!device) throw std::invalid_argument("Invalid Vulkan device"); - uint32_t familyIdx{}; - switch (type) { - case CommandPoolType::Compute: - familyIdx = device.getComputeFamilyIdx(); - break; - } - // create command pool const VkCommandPoolCreateInfo desc = { .sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO, - .queueFamilyIndex = familyIdx + .queueFamilyIndex = device.getComputeFamilyIdx() }; VkCommandPool commandPoolHandle{}; auto res = vkCreateCommandPool(device.handle(), &desc, nullptr, &commandPoolHandle); diff --git a/src/core/pipeline.cpp b/src/core/pipeline.cpp index b60fb7b..120c6f1 100644 --- a/src/core/pipeline.cpp +++ b/src/core/pipeline.cpp @@ -53,3 +53,10 @@ Pipeline::Pipeline(const Device& device, const ShaderModule& shader) { } ); } + +void Pipeline::bind(const CommandBuffer& commandBuffer) const { + if (!commandBuffer) + throw std::invalid_argument("Invalid command buffer"); + + vkCmdBindPipeline(commandBuffer.handle(), VK_PIPELINE_BIND_POINT_COMPUTE, *this->pipeline); +} diff --git a/src/main.cpp b/src/main.cpp index 5613bb5..69ea2cc 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -17,7 +17,7 @@ int main() { const Device device(instance); // prepare render pass - const Core::CommandPool commandPool(device, Core::CommandPoolType::Compute); + const Core::CommandPool commandPool(device); // prepare shader const Core::ShaderModule computeShader(device, "shaders/downsample.spv",