Skip to content

Commit

Permalink
Automatically limiting push constants to those shader stages that are…
Browse files Browse the repository at this point in the history
… actually used with pipeline creation
  • Loading branch information
johannesugb committed Aug 2, 2024
1 parent 3d2206f commit 749c5d6
Showing 1 changed file with 47 additions and 3 deletions.
50 changes: 47 additions & 3 deletions src/avk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3040,7 +3040,7 @@ namespace avk
result.mPushConstantRanges.reserve(aConfig.mPushConstantsBindings.size()); // Important! Otherwise the vector might realloc and .data() will become invalid!
for (const auto& pcBinding : aConfig.mPushConstantsBindings) {
result.mPushConstantRanges.push_back(vk::PushConstantRange{}
.setStageFlags(to_vk_shader_stages(pcBinding.mShaderStages))
.setStageFlags(to_vk_shader_stages(pcBinding.mShaderStages & shader_type::compute)) // <-- nothing else makes sense
.setOffset(static_cast<uint32_t>(pcBinding.mOffset))
.setSize(static_cast<uint32_t>(pcBinding.mSize))
);
Expand Down Expand Up @@ -4539,10 +4539,17 @@ namespace avk
.setPrimitiveRestartEnable(VK_FALSE);

// 5. Compile and store the shaders:
std::optional<shader_type> usedShaders; // <-- gather all the types of shaders used in this pipeline
result.mShaders.reserve(aConfig.mShaderInfos.size()); // Important! Otherwise the vector might realloc and .data() will become invalid!
result.mShaderStageCreateInfos.reserve(aConfig.mShaderInfos.size()); // Important! Otherwise the vector might realloc and .data() will become invalid!
result.mSpecializationInfos.reserve(aConfig.mShaderInfos.size()); // Important! Otherwise the vector might realloc and .data() will become invalid!
for (auto& shaderInfo : aConfig.mShaderInfos) {
if (usedShaders.has_value()) {
usedShaders = *usedShaders | shaderInfo.mShaderType;
}
else {
usedShaders = shaderInfo.mShaderType;
}
// 5.0 Sanity check
if (result.mShaders.end() != std::find_if(std::begin(result.mShaders), std::end(result.mShaders), [&shaderInfo](const shader& existing) { return existing.info().mShaderType == shaderInfo.mShaderType; })) {
throw avk::runtime_error("There's already a " + vk::to_string(to_vk_shader_stages(shaderInfo.mShaderType)) + "-type shader contained in this graphics pipeline. Can not add another one of the same type.");
Expand Down Expand Up @@ -4825,7 +4832,7 @@ namespace avk
result.mPushConstantRanges.reserve(aConfig.mPushConstantsBindings.size()); // Important! Otherwise the vector might realloc and .data() will become invalid!
for (const auto& pcBinding : aConfig.mPushConstantsBindings) {
result.mPushConstantRanges.push_back(vk::PushConstantRange{}
.setStageFlags(to_vk_shader_stages(pcBinding.mShaderStages))
.setStageFlags(to_vk_shader_stages(pcBinding.mShaderStages & usedShaders.value_or(pcBinding.mShaderStages))) // <-- limit shader stages to those which are actually there
.setOffset(static_cast<uint32_t>(pcBinding.mOffset))
.setSize(static_cast<uint32_t>(pcBinding.mSize))
);
Expand Down Expand Up @@ -6175,11 +6182,20 @@ namespace avk
vk::DeviceSize byteOffset = 0;
shader_group_info* curEdited = nullptr;

std::optional<shader_type> usedShaders; // <-- gather all the types of shaders used in this pipeline
for (auto& tableEntry : aConfig.mShaderTableConfig.mShaderTableEntries) {
group_type curType = group_type::none;

if (std::holds_alternative<shader_info>(tableEntry)) {
const auto& shaderInfo = std::get<shader_info>(tableEntry);

if (usedShaders.has_value()) {
usedShaders = *usedShaders | shaderInfo.mShaderType;
}
else {
usedShaders = shaderInfo.mShaderType;
}

switch (shaderInfo.mShaderType) {
case shader_type::ray_generation: curType = group_type::raygen; break;
case shader_type::miss: curType = group_type::miss; break;
Expand All @@ -6204,10 +6220,24 @@ namespace avk
const auto& hitGroup = std::get<triangles_hit_group>(tableEntry);
uint32_t rahitShaderIndex = VK_SHADER_UNUSED_KHR;
if (hitGroup.mAnyHitShader.has_value()) {
if (usedShaders.has_value()) {
usedShaders = *usedShaders | hitGroup.mAnyHitShader->mShaderType;
}
else {
usedShaders = hitGroup.mAnyHitShader->mShaderType;
}

rahitShaderIndex = static_cast<uint32_t>(index_of(orderedUniqueShaderInfos, hitGroup.mAnyHitShader.value()));
}
uint32_t rchitShaderIndex = VK_SHADER_UNUSED_KHR;
if (hitGroup.mClosestHitShader.has_value()) {
if (usedShaders.has_value()) {
usedShaders = *usedShaders | hitGroup.mClosestHitShader->mShaderType;
}
else {
usedShaders = hitGroup.mClosestHitShader->mShaderType;
}

rchitShaderIndex = static_cast<uint32_t>(index_of(orderedUniqueShaderInfos, hitGroup.mClosestHitShader.value()));
}
result.mShaderGroupCreateInfos.emplace_back()
Expand All @@ -6224,10 +6254,24 @@ namespace avk
uint32_t rintShaderIndex = static_cast<uint32_t>(index_of(orderedUniqueShaderInfos, hitGroup.mIntersectionShader));
uint32_t rahitShaderIndex = VK_SHADER_UNUSED_KHR;
if (hitGroup.mAnyHitShader.has_value()) {
if (usedShaders.has_value()) {
usedShaders = *usedShaders | hitGroup.mAnyHitShader->mShaderType;
}
else {
usedShaders = hitGroup.mAnyHitShader->mShaderType;
}

rahitShaderIndex = static_cast<uint32_t>(index_of(orderedUniqueShaderInfos, hitGroup.mAnyHitShader.value()));
}
uint32_t rchitShaderIndex = VK_SHADER_UNUSED_KHR;
if (hitGroup.mClosestHitShader.has_value()) {
if (usedShaders.has_value()) {
usedShaders = *usedShaders | hitGroup.mClosestHitShader->mShaderType;
}
else {
usedShaders = hitGroup.mClosestHitShader->mShaderType;
}

rchitShaderIndex = static_cast<uint32_t>(index_of(orderedUniqueShaderInfos, hitGroup.mClosestHitShader.value()));
}
result.mShaderGroupCreateInfos.emplace_back()
Expand Down Expand Up @@ -6294,7 +6338,7 @@ namespace avk
result.mPushConstantRanges.reserve(aConfig.mPushConstantsBindings.size()); // Important! Otherwise the vector might realloc and .data() will become invalid!
for (const auto& pcBinding : aConfig.mPushConstantsBindings) {
result.mPushConstantRanges.push_back(vk::PushConstantRange{}
.setStageFlags(to_vk_shader_stages(pcBinding.mShaderStages))
.setStageFlags(to_vk_shader_stages(pcBinding.mShaderStages & usedShaders.value_or(pcBinding.mShaderStages))) // <-- limit to ray tracing shader stages
.setOffset(static_cast<uint32_t>(pcBinding.mOffset))
.setSize(static_cast<uint32_t>(pcBinding.mSize))
);
Expand Down

0 comments on commit 749c5d6

Please sign in to comment.