From 95f00d03cb392a30804ad121c31d6e3984ac19dd Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Fri, 5 Jan 2024 13:23:23 -0800 Subject: [PATCH] [NFC] Move hlsl::DiagnoseTranslationUnit to separate file. (#6126) Since DiagnoseTranslationUnit is after Sema, it doesn't need to be part of SemaHLSL. This is in preparation for moving DiagnoseHLSLMethodCall to hlsl::DiagnoseTranslationUnit for access to the call graph. Only move code to new file, not other changes except replace OutputDebugFormat with llvm::dbgs for CallGraphWithRecurseGuard::dump. For https://github.com/microsoft/DirectXShaderCompiler/issues/5855 --- tools/clang/include/clang/Sema/SemaHLSL.h | 5 - tools/clang/lib/Sema/CMakeLists.txt | 1 + tools/clang/lib/Sema/SemaHLSL.cpp | 414 ------------------- tools/clang/lib/Sema/SemaHLSLDiagnoseTU.cpp | 426 ++++++++++++++++++++ 4 files changed, 427 insertions(+), 419 deletions(-) create mode 100644 tools/clang/lib/Sema/SemaHLSLDiagnoseTU.cpp diff --git a/tools/clang/include/clang/Sema/SemaHLSL.h b/tools/clang/include/clang/Sema/SemaHLSL.h index 99eed69c8f..40b030b430 100644 --- a/tools/clang/include/clang/Sema/SemaHLSL.h +++ b/tools/clang/include/clang/Sema/SemaHLSL.h @@ -69,11 +69,6 @@ void DiagnosePackingOffset(clang::Sema *self, clang::SourceLocation loc, void DiagnoseRegisterType(clang::Sema *self, clang::SourceLocation loc, clang::QualType type, char registerType); -clang::FunctionDecl *ValidateNoRecursion(clang::Sema *self, - clang::FunctionDecl *FD); - -void ValidateNoRecursionInTranslationUnit(clang::Sema *self); - void DiagnoseTranslationUnit(clang::Sema *self); void DiagnoseUnusualAnnotationsForHLSL( diff --git a/tools/clang/lib/Sema/CMakeLists.txt b/tools/clang/lib/Sema/CMakeLists.txt index 3fd53759ed..92cfcf3772 100644 --- a/tools/clang/lib/Sema/CMakeLists.txt +++ b/tools/clang/lib/Sema/CMakeLists.txt @@ -41,6 +41,7 @@ add_clang_library(clangSema SemaExprObjC.cpp SemaFixItUtils.cpp SemaHLSL.cpp # HLSL Change + SemaHLSLDiagnoseTU.cpp # HLSL Change SemaInit.cpp SemaLambda.cpp SemaLookup.cpp diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index ff52b2b897..7dc479166d 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -17,7 +17,6 @@ #include "dxc/HlslIntrinsicOp.h" #include "dxc/Support/Global.h" #include "dxc/Support/WinIncludes.h" -#include "dxc/WinAdapter.h" #include "dxc/dxcapi.internal.h" #include "gen_intrin_main_tables_15.h" #include "clang/AST/ASTContext.h" @@ -29,7 +28,6 @@ #include "clang/AST/ExprCXX.h" #include "clang/AST/ExternalASTSource.h" #include "clang/AST/HlslTypes.h" -#include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/TypeLoc.h" #include "clang/Basic/Diagnostic.h" #include "clang/Sema/ExternalSemaSource.h" @@ -2892,175 +2890,6 @@ CreateSubobjectProceduralPrimitiveHitGroup(ASTContext &context) { return decl; } -// -// This is similar to clang/Analysis/CallGraph, but the following differences -// motivate this: -// -// - track traversed vs. observed nodes explicitly -// - fully visit all reachable functions -// - merge graph visiting with checking for recursion -// - track global variables and types used (NYI) -// -namespace hlsl { -struct CallNode { - FunctionDecl *CallerFn; - ::llvm::SmallPtrSet CalleeFns; -}; -typedef ::llvm::DenseMap CallNodes; -typedef ::llvm::SmallPtrSet FnCallStack; -typedef ::llvm::SmallPtrSet FunctionSet; -typedef ::llvm::SmallVector PendingFunctions; -typedef ::llvm::DenseMap FunctionMap; - -// Returns the definition of a function. -// This serves two purposes - ignore built-in functions, and pick -// a single Decl * to be used in maps and sets. -static FunctionDecl *getFunctionWithBody(FunctionDecl *F) { - if (!F) - return nullptr; - if (F->doesThisDeclarationHaveABody()) - return F; - F = F->getFirstDecl(); - for (auto &&Candidate : F->redecls()) { - if (Candidate->doesThisDeclarationHaveABody()) { - return Candidate; - } - } - return nullptr; -} - -// AST visitor that maintains visited and pending collections, as well -// as recording nodes of caller/callees. -class FnReferenceVisitor : public RecursiveASTVisitor { -private: - CallNodes &m_callNodes; - FunctionSet &m_visitedFunctions; - PendingFunctions &m_pendingFunctions; - FunctionDecl *m_source; - CallNodes::iterator m_sourceIt; - -public: - FnReferenceVisitor(FunctionSet &visitedFunctions, - PendingFunctions &pendingFunctions, CallNodes &callNodes) - : m_callNodes(callNodes), m_visitedFunctions(visitedFunctions), - m_pendingFunctions(pendingFunctions) {} - - void setSourceFn(FunctionDecl *F) { - F = getFunctionWithBody(F); - m_source = F; - m_sourceIt = m_callNodes.find(F); - } - - bool VisitDeclRefExpr(DeclRefExpr *ref) { - ValueDecl *valueDecl = ref->getDecl(); - RecordFunctionDecl(dyn_cast_or_null(valueDecl)); - return true; - } - - bool VisitCXXMemberCallExpr(CXXMemberCallExpr *callExpr) { - RecordFunctionDecl(callExpr->getMethodDecl()); - return true; - } - - void RecordFunctionDecl(FunctionDecl *funcDecl) { - funcDecl = getFunctionWithBody(funcDecl); - if (funcDecl) { - if (m_sourceIt == m_callNodes.end()) { - auto result = m_callNodes.insert( - std::make_pair(m_source, CallNode{m_source, {}})); - DXASSERT(result.second == true, - "else setSourceFn didn't assign m_sourceIt"); - m_sourceIt = result.first; - } - m_sourceIt->second.CalleeFns.insert(funcDecl); - if (!m_visitedFunctions.count(funcDecl)) { - m_pendingFunctions.push_back(funcDecl); - } - } - } -}; - -// A call graph that can check for reachability and recursion efficiently. -class CallGraphWithRecurseGuard { -private: - CallNodes m_callNodes; - FunctionSet m_visitedFunctions; - FunctionMap m_functionsCheckedForRecursion; - - FunctionDecl *CheckRecursion(FnCallStack &CallStack, FunctionDecl *D) { - auto it = m_functionsCheckedForRecursion.find(D); - if (it != m_functionsCheckedForRecursion.end()) - return it->second; - if (CallStack.insert(D).second == false) - return D; - auto node = m_callNodes.find(D); - if (node != m_callNodes.end()) { - for (FunctionDecl *Callee : node->second.CalleeFns) { - FunctionDecl *pResult = CheckRecursion(CallStack, Callee); - if (pResult) { - m_functionsCheckedForRecursion[D] = pResult; - return pResult; - } - } - } - CallStack.erase(D); - m_functionsCheckedForRecursion[D] = nullptr; - return nullptr; - } - -public: - void BuildForEntry(FunctionDecl *EntryFnDecl) { - DXASSERT_NOMSG(EntryFnDecl); - EntryFnDecl = getFunctionWithBody(EntryFnDecl); - PendingFunctions pendingFunctions; - FnReferenceVisitor visitor(m_visitedFunctions, pendingFunctions, - m_callNodes); - pendingFunctions.push_back(EntryFnDecl); - while (!pendingFunctions.empty()) { - FunctionDecl *pendingDecl = pendingFunctions.pop_back_val(); - if (m_visitedFunctions.insert(pendingDecl).second == true) { - visitor.setSourceFn(pendingDecl); - visitor.TraverseDecl(pendingDecl); - } - } - } - - // return true if FD2 is reachable from FD1 - bool CheckReachability(FunctionDecl *FD1, FunctionDecl *FD2) { - if (FD1 == FD2) - return true; - auto node = m_callNodes.find(FD1); - if (node != m_callNodes.end()) { - for (FunctionDecl *Callee : node->second.CalleeFns) { - if (CheckReachability(Callee, FD2)) - return true; - } - } - return false; - } - - FunctionDecl *CheckRecursion(FunctionDecl *EntryFnDecl) { - FnCallStack CallStack; - EntryFnDecl = getFunctionWithBody(EntryFnDecl); - return CheckRecursion(CallStack, EntryFnDecl); - } - - const CallNodes &GetCallGraph() { return m_callNodes; } - - void dump() const { - OutputDebugStringW(L"Call Nodes:\r\n"); - for (auto &node : m_callNodes) { - OutputDebugFormatA("%s [%p]:\r\n", node.first->getName().str().c_str(), - (void *)node.first); - for (auto callee : node.second.CalleeFns) { - OutputDebugFormatA(" %s [%p]\r\n", callee->getName().str().c_str(), - (void *)callee); - } - } - } -}; -} // namespace hlsl - /// Creates a Typedef in the specified ASTContext. static TypedefDecl *CreateGlobalTypedef(ASTContext *context, const char *ident, QualType baseType) { @@ -3150,8 +2979,6 @@ class HLSLExternalSource : public ExternalSemaSource { UsedIntrinsicStore m_usedIntrinsics; - CallGraphWithRecurseGuard m_callGraph; - /// Add all base QualTypes for each hlsl scalar types. void AddBaseTypes(); @@ -5796,8 +5623,6 @@ class HLSLExternalSource : public ExternalSemaSource { return method; } - CallGraphWithRecurseGuard &getCallGraph() { return m_callGraph; } - // Overload support. UINT64 ScoreCast(QualType leftType, QualType rightType); UINT64 ScoreFunction(OverloadCandidateSet::iterator &Cand); @@ -11291,31 +11116,6 @@ void hlsl::DiagnoseRegisterType(clang::Sema *self, clang::SourceLocation loc, } } -struct NameLookup { - FunctionDecl *Found; - FunctionDecl *Other; -}; - -static NameLookup GetSingleFunctionDeclByName(clang::Sema *self, StringRef Name, - bool checkPatch) { - auto DN = DeclarationName(&self->getASTContext().Idents.get(Name)); - FunctionDecl *pFoundDecl = nullptr; - for (auto idIter = self->IdResolver.begin(DN), idEnd = self->IdResolver.end(); - idIter != idEnd; ++idIter) { - FunctionDecl *pFnDecl = dyn_cast(*idIter); - if (!pFnDecl) - continue; - if (checkPatch && - !self->getASTContext().IsPatchConstantFunctionDecl(pFnDecl)) - continue; - if (pFoundDecl) { - return NameLookup{pFoundDecl, pFnDecl}; - } - pFoundDecl = pFnDecl; - } - return NameLookup{pFoundDecl, nullptr}; -} - // Check HLSL member call constraints bool Sema::DiagnoseHLSLMethodCall(const CXXMethodDecl *MD, SourceLocation Loc) { if (MD->hasAttr()) { @@ -11405,42 +11205,6 @@ bool hlsl::DiagnoseNodeStructArgument(Sema *self, TemplateArgumentLoc ArgLoc, } } -static bool IsTargetProfileLib6x(Sema &S) { - // Remaining functions are exported only if target is 'lib_6_x'. - const hlsl::ShaderModel *SM = - hlsl::ShaderModel::GetByName(S.getLangOpts().HLSLProfile.c_str()); - bool isLib6x = - SM->IsLib() && SM->GetMinor() == hlsl::ShaderModel::kOfflineMinor; - return isLib6x; -} - -bool IsExported(Sema *self, clang::FunctionDecl *FD, - bool isDefaultLinkageExternal) { - // Entry points are exported. - if (FD->hasAttr()) - return true; - - // Internal linkage functions include functions marked 'static'. - if (FD->getLinkageAndVisibility().getLinkage() == InternalLinkage) - return false; - - // Explicit 'export' functions are exported. - if (FD->hasAttr()) - return true; - - return isDefaultLinkageExternal; -} - -bool getDefaultLinkageExternal(clang::Sema *self) { - const LangOptions &opts = self->getLangOpts(); - bool isDefaultLinkageExternal = - opts.DefaultLinkage == DXIL::DefaultLinkage::External; - if (opts.DefaultLinkage == DXIL::DefaultLinkage::Default && - !opts.ExportShadersOnly && IsTargetProfileLib6x(*self)) - isDefaultLinkageExternal = true; - return isDefaultLinkageExternal; -} - // This function diagnoses whether or not all entry-point attributes // should exist on this shader stage void DiagnoseEntryAttrAllowedOnStage(clang::Sema *self, @@ -11484,34 +11248,6 @@ void DiagnoseEntryAttrAllowedOnStage(clang::Sema *self, } } -std::vector GetAllExportedFDecls(clang::Sema *self) { - // Add to the end, process from the beginning, to ensure AllExportedFDecls - // will contain functions in decl order. - std::vector AllExportedFDecls; - - std::deque Worklist; - Worklist.push_back(self->getASTContext().getTranslationUnitDecl()); - while (Worklist.size()) { - DeclContext *DC = Worklist.front(); - Worklist.pop_front(); - if (auto *FD = dyn_cast(DC)) { - AllExportedFDecls.push_back(FD); - } else { - for (auto *D : DC->decls()) { - if (auto *FD = dyn_cast(D)) { - if (FD->hasBody() && - IsExported(self, FD, getDefaultLinkageExternal(self))) - Worklist.push_back(FD); - } else if (auto *DC2 = dyn_cast(D)) { - Worklist.push_back(DC2); - } - } - } - } - - return AllExportedFDecls; -} - std::string getFQFunctionName(FunctionDecl *FD) { std::string name = ""; if (!FD) { @@ -11543,137 +11279,6 @@ std::string getFQFunctionName(FunctionDecl *FD) { return name; } -void hlsl::DiagnoseTranslationUnit(clang::Sema *self) { - DXASSERT_NOMSG(self != nullptr); - - // Don't bother with global validation if compilation has already failed. - if (self->getDiagnostics().hasErrorOccurred()) { - return; - } - - // Check RT shader if available for their payload use and match payload access - // against availiable payload modifiers. - // We have to do it late because we could have payload access in a called - // function and have to check the callgraph if the root shader has the right - // access rights to the payload structure. - if (self->getLangOpts().IsHLSLLibrary) { - if (self->getLangOpts().EnablePayloadAccessQualifiers) { - ASTContext &ctx = self->getASTContext(); - TranslationUnitDecl *TU = ctx.getTranslationUnitDecl(); - DiagnoseRaytracingPayloadAccess(*self, TU); - } - } - - // Now check for recursion, and check for patch constant function - // reachabililty Validation methods differ depending on whether this is a - // library shader or not. - - // TODO: make these error 'real' errors rather than on-the-fly things - // Validate that the entry point is available. - DiagnosticsEngine &Diags = self->getDiagnostics(); - FunctionDecl *pEntryPointDecl = nullptr; - std::vector FDeclsToCheck; - if (self->getLangOpts().IsHLSLLibrary) { - FDeclsToCheck = GetAllExportedFDecls(self); - } else { - const std::string &EntryPointName = self->getLangOpts().HLSLEntryFunction; - if (!EntryPointName.empty()) { - NameLookup NL = GetSingleFunctionDeclByName(self, EntryPointName, - /*checkPatch*/ false); - if (NL.Found && NL.Other) { - // NOTE: currently we cannot hit this codepath when CodeGen is enabled, - // because CodeGenModule::getMangledName will mangle the entry point - // name into the bare string, and so ambiguous points will produce an - // error earlier on. - unsigned id = - Diags.getCustomDiagID(clang::DiagnosticsEngine::Level::Error, - "ambiguous entry point function"); - Diags.Report(NL.Found->getSourceRange().getBegin(), id); - Diags.Report(NL.Other->getLocation(), diag::note_previous_definition); - return; - } - pEntryPointDecl = NL.Found; - if (!pEntryPointDecl || !pEntryPointDecl->hasBody()) { - unsigned id = - Diags.getCustomDiagID(clang::DiagnosticsEngine::Level::Error, - "missing entry point definition"); - Diags.Report(id); - return; - } - FDeclsToCheck.push_back(NL.Found); - } - } - - std::set DiagnosedDecls; - // for each FDecl, check for recursion - for (FunctionDecl *FDecl : FDeclsToCheck) { - FunctionDecl *result = ValidateNoRecursion(self, FDecl); - - if (result) { - // don't emit duplicate diagnostics for the same recursive function - // if A and B call recursive function C, only emit 1 diagnostic for C. - if (DiagnosedDecls.find(result) == DiagnosedDecls.end()) { - DiagnosedDecls.insert(result); - self->Diag(result->getSourceRange().getBegin(), - diag::err_hlsl_no_recursion) - << FDecl->getQualifiedNameAsString() - << result->getQualifiedNameAsString(); - self->Diag(result->getSourceRange().getBegin(), - diag::note_hlsl_no_recursion); - } - } - - FunctionDecl *pPatchFnDecl = nullptr; - if (const HLSLPatchConstantFuncAttr *attr = - FDecl->getAttr()) { - NameLookup NL = GetSingleFunctionDeclByName(self, attr->getFunctionName(), - /*checkPatch*/ true); - if (!NL.Found || !NL.Found->hasBody()) { - self->Diag(attr->getLocation(), - diag::err_hlsl_missing_patch_constant_function) - << attr->getFunctionName(); - } - pPatchFnDecl = NL.Found; - } - - if (pPatchFnDecl) { - FunctionDecl *patchResult = ValidateNoRecursion(self, pPatchFnDecl); - - // In this case, recursion was detected in the patch-constant function - if (patchResult) { - if (DiagnosedDecls.find(patchResult) == DiagnosedDecls.end()) { - DiagnosedDecls.insert(patchResult); - self->Diag(patchResult->getSourceRange().getBegin(), - diag::err_hlsl_no_recursion) - << pPatchFnDecl->getQualifiedNameAsString() - << patchResult->getQualifiedNameAsString(); - self->Diag(patchResult->getSourceRange().getBegin(), - diag::note_hlsl_no_recursion); - } - } - - // The patch function decl and the entry function decl should be - // disconnected with respect to the call graph. - // Only check this if neither function decl is recursive - if (!result && !patchResult) { - hlsl::CallGraphWithRecurseGuard CG; - CG.BuildForEntry(pPatchFnDecl); - if (CG.CheckReachability(pPatchFnDecl, FDecl)) { - self->Diag(FDecl->getSourceRange().getBegin(), - diag::err_hlsl_patch_reachability_not_allowed) - << 1 << FDecl->getName() << 0 << pPatchFnDecl->getName(); - } - CG.BuildForEntry(FDecl); - if (CG.CheckReachability(FDecl, pPatchFnDecl)) { - self->Diag(FDecl->getSourceRange().getBegin(), - diag::err_hlsl_patch_reachability_not_allowed) - << 0 << pPatchFnDecl->getName() << 1 << FDecl->getName(); - } - } - } - } -} - void hlsl::DiagnosePayloadAccessQualifierAnnotations( Sema &S, Declarator &D, const QualType &T, const std::vector &annotations) { @@ -15933,25 +15538,6 @@ void TryAddShaderAttrFromTargetProfile(Sema &S, FunctionDecl *FD, return; } -// in the non-library case, this function will be run only once, -// but in the library case, this function will be run for each -// viable top-level function declaration by -// ValidateNoRecursionInTranslationUnit. -// (viable as in, is exported) -clang::FunctionDecl *ValidateNoRecursion(clang::Sema *self, - clang::FunctionDecl *FD) { - // Validate that there is no recursion reachable by this function declaration - // NOTE: the information gathered here could be used to bypass code generation - // on functions that are unreachable (as an early form of dead code - // elimination). - if (FD) { - HLSLExternalSource *hlslSource = HLSLExternalSource::FromSema(self); - hlslSource->getCallGraph().BuildForEntry(FD); - return hlslSource->getCallGraph().CheckRecursion(FD); - } - return nullptr; -} - // The DiagnoseEntry function does 2 things: // 1. Determine whether this function is the current entry point for a // non-library compilation, add an implicit shader attribute if so. diff --git a/tools/clang/lib/Sema/SemaHLSLDiagnoseTU.cpp b/tools/clang/lib/Sema/SemaHLSLDiagnoseTU.cpp new file mode 100644 index 0000000000..935032e165 --- /dev/null +++ b/tools/clang/lib/Sema/SemaHLSLDiagnoseTU.cpp @@ -0,0 +1,426 @@ +/////////////////////////////////////////////////////////////////////////////// +// // +// SemaHLSLDiagnoseTU.cpp // +// Copyright (C) Microsoft Corporation. All rights reserved. // +// This file is distributed under the University of Illinois Open Source // +// License. See LICENSE.TXT for details. // +// // +// This file implements the Translation Unit Diagnose for HLSL. // +// // +/////////////////////////////////////////////////////////////////////////////// + +#include "dxc/DXIL/DxilShaderModel.h" +#include "dxc/Support/Global.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/Decl.h" +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/Sema/SemaHLSL.h" +#include "llvm/Support/Debug.h" + +using namespace clang; +using namespace hlsl; + +// +// This is similar to clang/Analysis/CallGraph, but the following differences +// motivate this: +// +// - track traversed vs. observed nodes explicitly +// - fully visit all reachable functions +// - merge graph visiting with checking for recursion +// - track global variables and types used (NYI) +// +namespace { +struct CallNode { + FunctionDecl *CallerFn; + ::llvm::SmallPtrSet CalleeFns; +}; +typedef ::llvm::DenseMap CallNodes; +typedef ::llvm::SmallPtrSet FnCallStack; +typedef ::llvm::SmallPtrSet FunctionSet; +typedef ::llvm::SmallVector PendingFunctions; +typedef ::llvm::DenseMap FunctionMap; + +// Returns the definition of a function. +// This serves two purposes - ignore built-in functions, and pick +// a single Decl * to be used in maps and sets. +FunctionDecl *getFunctionWithBody(FunctionDecl *F) { + if (!F) + return nullptr; + if (F->doesThisDeclarationHaveABody()) + return F; + F = F->getFirstDecl(); + for (auto &&Candidate : F->redecls()) { + if (Candidate->doesThisDeclarationHaveABody()) { + return Candidate; + } + } + return nullptr; +} + +// AST visitor that maintains visited and pending collections, as well +// as recording nodes of caller/callees. +class FnReferenceVisitor : public RecursiveASTVisitor { +private: + CallNodes &m_callNodes; + FunctionSet &m_visitedFunctions; + PendingFunctions &m_pendingFunctions; + FunctionDecl *m_source; + CallNodes::iterator m_sourceIt; + +public: + FnReferenceVisitor(FunctionSet &visitedFunctions, + PendingFunctions &pendingFunctions, CallNodes &callNodes) + : m_callNodes(callNodes), m_visitedFunctions(visitedFunctions), + m_pendingFunctions(pendingFunctions) {} + + void setSourceFn(FunctionDecl *F) { + F = getFunctionWithBody(F); + m_source = F; + m_sourceIt = m_callNodes.find(F); + } + + bool VisitDeclRefExpr(DeclRefExpr *ref) { + ValueDecl *valueDecl = ref->getDecl(); + RecordFunctionDecl(dyn_cast_or_null(valueDecl)); + return true; + } + + bool VisitCXXMemberCallExpr(CXXMemberCallExpr *callExpr) { + RecordFunctionDecl(callExpr->getMethodDecl()); + return true; + } + + void RecordFunctionDecl(FunctionDecl *funcDecl) { + funcDecl = getFunctionWithBody(funcDecl); + if (funcDecl) { + if (m_sourceIt == m_callNodes.end()) { + auto result = m_callNodes.insert( + std::make_pair(m_source, CallNode{m_source, {}})); + DXASSERT(result.second == true, + "else setSourceFn didn't assign m_sourceIt"); + m_sourceIt = result.first; + } + m_sourceIt->second.CalleeFns.insert(funcDecl); + if (!m_visitedFunctions.count(funcDecl)) { + m_pendingFunctions.push_back(funcDecl); + } + } + } +}; + +// A call graph that can check for reachability and recursion efficiently. +class CallGraphWithRecurseGuard { +private: + CallNodes m_callNodes; + FunctionSet m_visitedFunctions; + FunctionMap m_functionsCheckedForRecursion; + + FunctionDecl *CheckRecursion(FnCallStack &CallStack, FunctionDecl *D) { + auto it = m_functionsCheckedForRecursion.find(D); + if (it != m_functionsCheckedForRecursion.end()) + return it->second; + if (CallStack.insert(D).second == false) + return D; + auto node = m_callNodes.find(D); + if (node != m_callNodes.end()) { + for (FunctionDecl *Callee : node->second.CalleeFns) { + FunctionDecl *pResult = CheckRecursion(CallStack, Callee); + if (pResult) { + m_functionsCheckedForRecursion[D] = pResult; + return pResult; + } + } + } + CallStack.erase(D); + m_functionsCheckedForRecursion[D] = nullptr; + return nullptr; + } + +public: + void BuildForEntry(FunctionDecl *EntryFnDecl) { + DXASSERT_NOMSG(EntryFnDecl); + EntryFnDecl = getFunctionWithBody(EntryFnDecl); + PendingFunctions pendingFunctions; + FnReferenceVisitor visitor(m_visitedFunctions, pendingFunctions, + m_callNodes); + pendingFunctions.push_back(EntryFnDecl); + while (!pendingFunctions.empty()) { + FunctionDecl *pendingDecl = pendingFunctions.pop_back_val(); + if (m_visitedFunctions.insert(pendingDecl).second == true) { + visitor.setSourceFn(pendingDecl); + visitor.TraverseDecl(pendingDecl); + } + } + } + + // return true if FD2 is reachable from FD1 + bool CheckReachability(FunctionDecl *FD1, FunctionDecl *FD2) { + if (FD1 == FD2) + return true; + auto node = m_callNodes.find(FD1); + if (node != m_callNodes.end()) { + for (FunctionDecl *Callee : node->second.CalleeFns) { + if (CheckReachability(Callee, FD2)) + return true; + } + } + return false; + } + + FunctionDecl *CheckRecursion(FunctionDecl *EntryFnDecl) { + FnCallStack CallStack; + EntryFnDecl = getFunctionWithBody(EntryFnDecl); + return CheckRecursion(CallStack, EntryFnDecl); + } + + const CallNodes &GetCallGraph() { return m_callNodes; } + + void dump() const { + llvm::dbgs() << "Call Nodes:\n"; + for (auto &node : m_callNodes) { + llvm::dbgs() << node.first->getName().str().c_str() << " [" + << (void *)node.first << "]:\n"; + for (auto callee : node.second.CalleeFns) { + llvm::dbgs() << " " << callee->getName().str().c_str() << " [" + << (void *)callee << "]\n"; + } + } + } +}; + +struct NameLookup { + FunctionDecl *Found; + FunctionDecl *Other; +}; + +NameLookup GetSingleFunctionDeclByName(clang::Sema *self, StringRef Name, + bool checkPatch) { + auto DN = DeclarationName(&self->getASTContext().Idents.get(Name)); + FunctionDecl *pFoundDecl = nullptr; + for (auto idIter = self->IdResolver.begin(DN), idEnd = self->IdResolver.end(); + idIter != idEnd; ++idIter) { + FunctionDecl *pFnDecl = dyn_cast(*idIter); + if (!pFnDecl) + continue; + if (checkPatch && + !self->getASTContext().IsPatchConstantFunctionDecl(pFnDecl)) + continue; + if (pFoundDecl) { + return NameLookup{pFoundDecl, pFnDecl}; + } + pFoundDecl = pFnDecl; + } + return NameLookup{pFoundDecl, nullptr}; +} + +bool IsTargetProfileLib6x(Sema &S) { + // Remaining functions are exported only if target is 'lib_6_x'. + const hlsl::ShaderModel *SM = + hlsl::ShaderModel::GetByName(S.getLangOpts().HLSLProfile.c_str()); + bool isLib6x = + SM->IsLib() && SM->GetMinor() == hlsl::ShaderModel::kOfflineMinor; + return isLib6x; +} + +bool IsExported(Sema *self, clang::FunctionDecl *FD, + bool isDefaultLinkageExternal) { + // Entry points are exported. + if (FD->hasAttr()) + return true; + + // Internal linkage functions include functions marked 'static'. + if (FD->getLinkageAndVisibility().getLinkage() == InternalLinkage) + return false; + + // Explicit 'export' functions are exported. + if (FD->hasAttr()) + return true; + + return isDefaultLinkageExternal; +} + +bool getDefaultLinkageExternal(clang::Sema *self) { + const LangOptions &opts = self->getLangOpts(); + bool isDefaultLinkageExternal = + opts.DefaultLinkage == DXIL::DefaultLinkage::External; + if (opts.DefaultLinkage == DXIL::DefaultLinkage::Default && + !opts.ExportShadersOnly && IsTargetProfileLib6x(*self)) + isDefaultLinkageExternal = true; + return isDefaultLinkageExternal; +} + +std::vector GetAllExportedFDecls(clang::Sema *self) { + // Add to the end, process from the beginning, to ensure AllExportedFDecls + // will contain functions in decl order. + std::vector AllExportedFDecls; + + std::deque Worklist; + Worklist.push_back(self->getASTContext().getTranslationUnitDecl()); + while (Worklist.size()) { + DeclContext *DC = Worklist.front(); + Worklist.pop_front(); + if (auto *FD = dyn_cast(DC)) { + AllExportedFDecls.push_back(FD); + } else { + for (auto *D : DC->decls()) { + if (auto *FD = dyn_cast(D)) { + if (FD->hasBody() && + IsExported(self, FD, getDefaultLinkageExternal(self))) + Worklist.push_back(FD); + } else if (auto *DC2 = dyn_cast(D)) { + Worklist.push_back(DC2); + } + } + } + } + + return AllExportedFDecls; +} + +// in the non-library case, this function will be run only once, +// but in the library case, this function will be run for each +// viable top-level function declaration by +// ValidateNoRecursionInTranslationUnit. +// (viable as in, is exported) +clang::FunctionDecl *ValidateNoRecursion(CallGraphWithRecurseGuard &callGraph, + clang::FunctionDecl *FD) { + // Validate that there is no recursion reachable by this function declaration + // NOTE: the information gathered here could be used to bypass code generation + // on functions that are unreachable (as an early form of dead code + // elimination). + if (FD) { + callGraph.BuildForEntry(FD); + return callGraph.CheckRecursion(FD); + } + return nullptr; +} + +} // namespace + +void hlsl::DiagnoseTranslationUnit(clang::Sema *self) { + DXASSERT_NOMSG(self != nullptr); + + // Don't bother with global validation if compilation has already failed. + if (self->getDiagnostics().hasErrorOccurred()) { + return; + } + + // Check RT shader if available for their payload use and match payload access + // against availiable payload modifiers. + // We have to do it late because we could have payload access in a called + // function and have to check the callgraph if the root shader has the right + // access rights to the payload structure. + if (self->getLangOpts().IsHLSLLibrary) { + if (self->getLangOpts().EnablePayloadAccessQualifiers) { + ASTContext &ctx = self->getASTContext(); + TranslationUnitDecl *TU = ctx.getTranslationUnitDecl(); + DiagnoseRaytracingPayloadAccess(*self, TU); + } + } + + // TODO: make these error 'real' errors rather than on-the-fly things + // Validate that the entry point is available. + DiagnosticsEngine &Diags = self->getDiagnostics(); + FunctionDecl *pEntryPointDecl = nullptr; + std::vector FDeclsToCheck; + if (self->getLangOpts().IsHLSLLibrary) { + FDeclsToCheck = GetAllExportedFDecls(self); + } else { + const std::string &EntryPointName = self->getLangOpts().HLSLEntryFunction; + if (!EntryPointName.empty()) { + NameLookup NL = GetSingleFunctionDeclByName(self, EntryPointName, + /*checkPatch*/ false); + if (NL.Found && NL.Other) { + // NOTE: currently we cannot hit this codepath when CodeGen is enabled, + // because CodeGenModule::getMangledName will mangle the entry point + // name into the bare string, and so ambiguous points will produce an + // error earlier on. + unsigned id = + Diags.getCustomDiagID(clang::DiagnosticsEngine::Level::Error, + "ambiguous entry point function"); + Diags.Report(NL.Found->getSourceRange().getBegin(), id); + Diags.Report(NL.Other->getLocation(), diag::note_previous_definition); + return; + } + pEntryPointDecl = NL.Found; + if (!pEntryPointDecl || !pEntryPointDecl->hasBody()) { + unsigned id = + Diags.getCustomDiagID(clang::DiagnosticsEngine::Level::Error, + "missing entry point definition"); + Diags.Report(id); + return; + } + FDeclsToCheck.push_back(NL.Found); + } + } + + CallGraphWithRecurseGuard callGraph; + std::set DiagnosedDecls; + // for each FDecl, check for recursion + for (FunctionDecl *FDecl : FDeclsToCheck) { + FunctionDecl *result = ValidateNoRecursion(callGraph, FDecl); + + if (result) { + // don't emit duplicate diagnostics for the same recursive function + // if A and B call recursive function C, only emit 1 diagnostic for C. + if (DiagnosedDecls.find(result) == DiagnosedDecls.end()) { + DiagnosedDecls.insert(result); + self->Diag(result->getSourceRange().getBegin(), + diag::err_hlsl_no_recursion) + << FDecl->getQualifiedNameAsString() + << result->getQualifiedNameAsString(); + self->Diag(result->getSourceRange().getBegin(), + diag::note_hlsl_no_recursion); + } + } + + FunctionDecl *pPatchFnDecl = nullptr; + if (const HLSLPatchConstantFuncAttr *attr = + FDecl->getAttr()) { + NameLookup NL = GetSingleFunctionDeclByName(self, attr->getFunctionName(), + /*checkPatch*/ true); + if (!NL.Found || !NL.Found->hasBody()) { + self->Diag(attr->getLocation(), + diag::err_hlsl_missing_patch_constant_function) + << attr->getFunctionName(); + } + pPatchFnDecl = NL.Found; + } + + if (pPatchFnDecl) { + FunctionDecl *patchResult = ValidateNoRecursion(callGraph, pPatchFnDecl); + + // In this case, recursion was detected in the patch-constant function + if (patchResult) { + if (DiagnosedDecls.find(patchResult) == DiagnosedDecls.end()) { + DiagnosedDecls.insert(patchResult); + self->Diag(patchResult->getSourceRange().getBegin(), + diag::err_hlsl_no_recursion) + << pPatchFnDecl->getQualifiedNameAsString() + << patchResult->getQualifiedNameAsString(); + self->Diag(patchResult->getSourceRange().getBegin(), + diag::note_hlsl_no_recursion); + } + } + + // The patch function decl and the entry function decl should be + // disconnected with respect to the call graph. + // Only check this if neither function decl is recursive + if (!result && !patchResult) { + CallGraphWithRecurseGuard CG; + CG.BuildForEntry(pPatchFnDecl); + if (CG.CheckReachability(pPatchFnDecl, FDecl)) { + self->Diag(FDecl->getSourceRange().getBegin(), + diag::err_hlsl_patch_reachability_not_allowed) + << 1 << FDecl->getName() << 0 << pPatchFnDecl->getName(); + } + CG.BuildForEntry(FDecl); + if (CG.CheckReachability(FDecl, pPatchFnDecl)) { + self->Diag(FDecl->getSourceRange().getBegin(), + diag::err_hlsl_patch_reachability_not_allowed) + << 0 << pPatchFnDecl->getName() << 1 << FDecl->getName(); + } + } + } + } +}