Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement wave size range #6167

Merged
merged 15 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/DXIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3261,7 +3261,9 @@ SM.TRIOUTPUTPRIMITIVEMISMATCH Hull Shader declared with Tri Domain m
SM.UNDEFINEDOUTPUT Not all elements of output %0 were written.
SM.VALIDDOMAIN Invalid Tessellator Domain specified. Must be isoline, tri or quad.
SM.VIEWIDNEEDSSLOT ViewID requires compatible space in pixel shader input signature
SM.WAVESIZEMINGEQMAX Declared Minimum WaveSize %0 greater or equal to declared Maximum Wavesize %1
SM.WAVESIZENEEDSDXIL16PLUS WaveSize is valid only for DXIL version 1.6 and higher.
SM.WAVESIZEPREFERREDOUTOFRANGE Preferred WaveSize %0 outside valid range [%1..%2]
SM.WAVESIZEVALUE Declared WaveSize %0 outside valid range [%1..%2], or not a power of 2.
SM.ZEROHSINPUTCONTROLPOINTWITHINPUT When HS input control point count is 0, no input signature should exist.
TYPES.DEFINED Type must be defined based on DXIL primitives
Expand Down
19 changes: 16 additions & 3 deletions include/dxc/DXIL/DxilConstants.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,23 @@ inline bool IsFeedbackTexture(DXIL::ResourceKind ResourceKind) {
ResourceKind == DXIL::ResourceKind::FeedbackTexture2DArray;
}

inline bool IsValidWaveSizeValue(unsigned size) {
inline bool isPowerOf2(unsigned x) { return (x & (x - 1)) == 0; }

inline bool IsValidWaveSizeValue(unsigned min_wave, unsigned max_wave,
unsigned pref_wave) {
// must be power of 2 between 4 and 128
return size >= kMinWaveSize && size <= kMaxWaveSize &&
(size & (size - 1)) == 0;
bool minIsValid = min_wave >= kMinWaveSize && min_wave <= kMaxWaveSize &&
isPowerOf2(min_wave);
if (max_wave == 0)
return true;
bool maxIsValid = max_wave >= kMinWaveSize && max_wave <= kMaxWaveSize &&
isPowerOf2(max_wave);
// 0 is a valid value for the preferred wave size
bool prefIsValid =
pref_wave == 0 || (pref_wave >= kMinWaveSize &&
pref_wave <= kMaxWaveSize && isPowerOf2(pref_wave));

return minIsValid && maxIsValid && prefIsValid;
}

// TODO: change opcodes.
Expand Down
12 changes: 8 additions & 4 deletions include/dxc/DXIL/DxilFunctionProps.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ struct DxilFunctionProps {
memset(&Node, 0, sizeof(Node));
Node.LaunchType = DXIL::NodeLaunchType::Invalid;
Node.LocalRootArgumentsTableIndex = -1;
waveSize = 0;
waveMinSize = 0;
waveMaxSize = 0;
wavePreferredSize = 0;
}
union {
// Geometry shader.
Expand Down Expand Up @@ -107,9 +109,11 @@ struct DxilFunctionProps {
std::vector<NodeIOProperties> InputNodes;
std::vector<NodeIOProperties> OutputNodes;

// WaveSize is currently allowed only on compute shaders, but could be
// supported on other shader types in the future
unsigned waveSize;
// SM 6.6 allows WaveSize specification for only a single required size. SM
// 6.8+ allows specification of WaveSize as a min, max and preferred value.
unsigned waveMinSize;
unsigned waveMaxSize;
unsigned wavePreferredSize;
// Save root signature for lib profile entry.
std::vector<uint8_t> serializedRootSignature;
void SetSerializedRootSignature(const uint8_t *pData, unsigned size) {
Expand Down
1 change: 1 addition & 0 deletions include/dxc/DXIL/DxilMetadataHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ class DxilMDHelper {
static const unsigned kDxilNodeInputsTag = 20;
static const unsigned kDxilNodeOutputsTag = 21;
static const unsigned kDxilNodeMaxDispatchGridTag = 22;
static const unsigned kDxilRangedWaveSizeTag = 23;

// Node Input/Output State.
static const unsigned kDxilNodeOutputIDTag = 0;
Expand Down
39 changes: 30 additions & 9 deletions lib/DXIL/DxilMetadataHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1596,10 +1596,17 @@ MDTuple *DxilMDHelper::EmitDxilEntryProperties(uint64_t rawShaderFlag,
NumThreadVals.emplace_back(Uint32ToConstMD(props.numThreads[2]));
MDVals.emplace_back(MDNode::get(m_Ctx, NumThreadVals));

if (props.waveSize != 0) {
MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilWaveSizeTag));
vector<Metadata *> WaveSizeVal;
WaveSizeVal.emplace_back(Uint32ToConstMD(props.waveSize));
if (props.waveMinSize != 0) {
bool UseRange = props.waveMaxSize != 0;
MDVals.emplace_back(
Uint32ToConstMD(UseRange ? DxilMDHelper::kDxilRangedWaveSizeTag
: DxilMDHelper::kDxilWaveSizeTag));
SmallVector<Metadata *, 3> WaveSizeVal;
WaveSizeVal.emplace_back(Uint32ToConstMD(props.waveMinSize));
if (UseRange) {
WaveSizeVal.emplace_back(Uint32ToConstMD(props.waveMaxSize));
WaveSizeVal.emplace_back(Uint32ToConstMD(props.wavePreferredSize));
}
MDVals.emplace_back(MDNode::get(m_Ctx, WaveSizeVal));
}
} break;
Expand Down Expand Up @@ -1823,7 +1830,14 @@ void DxilMDHelper::LoadDxilEntryProperties(const MDOperand &MDO,
case DxilMDHelper::kDxilWaveSizeTag: {
DXASSERT(props.IsCS() || props.IsNode(), "else invalid shader kind");
MDNode *pNode = cast<MDNode>(MDO.get());
props.waveSize = ConstMDToUint32(pNode->getOperand(0));
props.waveMinSize = ConstMDToUint32(pNode->getOperand(0));
} break;
case DxilMDHelper::kDxilRangedWaveSizeTag: {
DXASSERT(props.IsCS() || props.IsNode(), "else invalid shader kind");
MDNode *pNode = cast<MDNode>(MDO.get());
props.waveMinSize = ConstMDToUint32(pNode->getOperand(0));
props.waveMaxSize = ConstMDToUint32(pNode->getOperand(1));
props.wavePreferredSize = ConstMDToUint32(pNode->getOperand(2));
} break;
case DxilMDHelper::kDxilEntryRootSigTag: {
MDNode *pNode = cast<MDNode>(MDO.get());
Expand Down Expand Up @@ -2644,10 +2658,17 @@ void DxilMDHelper::EmitDxilNodeState(std::vector<llvm::Metadata *> &MDVals,

// Optional Fields

if (props.waveSize != 0) {
MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilWaveSizeTag));
vector<Metadata *> WaveSizeVal;
WaveSizeVal.emplace_back(Uint32ToConstMD(props.waveSize));
if (props.waveMinSize != 0) {
bool UseRange = props.waveMaxSize != 0;
MDVals.emplace_back(
Uint32ToConstMD(UseRange ? DxilMDHelper::kDxilRangedWaveSizeTag
: DxilMDHelper::kDxilWaveSizeTag));
SmallVector<Metadata *, 3> WaveSizeVal;
WaveSizeVal.emplace_back(Uint32ToConstMD(props.waveMinSize));
if (UseRange) {
WaveSizeVal.emplace_back(Uint32ToConstMD(props.waveMaxSize));
WaveSizeVal.emplace_back(Uint32ToConstMD(props.wavePreferredSize));
}
MDVals.emplace_back(MDNode::get(m_Ctx, WaveSizeVal));
}

Expand Down
4 changes: 2 additions & 2 deletions lib/DXIL/DxilModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ void DxilModule::SetWaveSize(unsigned size) {
"only works for CS profile");
DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
DXASSERT_NOMSG(m_pSM->GetKind() == props.shaderKind);
props.waveSize = size;
props.waveMinSize = size;
}

unsigned DxilModule::GetWaveSize() const {
Expand All @@ -410,7 +410,7 @@ unsigned DxilModule::GetWaveSize() const {
return 0;
const DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
DXASSERT_NOMSG(m_pSM->GetKind() == props.shaderKind);
return props.waveSize;
return props.waveMinSize;
}

DXIL::InputPrimitive DxilModule::GetInputPrimitive() const {
Expand Down
9 changes: 4 additions & 5 deletions lib/DxilContainer/DxilContainerAssembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1812,11 +1812,10 @@ class DxilRDATWriter : public DxilPartWriter {
shaderKind = (uint32_t)props.shaderKind;
if (pInfo2 && DM.HasDxilEntryProps(&function)) {
const auto &entryProps = DM.GetDxilEntryProps(&function);
unsigned waveSize = entryProps.props.waveSize;
if (waveSize) {
pInfo2->MinimumExpectedWaveLaneCount = waveSize;
pInfo2->MaximumExpectedWaveLaneCount = waveSize;
}
pInfo2->MinimumExpectedWaveLaneCount = entryProps.props.waveMinSize;
pInfo2->MaximumExpectedWaveLaneCount =
entryProps.props.waveMaxSize > 0 ? entryProps.props.waveMaxSize
: entryProps.props.waveMinSize;
pInfo2->ShaderFlags = 0;
if (entryProps.props.IsNode()) {
shaderInfo = AddShaderNodeInfo(DM, function, entryProps, *pInfo2,
Expand Down
24 changes: 21 additions & 3 deletions lib/HLSL/DxilValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5251,18 +5251,36 @@ static void ValidateEntryProps(ValidationContext &ValCtx,

// validate wave size (currently allowed only on CS but might be supported on
// other shader types in the future)
if (props.waveSize != 0) {
if (props.waveMinSize != 0) {
if (DXIL::CompareVersions(ValCtx.m_DxilMajor, ValCtx.m_DxilMinor, 1, 6) <
0) {
ValCtx.EmitFnFormatError(F, ValidationRule::SmWaveSizeNeedsDxil16Plus,
{});
}
if (!DXIL::IsValidWaveSizeValue(props.waveSize)) {
if (!DXIL::IsValidWaveSizeValue(props.waveMinSize, props.waveMaxSize,
props.wavePreferredSize)) {
ValCtx.EmitFnFormatError(F, ValidationRule::SmWaveSizeValue,
{std::to_string(props.waveSize),
{std::to_string(props.waveMinSize),
std::to_string(DXIL::kMinWaveSize),
std::to_string(DXIL::kMaxWaveSize)});
}

bool prefInRange = props.wavePreferredSize == 0
? true
: props.wavePreferredSize >= props.waveMinSize &&
props.wavePreferredSize <= props.waveMaxSize;
if (!prefInRange) {
ValCtx.EmitFnFormatError(F, ValidationRule::SmWaveSizePreferredOutOfRange,
{std::to_string(props.wavePreferredSize),
std::to_string(props.waveMinSize),
std::to_string(props.waveMaxSize)});
}

if (props.waveMaxSize != 0 && props.waveMinSize >= props.waveMaxSize) {
ValCtx.EmitFnFormatError(F, ValidationRule::SmWaveSizeMinGEQMax,
{std::to_string(props.waveMinSize),
std::to_string(props.waveMaxSize)});
}
}

if (ShaderType == DXIL::ShaderKind::Compute || props.IsNode()) {
Expand Down
2 changes: 1 addition & 1 deletion tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ def HLSLWaveSensitive : InheritableAttr {

def HLSLWaveSize : InheritableAttr {
let Spellings = [CXX11<"", "wavesize", 2017>];
let Args = [IntArgument<"Size">];
let Args = [IntArgument<"Min">, DefaultIntArgument<"Max", 0>, DefaultIntArgument<"Preferred", 0>];
let Documentation = [Undocumented];
}

Expand Down
7 changes: 5 additions & 2 deletions tools/clang/lib/CodeGen/CGHLSLMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1809,8 +1809,11 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
funcProps->ShaderProps.PS.EarlyDepthStencil = true;
}

if (const HLSLWaveSizeAttr *Attr = FD->getAttr<HLSLWaveSizeAttr>())
funcProps->waveSize = Attr->getSize();
if (const HLSLWaveSizeAttr *Attr = FD->getAttr<HLSLWaveSizeAttr>()) {
funcProps->waveMinSize = Attr->getMin();
funcProps->waveMaxSize = Attr->getMax();
funcProps->wavePreferredSize = Attr->getPreferred();
}

// Node shader
if (isNode) {
Expand Down
38 changes: 31 additions & 7 deletions tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12620,21 +12620,39 @@ HLSLWaveSizeAttr *ValidateWaveSizeAttributes(Sema &S, Decl *D,
const AttributeList &A) {
// validate that the wavesize argument is a power of 2 between 4 and 128
// inclusive
HLSLWaveSizeAttr *pAttr = ::new (S.Context)
HLSLWaveSizeAttr(A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
A.getAttributeSpellingListIndex());
HLSLWaveSizeAttr *pAttr = ::new (S.Context) HLSLWaveSizeAttr(
A.getRange(), S.Context, ValidateAttributeIntArg(S, A, 0),
ValidateAttributeIntArg(S, A, 1), ValidateAttributeIntArg(S, A, 2),
A.getAttributeSpellingListIndex());

int minWave = pAttr->getMin();
int maxWave = pAttr->getMax();
int prefWave = pAttr->getPreferred();

unsigned waveSize = pAttr->getSize();
if (!DXIL::IsValidWaveSizeValue(waveSize)) {
if (!DXIL::IsValidWaveSizeValue(minWave, maxWave, prefWave)) {
S.Diag(A.getLoc(), diag::err_hlsl_wavesize_size)
<< DXIL::kMinWaveSize << DXIL::kMaxWaveSize;
}

bool prefInRange =
prefWave == 0 ? true : prefWave >= minWave && prefWave <= maxWave;
if (!prefInRange) {
S.Diag(A.getLoc(), diag::err_hlsl_wavesize_pref_size_out_of_range)
<< (unsigned)prefWave << (unsigned)minWave << (unsigned)maxWave;
}

if (maxWave != 0 && minWave >= maxWave) {
S.Diag(A.getLoc(), diag::err_hlsl_wavesize_min_geq_max)
<< (unsigned)minWave << (unsigned)maxWave;
}

// make sure there is not already an existing conflicting
// wavesize attribute on the decl
HLSLWaveSizeAttr *waveSizeAttr = D->getAttr<HLSLWaveSizeAttr>();
if (waveSizeAttr) {
if (waveSizeAttr->getSize() != pAttr->getSize()) {
if (waveSizeAttr->getMin() != pAttr->getMin() ||
waveSizeAttr->getMax() != pAttr->getMax() ||
waveSizeAttr->getPreferred() != pAttr->getPreferred()) {
S.Diag(A.getLoc(), diag::err_hlsl_conflicting_shader_attribute)
<< pAttr->getSpelling() << waveSizeAttr->getSpelling();
S.Diag(waveSizeAttr->getLocation(), diag::note_conflicting_attribute);
Expand Down Expand Up @@ -14609,7 +14627,13 @@ void hlsl::CustomPrintHLSLAttr(const clang::Attr *A, llvm::raw_ostream &Out,
Attr *noconst = const_cast<Attr *>(A);
HLSLWaveSizeAttr *ACast = static_cast<HLSLWaveSizeAttr *>(noconst);
Indent(Indentation, Out);
Out << "[wavesize(" << ACast->getSize() << ")]\n";
Out << "[wavesize(" << ACast->getMin();
if (ACast->getMax() > 0) {
Out << ", " << ACast->getMax();
if (ACast->getPreferred() > 0)
Out << ", " << ACast->getPreferred();
}
Out << ")]\n";
break;
}

Expand Down
Loading
Loading