Skip to content

Commit

Permalink
[SPIR-V] Implement vk::ext_builtin_input and vk::ext_builtin_output (#…
Browse files Browse the repository at this point in the history
…6027)

I definitely think it would look better if we allowed these attributes
on variables, ie microsoft/hlsl-specs#76. I
haven't fully investigated how involved it would be to implement, but my
intuition is that it wouldn't take that much more work.

Fixes #4217.
  • Loading branch information
cassiebeckley authored Jan 2, 2024
1 parent a743e97 commit 9dacbe3
Show file tree
Hide file tree
Showing 13 changed files with 256 additions and 29 deletions.
16 changes: 16 additions & 0 deletions tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,22 @@ def VKBuiltIn : InheritableAttr {
let Documentation = [Undocumented];
}

def VKExtBuiltinInput : InheritableAttr {
let Spellings = [CXX11<"vk", "ext_builtin_input">];
let Subjects = SubjectList<[Var], ErrorDiag>;
let Args = [IntArgument<"BuiltInID">];
let LangOpts = [SPIRV];
let Documentation = [Undocumented];
}

def VKExtBuiltinOutput : InheritableAttr {
let Spellings = [CXX11<"vk", "ext_builtin_output">];
let Subjects = SubjectList<[Var], ErrorDiag>;
let Args = [IntArgument<"BuiltInID">];
let LangOpts = [SPIRV];
let Documentation = [Undocumented];
}

def VKLocation : InheritableAttr {
let Spellings = [CXX11<"vk", "location">];
let Subjects = SubjectList<[Function, ParmVar, Field], ErrorDiag>;
Expand Down
81 changes: 53 additions & 28 deletions tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,15 @@ SpirvInstruction *DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
/* spvArgs */ {}, /* isInst */ false,
loc);
}

if (auto *builtinAttr = decl->getAttr<VKExtBuiltinInputAttr>()) {
return getBuiltinVar(spv::BuiltIn(builtinAttr->getBuiltInID()),
decl->getType(), spv::StorageClass::Input, loc);
} else if (auto *builtinAttr = decl->getAttr<VKExtBuiltinOutputAttr>()) {
return getBuiltinVar(spv::BuiltIn(builtinAttr->getBuiltInID()),
decl->getType(), spv::StorageClass::Output, loc);
}

if (hlsl::IsHLSLDynamicResourceType(decl->getType()) ||
hlsl::IsHLSLDynamicSamplerType(decl->getType())) {
emitError("HLSL object %0 not yet supported with -spirv",
Expand Down Expand Up @@ -3808,23 +3817,60 @@ void DeclResultIdMapper::decorateInterpolationMode(

SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
QualType type,
spv::StorageClass sc,
SourceLocation loc) {
// Guarantee uniqueness
uint32_t spvBuiltinId = static_cast<uint32_t>(builtIn);
const auto builtInVar = builtinToVarMap.find(spvBuiltinId);
if (builtInVar != builtinToVarMap.end()) {
return builtInVar->second;
}
bool mayNeedFlatDecoration = false;
switch (builtIn) {
case spv::BuiltIn::HelperInvocation:
case spv::BuiltIn::SubgroupSize:
case spv::BuiltIn::SubgroupLocalInvocationId:
needsLegalization = true;
break;
}

// Create a dummy StageVar for this builtin variable
auto var = spvBuilder.addStageBuiltinVar(type, sc, builtIn,
/*isPrecise*/ false, loc);

if (spvContext.isPS() && sc == spv::StorageClass::Input) {
if (isUintOrVecMatOfUintType(type) || isSintOrVecMatOfSintType(type) ||
isBoolOrVecMatOfBoolType(type)) {
spvBuilder.decorateFlat(var, loc);
}
}

const hlsl::SigPoint *sigPoint =
hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
hlsl::DxilParamInputQual::In, spvContext.getCurrentShaderModelKind(),
/*isPatchConstant=*/false));

StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type,
/*locAndComponentCount=*/{0, 0, false});

stageVar.setIsSpirvBuiltin();
stageVar.setSpirvInstr(var);
stageVars.push_back(stageVar);

// Store in map for re-use
builtinToVarMap[spvBuiltinId] = var;
return var;
}

SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
QualType type,
SourceLocation loc) {
spv::StorageClass sc = spv::StorageClass::Max;

// Valid builtins supported
switch (builtIn) {
case spv::BuiltIn::HelperInvocation:
case spv::BuiltIn::SubgroupSize:
case spv::BuiltIn::SubgroupLocalInvocationId:
needsLegalization = true;
mayNeedFlatDecoration = true;
LLVM_FALLTHROUGH;
case spv::BuiltIn::HitTNV:
case spv::BuiltIn::RayTmaxNV:
case spv::BuiltIn::RayTminNV:
Expand Down Expand Up @@ -3857,32 +3903,11 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
sc = spv::StorageClass::Output;
break;
default:
assert(false && "unsupported SPIR-V builtin");
return nullptr;
}

// Create a dummy StageVar for this builtin variable
auto var = spvBuilder.addStageBuiltinVar(type, sc, builtIn,
/*isPrecise*/ false, loc);
if (mayNeedFlatDecoration && spvContext.isPS()) {
spvBuilder.decorateFlat(var, loc);
assert(false && "cannot infer storage class for SPIR-V builtin");
break;
}

const hlsl::SigPoint *sigPoint =
hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
hlsl::DxilParamInputQual::In, spvContext.getCurrentShaderModelKind(),
/*isPatchConstant=*/false));

StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type,
/*locAndComponentCount=*/{0, 0, false});

stageVar.setIsSpirvBuiltin();
stageVar.setSpirvInstr(var);
stageVars.push_back(stageVar);

// Store in map for re-use
builtinToVarMap[spvBuiltinId] = var;
return var;
return getBuiltinVar(builtIn, type, sc, loc);
}

SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
Expand Down
8 changes: 7 additions & 1 deletion tools/clang/lib/SPIRV/DeclResultIdMapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,13 @@ class DeclResultIdMapper {
FeatureManager &features,
const SpirvCodeGenOptions &spirvOptions);

/// \brief Returns the SPIR-V builtin variable.
/// \brief Returns the SPIR-V builtin variable. Uses sc as default storage
/// class.
SpirvVariable *getBuiltinVar(spv::BuiltIn builtIn, QualType type,
spv::StorageClass sc, SourceLocation);

/// \brief Returns the SPIR-V builtin variable. Tries to infer storage class
/// from the builtin.
SpirvVariable *getBuiltinVar(spv::BuiltIn builtIn, QualType type,
SourceLocation);

Expand Down
65 changes: 65 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,39 @@ bool SpirvEmitter::validateVKAttributes(const NamedDecl *decl) {
}
}

// a VarDecl should have only one of vk::ext_builtin_input or
// vk::ext_builtin_output
if (decl->hasAttr<VKExtBuiltinInputAttr>() &&
decl->hasAttr<VKExtBuiltinOutputAttr>()) {
emitError("vk::ext_builtin_input cannot be used together with "
"vk::ext_builtin_output",
decl->getAttr<VKExtBuiltinOutputAttr>()->getLocation());
success = false;
}

// vk::ext_builtin_input and vk::ext_builtin_output must only be used for a
// static variable. We only allow them to be attached to variables, so it
// should be fine to cast here.
if ((decl->hasAttr<VKExtBuiltinInputAttr>() ||
decl->hasAttr<VKExtBuiltinOutputAttr>()) &&
cast<VarDecl>(decl)->getStorageClass() != StorageClass::SC_Static) {
emitError("vk::ext_builtin_input and vk::ext_builtin_output can only be "
"applied to a static variable",
decl->getLocation());
success = false;
}

// vk::ext_builtin_input and vk::ext_builtin_output must only be used for a
// static variable. We only allow them to be attached to variables, so it
// should be fine to cast here.
if (decl->hasAttr<VKExtBuiltinInputAttr>() &&
!cast<VarDecl>(decl)->getType().isConstQualified()) {
emitError("vk::ext_builtin_input can only be applied to a const-qualified "
"variable",
decl->getLocation());
success = false;
}

return success;
}

Expand Down Expand Up @@ -1799,6 +1832,38 @@ void SpirvEmitter::doVarDecl(const VarDecl *decl) {
return;
}

// Handle vk::ext_builtin_input and vk::ext_builtin_input by using
// getBuiltinVar to create the builtin and validate the storage class
if (decl->hasAttr<VKExtBuiltinInputAttr>()) {
auto *builtinAttr = decl->getAttr<VKExtBuiltinInputAttr>();
int builtinId = builtinAttr->getBuiltInID();
SpirvVariable *builtinVar =
declIdMapper.getBuiltinVar(spv::BuiltIn(builtinId), decl->getType(),
spv::StorageClass::Input, loc);
if (builtinVar->getStorageClass() != spv::StorageClass::Input) {
emitError("cannot redefine builtin %0 as an input",
builtinAttr->getLocation())
<< builtinId;
emitWarning("previous definition is here",
builtinVar->getSourceLocation());
}
return;
} else if (decl->hasAttr<VKExtBuiltinOutputAttr>()) {
auto *builtinAttr = decl->getAttr<VKExtBuiltinOutputAttr>();
int builtinId = builtinAttr->getBuiltInID();
SpirvVariable *builtinVar =
declIdMapper.getBuiltinVar(spv::BuiltIn(builtinId), decl->getType(),
spv::StorageClass::Output, loc);
if (builtinVar->getStorageClass() != spv::StorageClass::Output) {
emitError("cannot redefine builtin %0 as an output",
builtinAttr->getLocation())
<< builtinId;
emitWarning("previous definition is here",
builtinVar->getSourceLocation());
}
return;
}

// We can have VarDecls inside cbuffer/tbuffer. For those VarDecls, we need
// to emit their cbuffer/tbuffer as a whole and access each individual one
// using access chains.
Expand Down
10 changes: 10 additions & 0 deletions tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13585,6 +13585,16 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
"DrawIndex,DeviceIndex,ViewportMaskNV"),
A.getAttributeSpellingListIndex());
break;
case AttributeList::AT_VKExtBuiltinInput:
declAttr = ::new (S.Context) VKExtBuiltinInputAttr(
A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
A.getAttributeSpellingListIndex());
break;
case AttributeList::AT_VKExtBuiltinOutput:
declAttr = ::new (S.Context) VKExtBuiltinOutputAttr(
A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
A.getAttributeSpellingListIndex());
break;
case AttributeList::AT_VKLocation:
declAttr = ::new (S.Context)
VKLocationAttr(A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: not %dxc -T ps_6_0 -E main -fcgl %s -spirv 2>&1 | FileCheck %s

// CHECK: error: vk::ext_builtin_input cannot be used together with vk::ext_builtin_output
[[vk::ext_builtin_input(/* NumWorkgroups */ 24)]]
[[vk::ext_builtin_output(/* NumWorkgroups */ 24)]]
static uint3 invalid;

void main() {
}
15 changes: 15 additions & 0 deletions tools/clang/test/CodeGenSPIRV/spv.inline.builtin.input.flat.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: %dxc -T ps_6_0 -E main -fcgl %s -spirv | FileCheck %s

// CHECK: OpEntryPoint Fragment %main "main" %gl_SampleID
// CHECK: OpDecorate %gl_SampleID BuiltIn SampleId
// CHECK: OpDecorate %gl_SampleID Flat

// CHECK: %gl_SampleID = OpVariable %_ptr_Input_int Input

[[vk::ext_builtin_input(/* SampleID */ 18)]]
static const int gl_SampleID;

void main() {
// CHECK: {{%[0-9]+}} = OpLoad %int %gl_SampleID
int sID = gl_SampleID;
}
23 changes: 23 additions & 0 deletions tools/clang/test/CodeGenSPIRV/spv.inline.builtin.input.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: %dxc -T cs_6_0 -E main -fcgl %s -spirv | FileCheck %s

// CHECK: OpEntryPoint GLCompute %main "main" %gl_NumWorkGroups
// CHECK: OpDecorate %gl_NumWorkGroups BuiltIn NumWorkgroups

// CHECK: %gl_NumWorkGroups = OpVariable %_ptr_Input_v3uint Input

[[vk::ext_builtin_input(/* NumWorkgroups */ 24)]]
static const uint3 gl_NumWorkGroups;

uint square_x(uint3 v) {
return v.x * v.x;
}

[numthreads(32,1,1)]
void main() {
// CHECK: {{%[0-9]+}} = OpLoad %v3uint %gl_NumWorkGroups
uint3 numWorkgroups = gl_NumWorkGroups;
// CHECK: [[nwg:%[0-9]+]] = OpLoad %v3uint %gl_NumWorkGroups
// CHECK: OpStore %param_var_v [[nwg]]
// CHECK: OpFunctionCall %uint %square_x %param_var_v
square_x(gl_NumWorkGroups);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: not %dxc -T ps_6_0 -E main -fcgl %s -spirv 2>&1 | FileCheck %s

// CHECK: error: vk::ext_builtin_input can only be applied to a const-qualified variable
[[vk::ext_builtin_input(/* NumWorkgroups */ 24)]]
static uint3 invalid;

void main() {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: not %dxc -T ps_6_0 -E main -fcgl %s -spirv 2>&1 | FileCheck %s

// CHECK: error: vk::ext_builtin_input and vk::ext_builtin_output can only be applied to a static variable
[[vk::ext_builtin_input(/* NumWorkgroups */ 24)]]
uint3 invalid;

void main() {
}
22 changes: 22 additions & 0 deletions tools/clang/test/CodeGenSPIRV/spv.inline.builtin.output.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: %dxc -T ps_6_0 -E main -fcgl %s -spirv | FileCheck %s

// CHECK: OpEntryPoint Fragment %main "main" [[fragStencilVar:%[0-9]+]]
// CHECK: OpDecorate [[fragStencilVar]] BuiltIn FragStencilRefEXT

// CHECK: [[fragStencilVar]] = OpVariable %_ptr_Output_int Output

[[vk::ext_extension("SPV_EXT_shader_stencil_export")]]
[[vk::ext_builtin_output(/* FragStencilRefEXT */ 5014)]]
static int gl_FragStencilRefARB;

void assign(out int val) {
val = 123;
}

void main() {
// CHECK: OpStore [[fragStencilVar]] %int_10
gl_FragStencilRefARB = 10;

// CHECK: OpFunctionCall %void %assign [[fragStencilVar]]
assign(gl_FragStencilRefARB);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: not %dxc -T ps_6_0 -E main -fcgl %s -spirv 2>&1 | FileCheck %s

// CHECK: error: vk::ext_builtin_input and vk::ext_builtin_output can only be applied to a static variable
[[vk::ext_builtin_output(/* NumWorkgroups */ 24)]]
uint3 invalid;

void main() {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: not %dxc -T ps_6_0 -E main -fcgl %s -spirv 2>&1 | FileCheck %s

[[vk::ext_builtin_input(/* NumWorkgroups */ 24)]]
static const uint3 gl_NumWorkGroups;

// CHECK: error: cannot redefine builtin 24 as an output
// CHECK: warning: previous definition is here
[[vk::ext_builtin_output(/* NumWorkgroups */ 24)]]
static uint3 invalid;

void main() {
}

0 comments on commit 9dacbe3

Please sign in to comment.