Skip to content

Commit

Permalink
[SPIR-V] Implement SpirvType and SpirvOpaqueType (#6156)
Browse files Browse the repository at this point in the history
Implements hlsl-specs proposal 0011, adding `vk::SpirvType` and
`vk::SpirvOpaqueType` templates which allow users to define and use
SPIR-V level types.
  • Loading branch information
cassiebeckley authored Apr 15, 2024
1 parent dc84d72 commit d60dffe
Show file tree
Hide file tree
Showing 23 changed files with 533 additions and 54 deletions.
3 changes: 2 additions & 1 deletion tools/clang/include/clang/AST/HlslBuiltinTypeDeclBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class BuiltinTypeDeclBuilder final {

clang::TemplateTypeParmDecl *
addTypeTemplateParam(llvm::StringRef name,
clang::TypeSourceInfo *defaultValue = nullptr);
clang::TypeSourceInfo *defaultValue = nullptr,
bool parameterPack = false);
clang::TemplateTypeParmDecl *
addTypeTemplateParam(llvm::StringRef name, clang::QualType defaultValue);
clang::NonTypeTemplateParmDecl *
Expand Down
10 changes: 10 additions & 0 deletions tools/clang/include/clang/AST/HlslTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,16 @@ DeclareNodeOrRecordType(clang::ASTContext &Ctx, DXIL::NodeIOKind Type,
bool HasGetMethods = false, bool IsArray = false,
bool IsCompleteType = false);

#ifdef ENABLE_SPIRV_CODEGEN
clang::CXXRecordDecl *DeclareInlineSpirvType(clang::ASTContext &context,
clang::DeclContext *declContext,
llvm::StringRef typeName,
bool opaque);
clang::CXXRecordDecl *DeclareVkIntegralConstant(
clang::ASTContext &context, clang::DeclContext *declContext,
llvm::StringRef typeName, clang::ClassTemplateDecl **templateDecl);
#endif

clang::CXXRecordDecl *DeclareNodeOutputArray(clang::ASTContext &Ctx,
DXIL::NodeIOKind Type,
clang::CXXRecordDecl *OutputType,
Expand Down
12 changes: 8 additions & 4 deletions tools/clang/include/clang/SPIRV/SpirvContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,12 @@ class SpirvContext {

const RayQueryTypeKHR *getRayQueryTypeKHR() const { return rayQueryTypeKHR; }

const SpirvIntrinsicType *
getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode,
llvm::ArrayRef<SpvIntrinsicTypeOperand> operands);
const SpirvIntrinsicType *getOrCreateSpirvIntrinsicType(
unsigned typeId, unsigned typeOpCode,
llvm::ArrayRef<SpvIntrinsicTypeOperand> operands);

const SpirvIntrinsicType *getOrCreateSpirvIntrinsicType(
unsigned typeOpCode, llvm::ArrayRef<SpvIntrinsicTypeOperand> operands);

SpirvIntrinsicType *getCreatedSpirvIntrinsicType(unsigned typeId);

Expand Down Expand Up @@ -471,7 +474,8 @@ class SpirvContext {
llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
llvm::SmallVector<const HybridPointerType *, 8> hybridPointerTypes;
llvm::DenseSet<FunctionType *, FunctionTypeMapInfo> functionTypes;
llvm::DenseMap<unsigned, SpirvIntrinsicType *> spirvIntrinsicTypes;
llvm::DenseMap<unsigned, SpirvIntrinsicType *> spirvIntrinsicTypesById;
llvm::SmallVector<const SpirvIntrinsicType *, 8> spirvIntrinsicTypes;
const AccelerationStructureTypeNV *accelerationStructureTypeNV;
const RayQueryTypeKHR *rayQueryTypeKHR;

Expand Down
2 changes: 2 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,8 @@ class SpirvConstant : public SpirvInstruction {
inst->getKind() <= IK_ConstantNull;
}

bool operator==(const SpirvConstant &that) const;

bool isSpecConstant() const;
void setLiteral(bool literal = true) { literalConstant = literal; }
bool isLiteral() { return literalConstant; }
Expand Down
13 changes: 10 additions & 3 deletions tools/clang/include/clang/SPIRV/SpirvType.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,15 +429,16 @@ class RayQueryTypeKHR : public SpirvType {

class SpirvInstruction;
struct SpvIntrinsicTypeOperand {
SpvIntrinsicTypeOperand(SpirvType *type_operand)
SpvIntrinsicTypeOperand(const SpirvType *type_operand)
: operand_as_type(type_operand), isTypeOperand(true) {}
SpvIntrinsicTypeOperand(SpirvInstruction *inst_operand)
: operand_as_inst(inst_operand), isTypeOperand(false) {}
bool operator==(const SpvIntrinsicTypeOperand &that) const;
union {
SpirvType *operand_as_type;
const SpirvType *operand_as_type;
SpirvInstruction *operand_as_inst;
};
bool isTypeOperand;
const bool isTypeOperand;
};

class SpirvIntrinsicType : public SpirvType {
Expand All @@ -453,6 +454,12 @@ class SpirvIntrinsicType : public SpirvType {
return operands;
}

bool operator==(const SpirvIntrinsicType &that) const {
return typeOpCode == that.typeOpCode &&
operands.size() == that.operands.size() &&
std::equal(operands.begin(), operands.end(), that.operands.begin());
}

private:
unsigned typeOpCode;
llvm::SmallVector<SpvIntrinsicTypeOperand, 3> operands;
Expand Down
37 changes: 37 additions & 0 deletions tools/clang/lib/AST/ASTContextHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,43 @@ CXXRecordDecl *hlsl::DeclareNodeOrRecordType(
return Builder.getRecordDecl();
}

#ifdef ENABLE_SPIRV_CODEGEN
CXXRecordDecl *hlsl::DeclareInlineSpirvType(clang::ASTContext &context,
clang::DeclContext *declContext,
llvm::StringRef typeName,
bool opaque) {
// template<uint opcode, int size, int alignment> vk::SpirvType { ... }
// template<uint opcode> vk::SpirvOpaqueType { ... }
BuiltinTypeDeclBuilder typeDeclBuilder(declContext, typeName,
clang::TagTypeKind::TTK_Class);
typeDeclBuilder.addIntegerTemplateParam("opcode", context.UnsignedIntTy);
if (!opaque) {
typeDeclBuilder.addIntegerTemplateParam("size", context.UnsignedIntTy);
typeDeclBuilder.addIntegerTemplateParam("alignment", context.UnsignedIntTy);
}
typeDeclBuilder.addTypeTemplateParam("operands", nullptr, true);
typeDeclBuilder.startDefinition();
typeDeclBuilder.addField(
"h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.
return typeDeclBuilder.getRecordDecl();
}

CXXRecordDecl *hlsl::DeclareVkIntegralConstant(
clang::ASTContext &context, clang::DeclContext *declContext,
llvm::StringRef typeName, ClassTemplateDecl **templateDecl) {
// template<typename T, T v> vk::integral_constant { ... }
BuiltinTypeDeclBuilder typeDeclBuilder(declContext, typeName,
clang::TagTypeKind::TTK_Class);
typeDeclBuilder.addTypeTemplateParam("T");
typeDeclBuilder.addIntegerTemplateParam("v", context.UnsignedIntTy);
typeDeclBuilder.startDefinition();
typeDeclBuilder.addField(
"h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.
*templateDecl = typeDeclBuilder.getTemplateDecl();
return typeDeclBuilder.getRecordDecl();
}
#endif

CXXRecordDecl *hlsl::DeclareNodeOutputArray(clang::ASTContext &Ctx,
DXIL::NodeIOKind Type,
CXXRecordDecl *OutputType,
Expand Down
7 changes: 3 additions & 4 deletions tools/clang/lib/AST/HlslBuiltinTypeDeclBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ BuiltinTypeDeclBuilder::BuiltinTypeDeclBuilder(DeclContext *declContext,
m_recordDecl->setImplicit(true);
}

TemplateTypeParmDecl *
BuiltinTypeDeclBuilder::addTypeTemplateParam(StringRef name,
TypeSourceInfo *defaultValue) {
TemplateTypeParmDecl *BuiltinTypeDeclBuilder::addTypeTemplateParam(
StringRef name, TypeSourceInfo *defaultValue, bool parameterPack) {
DXASSERT_NOMSG(!m_recordDecl->isBeingDefined() &&
!m_recordDecl->isCompleteDefinition());

Expand All @@ -45,7 +44,7 @@ BuiltinTypeDeclBuilder::addTypeTemplateParam(StringRef name,
astContext, m_recordDecl->getDeclContext(), NoLoc, NoLoc,
/* TemplateDepth */ 0, index,
&astContext.Idents.get(name, tok::TokenKind::identifier),
/* Typename */ false, /* ParameterPack */ false);
/* Typename */ false, parameterPack);
if (defaultValue != nullptr)
decl->setDefaultArgument(defaultValue);
m_templateParams.emplace_back(decl);
Expand Down
16 changes: 16 additions & 0 deletions tools/clang/lib/SPIRV/AlignmentSizeCalculator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "AlignmentSizeCalculator.h"
#include "clang/AST/Attr.h"
#include "clang/AST/DeclTemplate.h"

namespace {

Expand Down Expand Up @@ -264,6 +265,21 @@ std::pair<uint32_t, uint32_t> AlignmentSizeCalculator::getAlignmentAndSize(
return getAlignmentAndSize(desugaredType, rule, isRowMajor, stride);
}

const auto *recordType = type->getAs<RecordType>();
if (recordType != nullptr) {
const llvm::StringRef name = recordType->getDecl()->getName();

if (isTypeInVkNamespace(recordType) && name == "SpirvType") {
const ClassTemplateSpecializationDecl *templateDecl =
cast<ClassTemplateSpecializationDecl>(recordType->getDecl());
const uint64_t size =
templateDecl->getTemplateArgs()[1].getAsIntegral().getZExtValue();
const uint64_t alignment =
templateDecl->getTemplateArgs()[2].getAsIntegral().getZExtValue();
return {alignment, size};
}
}

if (isEnumType(type))
type = astContext.IntTy;

Expand Down
12 changes: 6 additions & 6 deletions tools/clang/lib/SPIRV/ConstEvaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ class ConstEvaluator {
SpirvConstant *translateAPFloat(llvm::APFloat floatValue, QualType targetType,
bool isSpecConstantMode);

/// Translates the given frontend APValue into its SPIR-V equivalent for the
/// given targetType.
SpirvConstant *translateAPValue(const APValue &value,
const QualType targetType,
bool isSpecConstantMode);

/// Tries to evaluate the given APInt as a 32-bit integer. If the evaluation
/// can be performed without loss, it returns the <result-id> of the SPIR-V
/// constant for that value.
Expand All @@ -52,12 +58,6 @@ class ConstEvaluator {
bool isSpecConstantMode);

private:
/// Translates the given frontend APValue into its SPIR-V equivalent for the
/// given targetType.
SpirvConstant *translateAPValue(const APValue &value,
const QualType targetType,
bool isSpecConstantMode);

/// Emits error to the diagnostic engine associated with the AST context.
template <unsigned N>
DiagnosticBuilder emitError(const char (&message)[N],
Expand Down
6 changes: 5 additions & 1 deletion tools/clang/lib/SPIRV/EmitVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2577,7 +2577,11 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
for (const SpvIntrinsicTypeOperand &operand :
spvIntrinsicType->getOperands()) {
if (operand.isTypeOperand) {
curTypeInst.push_back(emitType(operand.operand_as_type));
// calling emitType recursively will potentially replace the contents of
// curTypeInst, so we need to save them and restore after the call
std::vector<uint32_t> outerTypeInst = curTypeInst;
outerTypeInst.push_back(emitType(operand.operand_as_type));
curTypeInst = outerTypeInst;
} else {
auto *literal = dyn_cast<SpirvConstant>(operand.operand_as_inst);
if (literal && literal->isLiteral()) {
Expand Down
3 changes: 2 additions & 1 deletion tools/clang/lib/SPIRV/InitListHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ InitListHandler::createInitForStructType(QualType type, SourceLocation srcLoc,
assert(recordType);

LowerTypeVisitor lowerTypeVisitor(astContext, theEmitter.getSpirvContext(),
theEmitter.getSpirvOptions());
theEmitter.getSpirvOptions(),
theEmitter.getSpirvBuilder());
const SpirvType *spirvType =
lowerTypeVisitor.lowerType(type, SpirvLayoutRule::Void, false, srcLoc);

Expand Down
Loading

0 comments on commit d60dffe

Please sign in to comment.