Skip to content

Commit

Permalink
Simplified optional extensions management
Browse files Browse the repository at this point in the history
  • Loading branch information
corporateshark committed Nov 2, 2024
1 parent 576fb01 commit bf6351c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 71 deletions.
119 changes: 49 additions & 70 deletions lvk/vulkan/VulkanClasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5757,18 +5757,6 @@ uint32_t lvk::VulkanContext::queryDevices(HWDeviceType deviceType, HWDeviceDesc*
return numCompatibleDevices;
}

bool lvk::VulkanContext::isRequestedCustomDeviceExtension(const char* ext) const {
if (!ext)
return false;

for (const char* s : config_.extensionsDevice) {
if (s && strcmp(s, ext) == 0) {
return true;
}
}
return false;
}

void lvk::VulkanContext::addNextPhysicalDeviceProperties(void* properties) {
if (!properties)
return;
Expand All @@ -5789,16 +5777,18 @@ lvk::Result lvk::VulkanContext::initContext(const HWDeviceDesc& desc) {

useStaging_ = !isHostVisibleSingleHeapMemory(vkPhysicalDevice_);

uint32_t count = 0;
VK_ASSERT(vkEnumerateDeviceExtensionProperties(vkPhysicalDevice_, nullptr, &count, nullptr));

std::vector<VkExtensionProperties> allPhysicalDeviceExtensions(count);
VK_ASSERT(vkEnumerateDeviceExtensionProperties(vkPhysicalDevice_, nullptr, &count, allPhysicalDeviceExtensions.data()));
std::vector<VkExtensionProperties> allDeviceExtensions;
getDeviceExtensionProps(vkPhysicalDevice_, allDeviceExtensions);
if (config_.enableValidation) {
for (const char* layer : kDefaultValidationLayers) {
getDeviceExtensionProps(vkPhysicalDevice_, allDeviceExtensions, layer);
}
}

if (hasExtension(VK_KHR_ACCELERATION_STRUCTURE_EXTENSION_NAME, allPhysicalDeviceExtensions)) {
if (hasExtension(VK_KHR_ACCELERATION_STRUCTURE_EXTENSION_NAME, allDeviceExtensions)) {
addNextPhysicalDeviceProperties(&accelerationStructureProperties_);
}
if (hasExtension(VK_KHR_RAY_TRACING_PIPELINE_EXTENSION_NAME, allPhysicalDeviceExtensions) ) {
if (hasExtension(VK_KHR_RAY_TRACING_PIPELINE_EXTENSION_NAME, allDeviceExtensions) ) {
addNextPhysicalDeviceProperties(&rayTracingPipelineProperties_);
}

Expand All @@ -5818,7 +5808,7 @@ lvk::Result lvk::VulkanContext::initContext(const HWDeviceDesc& desc) {
LLOGL("Vulkan physical device extensions:\n");

// log available physical device extensions
for (const auto& ext : allPhysicalDeviceExtensions) {
for (const auto& ext : allDeviceExtensions) {
LLOGL(" %s\n", ext.extensionName);
}

Expand Down Expand Up @@ -5883,46 +5873,12 @@ lvk::Result lvk::VulkanContext::initContext(const HWDeviceDesc& desc) {
#endif
};

std::vector<VkExtensionProperties> allDeviceExtensions;
getDeviceExtensionProps(vkPhysicalDevice_, allDeviceExtensions);
if (config_.enableValidation) {
for (const char* layer : kDefaultValidationLayers) {
getDeviceExtensionProps(vkPhysicalDevice_, allDeviceExtensions, layer);
}
}

for (const char* ext : config_.extensionsDevice) {
if (ext) {
deviceExtensionNames.push_back(ext);
}
}

if (hasExtension(VK_KHR_ACCELERATION_STRUCTURE_EXTENSION_NAME, allPhysicalDeviceExtensions) &&
hasExtension(VK_KHR_DEFERRED_HOST_OPERATIONS_EXTENSION_NAME, allPhysicalDeviceExtensions)) {
hasAccelerationStructure_ = true;
if (!isRequestedCustomDeviceExtension(VK_KHR_ACCELERATION_STRUCTURE_EXTENSION_NAME))
deviceExtensionNames.push_back(VK_KHR_ACCELERATION_STRUCTURE_EXTENSION_NAME);
if (!isRequestedCustomDeviceExtension(VK_KHR_DEFERRED_HOST_OPERATIONS_EXTENSION_NAME))
deviceExtensionNames.push_back(VK_KHR_DEFERRED_HOST_OPERATIONS_EXTENSION_NAME);
}
if (hasExtension(VK_KHR_RAY_QUERY_EXTENSION_NAME, allPhysicalDeviceExtensions)) {
hasRayQuery_ = true;
if (!isRequestedCustomDeviceExtension(VK_KHR_RAY_QUERY_EXTENSION_NAME))
deviceExtensionNames.push_back(VK_KHR_RAY_QUERY_EXTENSION_NAME);
}
if (hasExtension(VK_KHR_RAY_TRACING_PIPELINE_EXTENSION_NAME, allPhysicalDeviceExtensions)) {
hasRayTracingPipeline_ = true;
if (!isRequestedCustomDeviceExtension(VK_KHR_RAY_TRACING_PIPELINE_EXTENSION_NAME))
deviceExtensionNames.push_back(VK_KHR_RAY_TRACING_PIPELINE_EXTENSION_NAME);
}
if (hasExtension(VK_KHR_INDEX_TYPE_UINT8_EXTENSION_NAME, allDeviceExtensions)) {
if (!isRequestedCustomDeviceExtension(VK_KHR_INDEX_TYPE_UINT8_EXTENSION_NAME))
deviceExtensionNames.push_back(VK_KHR_INDEX_TYPE_UINT8_EXTENSION_NAME);
} else if (hasExtension(VK_EXT_INDEX_TYPE_UINT8_EXTENSION_NAME, allDeviceExtensions)) {
if (!isRequestedCustomDeviceExtension(VK_EXT_INDEX_TYPE_UINT8_EXTENSION_NAME))
deviceExtensionNames.push_back(VK_EXT_INDEX_TYPE_UINT8_EXTENSION_NAME);
}

VkPhysicalDeviceFeatures deviceFeatures10 = {
#if !defined(__APPLE__)
.geometryShader = VK_TRUE,
Expand Down Expand Up @@ -6031,22 +5987,45 @@ lvk::Result lvk::VulkanContext::initContext(const HWDeviceDesc& desc) {
.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_INDEX_TYPE_UINT8_FEATURES_KHR,
.indexTypeUint8 = VK_TRUE,
};
if (hasExtension(VK_KHR_ACCELERATION_STRUCTURE_EXTENSION_NAME, allDeviceExtensions)) {
accelerationStructureFeatures.pNext = createInfoNext;
createInfoNext = &accelerationStructureFeatures;
}
if (hasExtension(VK_KHR_RAY_QUERY_EXTENSION_NAME, allDeviceExtensions)) {
rayQueryFeatures.pNext = createInfoNext;
createInfoNext = &rayQueryFeatures;
}
if (hasExtension(VK_KHR_RAY_TRACING_PIPELINE_EXTENSION_NAME, allDeviceExtensions)) {
rayTracingFeatures.pNext = createInfoNext;
createInfoNext = &rayTracingFeatures;
}
if (hasExtension(VK_KHR_INDEX_TYPE_UINT8_EXTENSION_NAME, allDeviceExtensions) ||
hasExtension(VK_EXT_INDEX_TYPE_UINT8_EXTENSION_NAME, allDeviceExtensions)) {
indexTypeUint8Features.pNext = createInfoNext;
createInfoNext = &indexTypeUint8Features;

auto addOptionalExtension = [&allDeviceExtensions, &deviceExtensionNames, &createInfoNext](
const char* name, bool& enabled, void* features = nullptr) mutable -> bool {
if (!hasExtension(name, allDeviceExtensions))
return false;
enabled = true;
deviceExtensionNames.push_back(name);
if (features) {
std::launder(reinterpret_cast<VkBaseOutStructure*>(features))->pNext =
std::launder(reinterpret_cast<VkBaseOutStructure*>(createInfoNext));
createInfoNext = features;
}
return true;
};
auto addOptionalExtensions = [&allDeviceExtensions, &deviceExtensionNames, &createInfoNext](
const char* name1, const char* name2, bool& enabled, void* features = nullptr) mutable {
if (!hasExtension(name1, allDeviceExtensions) || !hasExtension(name2, allDeviceExtensions))
return;
enabled = true;
deviceExtensionNames.push_back(name1);
deviceExtensionNames.push_back(name2);
if (features) {
std::launder(reinterpret_cast<VkBaseOutStructure*>(features))->pNext =
std::launder(reinterpret_cast<VkBaseOutStructure*>(createInfoNext));
createInfoNext = features;
}
};

addOptionalExtensions(VK_KHR_ACCELERATION_STRUCTURE_EXTENSION_NAME,
VK_KHR_DEFERRED_HOST_OPERATIONS_EXTENSION_NAME,
hasAccelerationStructure_,
&accelerationStructureFeatures);
addOptionalExtension(VK_KHR_RAY_QUERY_EXTENSION_NAME, hasRayQuery_, &rayQueryFeatures);
addOptionalExtension(VK_KHR_RAY_TRACING_PIPELINE_EXTENSION_NAME, hasRayTracingPipeline_, &rayTracingFeatures);
#if defined(VK_KHR_INDEX_TYPE_UINT8_EXTENSION_NAME)
if (!addOptionalExtension(VK_KHR_INDEX_TYPE_UINT8_EXTENSION_NAME, has8BitIndices_, &indexTypeUint8Features))
#endif // VK_KHR_INDEX_TYPE_UINT8_EXTENSION_NAME
{
addOptionalExtension(VK_EXT_INDEX_TYPE_UINT8_EXTENSION_NAME, has8BitIndices_, &indexTypeUint8Features);
}

const VkDeviceCreateInfo ci = {
Expand Down
2 changes: 1 addition & 1 deletion lvk/vulkan/VulkanClasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,6 @@ class VulkanContext final : public IContext {
ShaderModuleState createShaderModuleFromGLSL(ShaderStage stage, const char* source, const char* debugName, Result* outResult) const;
const VkSamplerYcbcrConversionInfo* getOrCreateYcbcrConversionInfo(lvk::Format format);
VkSampler getOrCreateYcbcrSampler(lvk::Format format);
bool isRequestedCustomDeviceExtension(const char* ext) const;
void addNextPhysicalDeviceProperties(void* properties);

private:
Expand Down Expand Up @@ -675,6 +674,7 @@ class VulkanContext final : public IContext {
bool hasAccelerationStructure_ = false;
bool hasRayQuery_ = false;
bool hasRayTracingPipeline_ = false;
bool has8BitIndices_ = false;

TextureHandle dummyTexture_;

Expand Down

0 comments on commit bf6351c

Please sign in to comment.