Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add two new tail strategies for update definitions #7949

Merged
merged 10 commits into from
Dec 5, 2023
15 changes: 15 additions & 0 deletions src/ApplySplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,21 @@ vector<ApplySplitResult> apply_split(const Split &split, bool is_update, const s
// non-trivial loop.
base = likely_if_innermost(base);
base = Min::make(base, old_max + (1 - split.factor));
} else if (tail == TailStrategy::ShiftInwardsAndBlend) {
Expr old_base = base;
base = likely(base);
base = Min::make(base, old_max + (1 - split.factor));
// Make a mask which will be a loop invariant if inner gets
// vectorized, and apply it if we're in the tail.
Expr unwanted_elems = (-old_extent) % split.factor;
Expr mask = inner >= unwanted_elems;
mask = select(base == old_base, likely(const_true()), mask);
result.emplace_back(mask, ApplySplitResult::BlendProvides);
} else if (tail == TailStrategy::RoundUpAndBlend) {
Expr unwanted_elems = (-old_extent) % split.factor;
Expr mask = inner < split.factor - unwanted_elems;
mask = select(outer < outer_max, likely(const_true()), mask);
result.emplace_back(mask, ApplySplitResult::BlendProvides);
} else {
internal_assert(tail == TailStrategy::RoundUp);
}
Expand Down
6 changes: 5 additions & 1 deletion src/ApplySplit.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ struct ApplySplitResult {
LetStmt,
PredicateCalls,
PredicateProvides,
Predicate };
Predicate,
BlendProvides };
Type type;

ApplySplitResult(const std::string &n, Expr val, Type t)
Expand Down Expand Up @@ -67,6 +68,9 @@ struct ApplySplitResult {
bool is_predicate_provides() const {
return (type == PredicateProvides);
}
bool is_blend_provides() const {
return (type == BlendProvides);
}
};

/** Given a Split schedule on a definition (init or update), return a list of
Expand Down
4 changes: 4 additions & 0 deletions src/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,10 @@ TailStrategy Deserializer::deserialize_tail_strategy(Serialize::TailStrategy tai
return TailStrategy::PredicateStores;
case Serialize::TailStrategy::ShiftInwards:
return TailStrategy::ShiftInwards;
case Serialize::TailStrategy::ShiftInwardsAndBlend:
return TailStrategy::ShiftInwardsAndBlend;
case Serialize::TailStrategy::RoundUpAndBlend:
return TailStrategy::RoundUpAndBlend;
case Serialize::TailStrategy::Auto:
return TailStrategy::Auto;
default:
Expand Down
82 changes: 82 additions & 0 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,79 @@ bool is_const_assignment(const string &func_name, const vector<Expr> &args, cons
rhs_checker.has_self_reference ||
rhs_checker.has_rvar);
}

void check_for_race_conditions_in_split_with_blend(const StageSchedule &sched) {
// Splits with a 'blend' tail strategy do a load and then a store of values
// outside of the region to be computed, so for each split using a 'blend'
// tail strategy, verify that there aren't any parallel vars that stem from
// the same original dimension, so that this load and store doesn't race
// with a true computation of that value happening in some other thread.

// Note that we only need to check vars in the same dimension, because
// allocation bounds inference is done per-dimension and allocates padding
// based on the values actually accessed by the lowered code (i.e. it covers
// the blend region). So for example, an access beyond the end of a scanline
// can't overflow onto the next scanline. Halide will allocate padding, or
// throw a bounds error if it's an input or output.

if (sched.allow_race_conditions()) {
return;
}

std::set<std::string> parallel;
for (const auto &dim : sched.dims()) {
if (is_unordered_parallel(dim.for_type)) {
parallel.insert(dim.var);
}
}

// Process the splits in reverse order to figure out which root vars have a
// parallel child.
for (auto it = sched.splits().rbegin(); it != sched.splits().rend(); it++) {
if (it->is_fuse()) {
if (parallel.count(it->old_var)) {
parallel.insert(it->inner);
parallel.insert(it->old_var);
}
} else if (it->is_rename() || it->is_purify()) {
if (parallel.count(it->outer)) {
parallel.insert(it->old_var);
}
} else {
if (parallel.count(it->inner) || parallel.count(it->outer)) {
parallel.insert(it->old_var);
}
}
}

// Now propagate back to all children of the identified root vars, to assert
// that none of them use a blending tail strategy.
for (auto it = sched.splits().begin(); it != sched.splits().end(); it++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to add a similar comment to above that this loop propagates back (thus all of the if conditions and bodies are reversed) from the root to the child vars which of them are connected to parallel var.

if (it->is_fuse()) {
if (parallel.count(it->inner) || parallel.count(it->outer)) {
parallel.insert(it->old_var);
}
} else if (it->is_rename() || it->is_purify()) {
if (parallel.count(it->old_var)) {
parallel.insert(it->outer);
}
} else {
if (parallel.count(it->old_var)) {
parallel.insert(it->inner);
parallel.insert(it->old_var);
if (it->tail == TailStrategy::ShiftInwardsAndBlend ||
it->tail == TailStrategy::RoundUpAndBlend) {
user_error << "Tail strategy " << it->tail
<< " may not be used to split " << it->old_var
<< " because other vars stemming from the same original "
<< "Var or RVar are marked as parallel."
<< "This could cause a race condition.\n";
}
}
}
}
}

} // namespace

void Stage::set_dim_type(const VarOrRVar &var, ForType t) {
Expand Down Expand Up @@ -439,6 +512,10 @@ void Stage::set_dim_type(const VarOrRVar &var, ForType t) {
<< " in vars for function\n"
<< dump_argument_list();
}

if (is_unordered_parallel(t)) {
check_for_race_conditions_in_split_with_blend(definition.schedule());
}
}

void Stage::set_dim_device_api(const VarOrRVar &var, DeviceAPI device_api) {
Expand Down Expand Up @@ -1171,6 +1248,11 @@ void Stage::split(const string &old, const string &outer, const string &inner, c
}
}

if (tail == TailStrategy::ShiftInwardsAndBlend ||
tail == TailStrategy::RoundUpAndBlend) {
check_for_race_conditions_in_split_with_blend(definition.schedule());
}

if (!definition.is_init()) {
user_assert(tail != TailStrategy::ShiftInwards)
<< "When splitting Var " << old_name
Expand Down
6 changes: 6 additions & 0 deletions src/IRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ std::ostream &operator<<(std::ostream &out, const TailStrategy &t) {
case TailStrategy::RoundUp:
out << "RoundUp";
break;
case TailStrategy::ShiftInwardsAndBlend:
out << "ShiftInwardsAndBlend";
break;
case TailStrategy::RoundUpAndBlend:
out << "RoundUpAndBlend";
break;
}
return out;
}
Expand Down
26 changes: 26 additions & 0 deletions src/Schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,32 @@ enum class TailStrategy {
* instead of a multiple of the split factor as with RoundUp. */
ShiftInwards,

/** Equivalent to ShiftInwards, but protects values that would be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really think we need to have some kind of tutorial for all of the different variants of TailStrategy: there are a lot of them and some of them are pretty subtle, so it's not really obvious how exactly they work.

No need to block PR on this though, but maybe would be good to create an issue.

* re-evaluated by loading the memory location that would be stored to,
* modifying only the elements not contained within the overlap, and then
* storing the blended result.
*
* This tail strategy is useful when you want to use ShiftInwards to
* vectorize without a scalar tail, but are scheduling a stage where that
* isn't legal (e.g. an update definition).
*
* Because this is a read - modify - write, this tail strategy cannot be
* used on any dimension the stage is parallelized over as it would cause a
* race condition.
*/
ShiftInwardsAndBlend,

/** Equivalent to RoundUp, but protected values that would be written beyond
* the end by loading the memory location that would be stored to,
* modifying only the elements within the region being computed, and then
* storing the blended result.
*
* This tail strategy is useful when vectorizing an update to some sub-region
* of a larger Func. As with ShiftInwardsAndBlend, it can't be combined with
* parallelism.
*/
RoundUpAndBlend,

/** For pure definitions use ShiftInwards. For pure vars in
* update definitions use RoundUp. For RVars in update
* definitions use GuardWithIf. */
Expand Down
30 changes: 18 additions & 12 deletions src/ScheduleFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,22 @@ Stmt substitute_in(const string &name, const Expr &value, bool calls, bool provi

class AddPredicates : public IRGraphMutator {
const Expr &cond;
bool calls;
bool provides;
const Function &func;
ApplySplitResult::Type type;

using IRMutator::visit;

Stmt visit(const Provide *p) override {
auto [args, changed_args] = mutate_with_changes(p->args);
auto [values, changed_values] = mutate_with_changes(p->values);
Expr predicate = mutate(p->predicate);
if (provides) {
if (type == ApplySplitResult::BlendProvides) {
int idx = 0;
for (Expr &v : values) {
v = select(cond, v, Call::make(func, args, idx++));
}
return Provide::make(p->name, values, args, predicate);
} else if (type == ApplySplitResult::PredicateProvides) {
return Provide::make(p->name, values, args, predicate && cond);
} else if (changed_args || changed_values || !predicate.same_as(p->predicate)) {
return Provide::make(p->name, values, args, predicate);
Expand All @@ -146,20 +152,20 @@ class AddPredicates : public IRGraphMutator {

Expr visit(const Call *op) override {
Expr result = IRMutator::visit(op);
if (calls && op->call_type == Call::Halide) {
if (type == ApplySplitResult::PredicateCalls && op->call_type == Call::Halide) {
result = Call::make(op->type, Call::if_then_else, {cond, result}, Call::PureIntrinsic);
}
return result;
}

public:
AddPredicates(const Expr &cond, bool calls, bool provides)
: cond(cond), calls(calls), provides(provides) {
AddPredicates(const Expr &cond, const Function &func, ApplySplitResult::Type type)
: cond(cond), func(func), type(type) {
}
};

Stmt add_predicates(const Expr &cond, bool calls, bool provides, const Stmt &s) {
return AddPredicates(cond, calls, provides).mutate(s);
Stmt add_predicates(const Expr &cond, const Function &func, ApplySplitResult::Type type, const Stmt &s) {
return AddPredicates(cond, func, type).mutate(s);
}

// Build a loop nest about a provide node using a schedule
Expand Down Expand Up @@ -227,10 +233,10 @@ Stmt build_loop_nest(
stmt = substitute_in(res.name, res.value, true, false, stmt);
} else if (res.is_substitution_in_provides()) {
stmt = substitute_in(res.name, res.value, false, true, stmt);
} else if (res.is_predicate_calls()) {
stmt = add_predicates(res.value, true, false, stmt);
} else if (res.is_predicate_provides()) {
stmt = add_predicates(res.value, false, true, stmt);
} else if (res.is_blend_provides() ||
res.is_predicate_calls() ||
res.is_predicate_provides()) {
stmt = add_predicates(res.value, func, res.type, stmt);
} else if (res.is_let()) {
stmt = LetStmt::make(res.name, res.value, stmt);
} else {
Expand Down
4 changes: 4 additions & 0 deletions src/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ Serialize::TailStrategy Serializer::serialize_tail_strategy(const TailStrategy &
return Serialize::TailStrategy::PredicateStores;
case TailStrategy::ShiftInwards:
return Serialize::TailStrategy::ShiftInwards;
case TailStrategy::ShiftInwardsAndBlend:
return Serialize::TailStrategy::ShiftInwardsAndBlend;
case TailStrategy::RoundUpAndBlend:
return Serialize::TailStrategy::RoundUpAndBlend;
case TailStrategy::Auto:
return Serialize::TailStrategy::Auto;
default:
Expand Down
2 changes: 2 additions & 0 deletions src/halide_ir.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ enum TailStrategy: ubyte {
PredicateLoads,
PredicateStores,
ShiftInwards,
ShiftInwardsAndBlend,
RoundUpAndBlend,
Auto,
}

Expand Down
30 changes: 26 additions & 4 deletions test/correctness/nested_tail_strategies.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ void my_free(JITUserContext *user_context, void *ptr) {
void check(Func out, int line, std::vector<TailStrategy> tails) {
bool has_round_up =
std::find(tails.begin(), tails.end(), TailStrategy::RoundUp) != tails.end() ||
std::find(tails.begin(), tails.end(), TailStrategy::RoundUpAndBlend) != tails.end() ||
std::find(tails.begin(), tails.end(), TailStrategy::PredicateLoads) != tails.end() ||
std::find(tails.begin(), tails.end(), TailStrategy::PredicateStores) != tails.end();
bool has_shift_inwards =
std::find(tails.begin(), tails.end(), TailStrategy::ShiftInwards) != tails.end();
std::find(tails.begin(), tails.end(), TailStrategy::ShiftInwards) != tails.end() ||
std::find(tails.begin(), tails.end(), TailStrategy::ShiftInwardsAndBlend) != tails.end();

std::vector<int> sizes_to_try;

Expand Down Expand Up @@ -68,6 +70,12 @@ int main(int argc, char **argv) {
return 0;
}

// We'll randomly subsample these tests, because otherwise there are too many of them.
std::mt19937 rng(0);
int seed = argc > 1 ? atoi(argv[1]) : time(nullptr);
rng.seed(seed);
std::cout << "Nested tail strategies seed: " << seed << "\n";

// Test random compositions of tail strategies in simple
// producer-consumer pipelines. The bounds being tight sometimes
// depends on the simplifier being able to cancel out things.
Expand All @@ -76,15 +84,17 @@ int main(int argc, char **argv) {
TailStrategy::RoundUp,
TailStrategy::GuardWithIf,
TailStrategy::ShiftInwards,
};
TailStrategy::RoundUpAndBlend,
TailStrategy::ShiftInwardsAndBlend};

TailStrategy innermost_tails[] = {
TailStrategy::RoundUp,
TailStrategy::GuardWithIf,
TailStrategy::PredicateLoads,
TailStrategy::PredicateStores,
TailStrategy::ShiftInwards,
};
TailStrategy::RoundUpAndBlend,
TailStrategy::ShiftInwardsAndBlend};

// Two stages. First stage computed at tiles of second.
for (auto t1 : innermost_tails) {
Expand All @@ -110,6 +120,10 @@ int main(int argc, char **argv) {
for (auto t1 : innermost_tails) {
for (auto t2 : innermost_tails) {
for (auto t3 : innermost_tails) {
if ((rng() & 7) != 0) {
continue;
}

Func in("in"), f("f"), g("g"), h("h");
Var x;

Expand All @@ -134,6 +148,10 @@ int main(int argc, char **argv) {
for (auto t1 : tails) {
for (auto t2 : innermost_tails) {
for (auto t3 : innermost_tails) {
if ((rng() & 7) != 0) {
continue;
}

Func in, f, g, h;
Var x;

Expand All @@ -158,8 +176,12 @@ int main(int argc, char **argv) {
// (but can handle smaller outputs).
for (auto t1 : innermost_tails) {
for (auto t2 : tails) {
for (auto t3 : tails) { // Not innermost_tails because of n^4 complexity here.
for (auto t3 : innermost_tails) {
for (auto t4 : tails) {
if ((rng() & 63) != 0) {
continue;
}

Func in("in"), f("f"), g("g"), h("h");
Var x;

Expand Down
2 changes: 2 additions & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ tests(GROUPS error
reuse_var_in_schedule.cpp
reused_args.cpp
rfactor_inner_dim_non_commutative.cpp
round_up_and_blend_race.cpp
run_with_large_stack_throws.cpp
shift_inwards_and_blend_race.cpp
specialize_fail.cpp
split_inner_wrong_tail_strategy.cpp
split_non_innermost_predicated.cpp
Expand Down
Loading