From f313bced3399b606a910c27fa6a04391685ef834 Mon Sep 17 00:00:00 2001 From: Sergey Kosarevsky Date: Thu, 28 Dec 2023 17:08:54 -0800 Subject: [PATCH] igl | vulkan | Check if SPIR-V shader uses push constants Summary: Added a new member field `SpvModuleInfo::hasPushConstants` to check if a SPIR-V shader uses push constants. Reviewed By: mmaurer Differential Revision: D52446637 fbshipit-source-id: 4b05638c04564cc56f9469176ca8f4d308807d32 --- src/igl/vulkan/PipelineState.cpp | 32 +++++++++++++++++++++++++-- src/igl/vulkan/PipelineState.h | 8 +++---- src/igl/vulkan/VulkanHelpers.c | 2 +- src/igl/vulkan/util/SpvReflection.cpp | 5 +++++ src/igl/vulkan/util/SpvReflection.h | 1 + 5 files changed, 41 insertions(+), 7 deletions(-) diff --git a/src/igl/vulkan/PipelineState.cpp b/src/igl/vulkan/PipelineState.cpp index 811f36c7085..3693bcf7c84 100644 --- a/src/igl/vulkan/PipelineState.cpp +++ b/src/igl/vulkan/PipelineState.cpp @@ -15,14 +15,21 @@ namespace igl::vulkan { -void PipelineState::initializeSpvModuleInfoFromShaderStages(IShaderStages* stages) { +void PipelineState::initializeSpvModuleInfoFromShaderStages(const VulkanContext& ctx, + IShaderStages* stages) { auto* smComp = static_cast(stages->getComputeModule().get()); + VkShaderStageFlags pushConstantMask = 0; + if (smComp) { // compute ensureShaderModule(smComp); info_ = smComp->getVulkanShaderModule().getSpvModuleInfo(); + + if (info_.hasPushConstants) { + pushConstantMask |= VK_SHADER_STAGE_COMPUTE_BIT; + } } else { auto* smVert = static_cast(stages->getVertexModule().get()); auto* smFrag = static_cast(stages->getFragmentModule().get()); @@ -34,8 +41,29 @@ void PipelineState::initializeSpvModuleInfoFromShaderStages(IShaderStages* stage const util::SpvModuleInfo& infoVert = smVert->getVulkanShaderModule().getSpvModuleInfo(); const util::SpvModuleInfo& infoFrag = smFrag->getVulkanShaderModule().getSpvModuleInfo(); + if (infoVert.hasPushConstants) { + pushConstantMask |= VK_SHADER_STAGE_VERTEX_BIT; + } + if (infoFrag.hasPushConstants) { + pushConstantMask |= VK_SHADER_STAGE_FRAGMENT_BIT; + } + info_ = util::mergeReflectionData(infoVert, infoFrag); } + + if (pushConstantMask) { + const VkPhysicalDeviceLimits& limits = ctx.getVkPhysicalDeviceProperties().limits; + + constexpr uint32_t kPushConstantsSize = 128; + + if (!IGL_VERIFY(kPushConstantsSize <= limits.maxPushConstantsSize)) { + IGL_LOG_ERROR("Push constants size exceeded %u (max %u bytes)", + kPushConstantsSize, + limits.maxPushConstantsSize); + } + + pushConstantRange_ = ivkGetPushConstantRange(pushConstantMask, 0, kPushConstantsSize); + } } PipelineState::PipelineState(const VulkanContext& ctx, @@ -43,7 +71,7 @@ PipelineState::PipelineState(const VulkanContext& ctx, const char* debugName) { IGL_ASSERT(stages); - initializeSpvModuleInfoFromShaderStages(stages); + initializeSpvModuleInfoFromShaderStages(ctx, stages); // Create all Vulkan descriptor set layouts for this pipeline diff --git a/src/igl/vulkan/PipelineState.h b/src/igl/vulkan/PipelineState.h index d89548f1cdd..7c62fdf6da7 100644 --- a/src/igl/vulkan/PipelineState.h +++ b/src/igl/vulkan/PipelineState.h @@ -33,13 +33,13 @@ class PipelineState { } private: - void initializeSpvModuleInfoFromShaderStages(IShaderStages* stages); - - protected: - friend class ResourcesBinder; + void initializeSpvModuleInfoFromShaderStages(const VulkanContext& ctx, IShaderStages* stages); + public: igl::vulkan::util::SpvModuleInfo info_; + VkPushConstantRange pushConstantRange_ = {}; + mutable std::unique_ptr pipelineLayout_; // the last seen VkDescriptorSetLayout from VulkanContext::dslBindless_ diff --git a/src/igl/vulkan/VulkanHelpers.c b/src/igl/vulkan/VulkanHelpers.c index d58b3059beb..7e768125d78 100644 --- a/src/igl/vulkan/VulkanHelpers.c +++ b/src/igl/vulkan/VulkanHelpers.c @@ -1170,7 +1170,7 @@ VkPipelineLayoutCreateInfo ivkGetPipelineLayoutCreateInfo(uint32_t numLayouts, .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, .setLayoutCount = numLayouts, .pSetLayouts = layouts, - .pushConstantRangeCount = 1, + .pushConstantRangeCount = range ? 1u : 0u, .pPushConstantRanges = range, }; return ci; diff --git a/src/igl/vulkan/util/SpvReflection.cpp b/src/igl/vulkan/util/SpvReflection.cpp index 35e4c9434e7..0c9dbc3697f 100644 --- a/src/igl/vulkan/util/SpvReflection.cpp +++ b/src/igl/vulkan/util/SpvReflection.cpp @@ -228,6 +228,9 @@ SpvModuleInfo getReflectionData(const uint32_t* spirv, size_t numBytes) { break; } } + if (id.opCode == SpvOpVariable && id.storageClass == SpvStorageClassPushConstant) { + info.hasPushConstants = true; + } } return info; @@ -260,6 +263,8 @@ SpvModuleInfo mergeReflectionData(const SpvModuleInfo& info1, const SpvModuleInf combineDescriptions(result.storageBuffers, info1.storageBuffers, info2.storageBuffers); combineDescriptions(result.textures, info1.textures, info2.textures); + result.hasPushConstants = info1.hasPushConstants || info2.hasPushConstants; + return result; } diff --git a/src/igl/vulkan/util/SpvReflection.h b/src/igl/vulkan/util/SpvReflection.h index f89beae1650..96d8d64e28e 100644 --- a/src/igl/vulkan/util/SpvReflection.h +++ b/src/igl/vulkan/util/SpvReflection.h @@ -32,6 +32,7 @@ struct SpvModuleInfo { std::vector uniformBuffers; std::vector storageBuffers; std::vector textures; + bool hasPushConstants = false; }; SpvModuleInfo getReflectionData(const uint32_t* spirv, size_t numBytes);