Skip to content

Commit

Permalink
Eliminate the predicate of the consumer of circular buffer in compute…
Browse files Browse the repository at this point in the history
… warp (#3545)

This PR unblocks `MmaOp` from being used in warp specialized kernel.
  • Loading branch information
zasdfgbnm authored Dec 10, 2024
1 parent 1978cf4 commit 829f879
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 10 deletions.
56 changes: 52 additions & 4 deletions csrc/device_lower/analysis/predicate_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,46 @@ void assertOnWarpOps(const Expr* expr) {

namespace {

// Check if consumer is in the compute warp of a warp specialized loop,
// and the id_in_consumer is the parallel type of the warp specialization.
bool isComputeWarp(TensorView* consumer, IterDomain* id_in_consumer) {
// TODO: This function can not find all the expressions in the compute
// warp. For example, if we have:
// if (load warp) {
// T1 = T0;
// } else {
// T2 = T1;
// T3 = T2;
// }
// then we will return false for T3, which is a false negative. Having
// a false negative is fine in the sense that we will still be
// functionally correct, but we will not be able to remove the predicate
// around T3, which is a missed optimization opportunity.
// For now, because warp specialization is only used for matmul, for
// which the circular buffer loop is a reduction loop, and mma is the
// only expr in the compute warp, we are fine. In the future, we might
// want to improve this function to find all the expressions in the
// compute warp, which will require a more sophisticated analysis.
auto def = consumer->definition();
if (def == nullptr) {
return false;
}
auto producer_tvs = ir_utils::filterByType<TensorView>(def->inputs());
if (producer_tvs.empty()) {
return false;
}
return std::all_of(
producer_tvs.begin(), producer_tvs.end(), [&](TensorView* producer_tv) {
if (!producer_tv->isCircularBuffered()) {
return false;
}
const auto& type = producer_tv->circularBufferOptions().type;
return std::holds_alternative<WarpSpecialized>(type) &&
std::get<WarpSpecialized>(type).on ==
id_in_consumer->getParallelType();
});
}

// Utility to check if the scheduled domain of the given
// TensorView represent an exact shared mem access, meaning
// that all the thread parallel dimensions on the loop nodes
Expand All @@ -53,7 +93,8 @@ bool isExactParallelSharedMemAccess(TensorView* tv) {
if (id->isThreadDim()) {
// Need to predicate to avoid out of bound access
// because of over-subscribed block size.
if (!lower_utils::isExtentEqualToMaxParallelTypeExtent(id)) {
if (!lower_utils::isExtentEqualToMaxParallelTypeExtent(
id, isComputeWarp(tv, id))) {
return false;
}
}
Expand Down Expand Up @@ -235,7 +276,8 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch {
alloc_to_loop_groups.insert(
indexing_groups.begin(), indexing_groups.end());
}
ProducerConsumerPairAnalyzer analyzer(c2p, graph, alloc_to_loop_groups);
ProducerConsumerPairAnalyzer analyzer(
consumer, c2p, graph, alloc_to_loop_groups);

for (auto id : consumer->getLoopDomain()) {
if (analyzer.needsPredicate(id)) {
Expand All @@ -248,10 +290,14 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch {

private:
ProducerConsumerPairAnalyzer(
TensorView* consumer,
const std::unordered_map<IterDomain*, IterDomain*>& c2p,
const ValGraph* graph,
const std::unordered_set<ValGroup> alloc_to_loop_groups)
: c2p_(c2p), graph_(graph), alloc_to_loop_groups_(alloc_to_loop_groups) {}
: consumer_(consumer),
c2p_(c2p),
graph_(graph),
alloc_to_loop_groups_(alloc_to_loop_groups) {}

// Returns true if no out-of-bound accesses could occur with a
// producer
Expand Down Expand Up @@ -286,7 +332,8 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch {
// consumer ID may be oversubscribed, which may cause
// out-of-bounds accesses in the producer
const auto maybe_oversubscribed = consumer_id->isThread() &&
(!lower_utils::isExtentEqualToMaxParallelTypeExtent(consumer_id));
(!lower_utils::isExtentEqualToMaxParallelTypeExtent(
consumer_id, isComputeWarp(consumer_, consumer_id)));
if (maybe_oversubscribed) {
// If oversubscribed, there must be a mapped producer ID that is
// parallelized in the same way. Otherwise, needs to be
Expand Down Expand Up @@ -354,6 +401,7 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch {
}

private:
TensorView* consumer_ = nullptr;
//! BestEffort map from consumer IDs to producer IDs
const std::unordered_map<IterDomain*, IterDomain*>& c2p_;
bool needs_predicate_ = false;
Expand Down
11 changes: 9 additions & 2 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -831,9 +831,16 @@ bool isScalarExpr(Expr* expr) {
return true;
}

bool isExtentEqualToMaxParallelTypeExtent(const IterDomain* id) {
bool isExtentEqualToMaxParallelTypeExtent(
const IterDomain* id,
bool in_compute_warp) {
const auto& parallel_dim_map = GpuLower::current()->parallelDimensionMap();
auto* pdm_max_extent = parallel_dim_map.getRaw(id->getParallelType());
Val* pdm_max_extent = nullptr;
if (in_compute_warp) {
pdm_max_extent = parallel_dim_map.getRawCompute(id->getParallelType());
} else {
pdm_max_extent = parallel_dim_map.getRaw(id->getParallelType());
}
if (nullptr == pdm_max_extent) {
return false;
}
Expand Down
8 changes: 6 additions & 2 deletions csrc/device_lower/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,12 @@ bool isScalarExpr(Expr* expr);

//! Test if provided IterDomain instance has an extent that matches maximum
//! extent stored in parallel dimension map for parallel type of provided
//! IterDomain object.
bool isExtentEqualToMaxParallelTypeExtent(const IterDomain* id);
//! IterDomain object. `in_compute_warp` specifies we are checking an
//! expression in the compute warp, if so, we need to get the parallel type
//! extent of the compute warp, instead of the global parallel type extent.
bool isExtentEqualToMaxParallelTypeExtent(
const IterDomain* id,
bool in_compute_warp = false);

//! Get the uint32_t index of a scalar TensorView. This is usually used for
//! indexing special items in shared memory, like mbarrier.
Expand Down
10 changes: 8 additions & 2 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3657,6 +3657,7 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
const auto dtype = DataType::Half;

constexpr bool use_smem_epilogue = false;
constexpr bool use_warp_specialization = true;

constexpr int64_t stages = 4;
constexpr int64_t prefetch = 3;
Expand Down Expand Up @@ -3787,8 +3788,13 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {

inlineMost();

tv0c->circularBuffer(stages, prefetch);
tv1c->circularBuffer(stages, prefetch);
if (use_warp_specialization) {
tv0c->circularBuffer(stages, prefetch, WarpSpecialized(ParallelType::TIDy));
tv1c->circularBuffer(stages, prefetch, WarpSpecialized(ParallelType::TIDy));
} else {
tv0c->circularBuffer(stages, prefetch);
tv1c->circularBuffer(stages, prefetch);
}

auto inputs =
matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));
Expand Down

0 comments on commit 829f879

Please sign in to comment.