diff --git a/src/ApplySplit.cpp b/src/ApplySplit.cpp index 7bde69a38e94..48d27b1ffc7a 100644 --- a/src/ApplySplit.cpp +++ b/src/ApplySplit.cpp @@ -107,6 +107,21 @@ vector apply_split(const Split &split, bool is_update, const s // non-trivial loop. base = likely_if_innermost(base); base = Min::make(base, old_max + (1 - split.factor)); + } else if (tail == TailStrategy::ShiftInwardsAndBlend) { + Expr old_base = base; + base = likely(base); + base = Min::make(base, old_max + (1 - split.factor)); + // Make a mask which will be a loop invariant if inner gets + // vectorized, and apply it if we're in the tail. + Expr unwanted_elems = (-old_extent) % split.factor; + Expr mask = inner >= unwanted_elems; + mask = select(base == old_base, likely(const_true()), mask); + result.emplace_back(mask, ApplySplitResult::BlendProvides); + } else if (tail == TailStrategy::RoundUpAndBlend) { + Expr unwanted_elems = (-old_extent) % split.factor; + Expr mask = inner < split.factor - unwanted_elems; + mask = select(outer < outer_max, likely(const_true()), mask); + result.emplace_back(mask, ApplySplitResult::BlendProvides); } else { internal_assert(tail == TailStrategy::RoundUp); } diff --git a/src/ApplySplit.h b/src/ApplySplit.h index 61774733b02b..5e646b22f08b 100644 --- a/src/ApplySplit.h +++ b/src/ApplySplit.h @@ -36,7 +36,8 @@ struct ApplySplitResult { LetStmt, PredicateCalls, PredicateProvides, - Predicate }; + Predicate, + BlendProvides }; Type type; ApplySplitResult(const std::string &n, Expr val, Type t) @@ -67,6 +68,9 @@ struct ApplySplitResult { bool is_predicate_provides() const { return (type == PredicateProvides); } + bool is_blend_provides() const { + return (type == BlendProvides); + } }; /** Given a Split schedule on a definition (init or update), return a list of diff --git a/src/Deserialization.cpp b/src/Deserialization.cpp index c0e9f39de7bf..1f8d5f491ad9 100644 --- a/src/Deserialization.cpp +++ b/src/Deserialization.cpp @@ -350,6 +350,10 @@ TailStrategy Deserializer::deserialize_tail_strategy(Serialize::TailStrategy tai return TailStrategy::PredicateStores; case Serialize::TailStrategy::ShiftInwards: return TailStrategy::ShiftInwards; + case Serialize::TailStrategy::ShiftInwardsAndBlend: + return TailStrategy::ShiftInwardsAndBlend; + case Serialize::TailStrategy::RoundUpAndBlend: + return TailStrategy::RoundUpAndBlend; case Serialize::TailStrategy::Auto: return TailStrategy::Auto; default: diff --git a/src/Func.cpp b/src/Func.cpp index 37b64df5af5b..8f46e7316531 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -375,6 +375,79 @@ bool is_const_assignment(const string &func_name, const vector &args, cons rhs_checker.has_self_reference || rhs_checker.has_rvar); } + +void check_for_race_conditions_in_split_with_blend(const StageSchedule &sched) { + // Splits with a 'blend' tail strategy do a load and then a store of values + // outside of the region to be computed, so for each split using a 'blend' + // tail strategy, verify that there aren't any parallel vars that stem from + // the same original dimension, so that this load and store doesn't race + // with a true computation of that value happening in some other thread. + + // Note that we only need to check vars in the same dimension, because + // allocation bounds inference is done per-dimension and allocates padding + // based on the values actually accessed by the lowered code (i.e. it covers + // the blend region). So for example, an access beyond the end of a scanline + // can't overflow onto the next scanline. Halide will allocate padding, or + // throw a bounds error if it's an input or output. + + if (sched.allow_race_conditions()) { + return; + } + + std::set parallel; + for (const auto &dim : sched.dims()) { + if (is_unordered_parallel(dim.for_type)) { + parallel.insert(dim.var); + } + } + + // Process the splits in reverse order to figure out which root vars have a + // parallel child. + for (auto it = sched.splits().rbegin(); it != sched.splits().rend(); it++) { + if (it->is_fuse()) { + if (parallel.count(it->old_var)) { + parallel.insert(it->inner); + parallel.insert(it->old_var); + } + } else if (it->is_rename() || it->is_purify()) { + if (parallel.count(it->outer)) { + parallel.insert(it->old_var); + } + } else { + if (parallel.count(it->inner) || parallel.count(it->outer)) { + parallel.insert(it->old_var); + } + } + } + + // Now propagate back to all children of the identified root vars, to assert + // that none of them use a blending tail strategy. + for (auto it = sched.splits().begin(); it != sched.splits().end(); it++) { + if (it->is_fuse()) { + if (parallel.count(it->inner) || parallel.count(it->outer)) { + parallel.insert(it->old_var); + } + } else if (it->is_rename() || it->is_purify()) { + if (parallel.count(it->old_var)) { + parallel.insert(it->outer); + } + } else { + if (parallel.count(it->old_var)) { + parallel.insert(it->inner); + parallel.insert(it->old_var); + if (it->tail == TailStrategy::ShiftInwardsAndBlend || + it->tail == TailStrategy::RoundUpAndBlend) { + user_error << "Tail strategy " << it->tail + << " may not be used to split " << it->old_var + << " because other vars stemming from the same original " + << "Var or RVar are marked as parallel." + << "This could cause a race condition.\n"; + } + } + } + } +} + } // namespace void Stage::set_dim_type(const VarOrRVar &var, ForType t) { @@ -439,6 +512,10 @@ void Stage::set_dim_type(const VarOrRVar &var, ForType t) { << " in vars for function\n" << dump_argument_list(); } + + if (is_unordered_parallel(t)) { + check_for_race_conditions_in_split_with_blend(definition.schedule()); + } } void Stage::set_dim_device_api(const VarOrRVar &var, DeviceAPI device_api) { @@ -1171,6 +1248,11 @@ void Stage::split(const string &old, const string &outer, const string &inner, c } } + if (tail == TailStrategy::ShiftInwardsAndBlend || + tail == TailStrategy::RoundUpAndBlend) { + check_for_race_conditions_in_split_with_blend(definition.schedule()); + } + if (!definition.is_init()) { user_assert(tail != TailStrategy::ShiftInwards) << "When splitting Var " << old_name diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index cd89e76417c0..dc07d0e0f010 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -180,6 +180,12 @@ std::ostream &operator<<(std::ostream &out, const TailStrategy &t) { case TailStrategy::RoundUp: out << "RoundUp"; break; + case TailStrategy::ShiftInwardsAndBlend: + out << "ShiftInwardsAndBlend"; + break; + case TailStrategy::RoundUpAndBlend: + out << "RoundUpAndBlend"; + break; } return out; } diff --git a/src/Schedule.h b/src/Schedule.h index 22908a8425e4..32a654228673 100644 --- a/src/Schedule.h +++ b/src/Schedule.h @@ -100,6 +100,32 @@ enum class TailStrategy { * instead of a multiple of the split factor as with RoundUp. */ ShiftInwards, + /** Equivalent to ShiftInwards, but protects values that would be + * re-evaluated by loading the memory location that would be stored to, + * modifying only the elements not contained within the overlap, and then + * storing the blended result. + * + * This tail strategy is useful when you want to use ShiftInwards to + * vectorize without a scalar tail, but are scheduling a stage where that + * isn't legal (e.g. an update definition). + * + * Because this is a read - modify - write, this tail strategy cannot be + * used on any dimension the stage is parallelized over as it would cause a + * race condition. + */ + ShiftInwardsAndBlend, + + /** Equivalent to RoundUp, but protected values that would be written beyond + * the end by loading the memory location that would be stored to, + * modifying only the elements within the region being computed, and then + * storing the blended result. + * + * This tail strategy is useful when vectorizing an update to some sub-region + * of a larger Func. As with ShiftInwardsAndBlend, it can't be combined with + * parallelism. + */ + RoundUpAndBlend, + /** For pure definitions use ShiftInwards. For pure vars in * update definitions use RoundUp. For RVars in update * definitions use GuardWithIf. */ diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index 5c0b63edfe9e..9c5ca9095575 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -126,8 +126,8 @@ Stmt substitute_in(const string &name, const Expr &value, bool calls, bool provi class AddPredicates : public IRGraphMutator { const Expr &cond; - bool calls; - bool provides; + const Function &func; + ApplySplitResult::Type type; using IRMutator::visit; @@ -135,7 +135,13 @@ class AddPredicates : public IRGraphMutator { auto [args, changed_args] = mutate_with_changes(p->args); auto [values, changed_values] = mutate_with_changes(p->values); Expr predicate = mutate(p->predicate); - if (provides) { + if (type == ApplySplitResult::BlendProvides) { + int idx = 0; + for (Expr &v : values) { + v = select(cond, v, Call::make(func, args, idx++)); + } + return Provide::make(p->name, values, args, predicate); + } else if (type == ApplySplitResult::PredicateProvides) { return Provide::make(p->name, values, args, predicate && cond); } else if (changed_args || changed_values || !predicate.same_as(p->predicate)) { return Provide::make(p->name, values, args, predicate); @@ -146,20 +152,20 @@ class AddPredicates : public IRGraphMutator { Expr visit(const Call *op) override { Expr result = IRMutator::visit(op); - if (calls && op->call_type == Call::Halide) { + if (type == ApplySplitResult::PredicateCalls && op->call_type == Call::Halide) { result = Call::make(op->type, Call::if_then_else, {cond, result}, Call::PureIntrinsic); } return result; } public: - AddPredicates(const Expr &cond, bool calls, bool provides) - : cond(cond), calls(calls), provides(provides) { + AddPredicates(const Expr &cond, const Function &func, ApplySplitResult::Type type) + : cond(cond), func(func), type(type) { } }; -Stmt add_predicates(const Expr &cond, bool calls, bool provides, const Stmt &s) { - return AddPredicates(cond, calls, provides).mutate(s); +Stmt add_predicates(const Expr &cond, const Function &func, ApplySplitResult::Type type, const Stmt &s) { + return AddPredicates(cond, func, type).mutate(s); } // Build a loop nest about a provide node using a schedule @@ -227,10 +233,10 @@ Stmt build_loop_nest( stmt = substitute_in(res.name, res.value, true, false, stmt); } else if (res.is_substitution_in_provides()) { stmt = substitute_in(res.name, res.value, false, true, stmt); - } else if (res.is_predicate_calls()) { - stmt = add_predicates(res.value, true, false, stmt); - } else if (res.is_predicate_provides()) { - stmt = add_predicates(res.value, false, true, stmt); + } else if (res.is_blend_provides() || + res.is_predicate_calls() || + res.is_predicate_provides()) { + stmt = add_predicates(res.value, func, res.type, stmt); } else if (res.is_let()) { stmt = LetStmt::make(res.name, res.value, stmt); } else { diff --git a/src/Serialization.cpp b/src/Serialization.cpp index 2928e3b7ebbf..a945bd1fd0a9 100644 --- a/src/Serialization.cpp +++ b/src/Serialization.cpp @@ -320,6 +320,10 @@ Serialize::TailStrategy Serializer::serialize_tail_strategy(const TailStrategy & return Serialize::TailStrategy::PredicateStores; case TailStrategy::ShiftInwards: return Serialize::TailStrategy::ShiftInwards; + case TailStrategy::ShiftInwardsAndBlend: + return Serialize::TailStrategy::ShiftInwardsAndBlend; + case TailStrategy::RoundUpAndBlend: + return Serialize::TailStrategy::RoundUpAndBlend; case TailStrategy::Auto: return Serialize::TailStrategy::Auto; default: diff --git a/src/halide_ir.fbs b/src/halide_ir.fbs index 479e488b6739..e9e5f947f2ed 100644 --- a/src/halide_ir.fbs +++ b/src/halide_ir.fbs @@ -527,6 +527,8 @@ enum TailStrategy: ubyte { PredicateLoads, PredicateStores, ShiftInwards, + ShiftInwardsAndBlend, + RoundUpAndBlend, Auto, } diff --git a/test/correctness/nested_tail_strategies.cpp b/test/correctness/nested_tail_strategies.cpp index 2a0ddc7a6bf8..a1f59d30c0bb 100644 --- a/test/correctness/nested_tail_strategies.cpp +++ b/test/correctness/nested_tail_strategies.cpp @@ -19,10 +19,12 @@ void my_free(JITUserContext *user_context, void *ptr) { void check(Func out, int line, std::vector tails) { bool has_round_up = std::find(tails.begin(), tails.end(), TailStrategy::RoundUp) != tails.end() || + std::find(tails.begin(), tails.end(), TailStrategy::RoundUpAndBlend) != tails.end() || std::find(tails.begin(), tails.end(), TailStrategy::PredicateLoads) != tails.end() || std::find(tails.begin(), tails.end(), TailStrategy::PredicateStores) != tails.end(); bool has_shift_inwards = - std::find(tails.begin(), tails.end(), TailStrategy::ShiftInwards) != tails.end(); + std::find(tails.begin(), tails.end(), TailStrategy::ShiftInwards) != tails.end() || + std::find(tails.begin(), tails.end(), TailStrategy::ShiftInwardsAndBlend) != tails.end(); std::vector sizes_to_try; @@ -68,6 +70,12 @@ int main(int argc, char **argv) { return 0; } + // We'll randomly subsample these tests, because otherwise there are too many of them. + std::mt19937 rng(0); + int seed = argc > 1 ? atoi(argv[1]) : time(nullptr); + rng.seed(seed); + std::cout << "Nested tail strategies seed: " << seed << "\n"; + // Test random compositions of tail strategies in simple // producer-consumer pipelines. The bounds being tight sometimes // depends on the simplifier being able to cancel out things. @@ -76,7 +84,8 @@ int main(int argc, char **argv) { TailStrategy::RoundUp, TailStrategy::GuardWithIf, TailStrategy::ShiftInwards, - }; + TailStrategy::RoundUpAndBlend, + TailStrategy::ShiftInwardsAndBlend}; TailStrategy innermost_tails[] = { TailStrategy::RoundUp, @@ -84,7 +93,8 @@ int main(int argc, char **argv) { TailStrategy::PredicateLoads, TailStrategy::PredicateStores, TailStrategy::ShiftInwards, - }; + TailStrategy::RoundUpAndBlend, + TailStrategy::ShiftInwardsAndBlend}; // Two stages. First stage computed at tiles of second. for (auto t1 : innermost_tails) { @@ -110,6 +120,10 @@ int main(int argc, char **argv) { for (auto t1 : innermost_tails) { for (auto t2 : innermost_tails) { for (auto t3 : innermost_tails) { + if ((rng() & 7) != 0) { + continue; + } + Func in("in"), f("f"), g("g"), h("h"); Var x; @@ -134,6 +148,10 @@ int main(int argc, char **argv) { for (auto t1 : tails) { for (auto t2 : innermost_tails) { for (auto t3 : innermost_tails) { + if ((rng() & 7) != 0) { + continue; + } + Func in, f, g, h; Var x; @@ -158,8 +176,12 @@ int main(int argc, char **argv) { // (but can handle smaller outputs). for (auto t1 : innermost_tails) { for (auto t2 : tails) { - for (auto t3 : tails) { // Not innermost_tails because of n^4 complexity here. + for (auto t3 : innermost_tails) { for (auto t4 : tails) { + if ((rng() & 63) != 0) { + continue; + } + Func in("in"), f("f"), g("g"), h("h"); Var x; diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 440851b521cb..ace9247056d1 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -92,7 +92,9 @@ tests(GROUPS error reuse_var_in_schedule.cpp reused_args.cpp rfactor_inner_dim_non_commutative.cpp + round_up_and_blend_race.cpp run_with_large_stack_throws.cpp + shift_inwards_and_blend_race.cpp specialize_fail.cpp split_inner_wrong_tail_strategy.cpp split_non_innermost_predicated.cpp diff --git a/test/error/round_up_and_blend_race.cpp b/test/error/round_up_and_blend_race.cpp new file mode 100644 index 000000000000..72244c0a6e8b --- /dev/null +++ b/test/error/round_up_and_blend_race.cpp @@ -0,0 +1,23 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + + Func f; + Var x; + + f(x) = 0; + f(x) += 4; + + // This schedule should be forbidden, because it causes a race condition. + Var xo, xi; + f.update() + .split(x, xo, xi, 8, TailStrategy::RoundUp) + .vectorize(xi, 16, TailStrategy::RoundUpAndBlend) // Access beyond the end of each slice + .parallel(xo); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/shift_inwards_and_blend_race.cpp b/test/error/shift_inwards_and_blend_race.cpp new file mode 100644 index 000000000000..67b4d9a6bcf1 --- /dev/null +++ b/test/error/shift_inwards_and_blend_race.cpp @@ -0,0 +1,19 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + + Func f; + Var x; + + f(x) = 0; + f(x) += 4; + + // This schedule should be forbidden, because it causes a race condition. + f.update().vectorize(x, 8, TailStrategy::ShiftInwardsAndBlend).parallel(x); + + printf("Success!\n"); + return 0; +} diff --git a/test/performance/CMakeLists.txt b/test/performance/CMakeLists.txt index f47e92d6436b..1fecb06d0195 100644 --- a/test/performance/CMakeLists.txt +++ b/test/performance/CMakeLists.txt @@ -7,6 +7,7 @@ endif() tests(GROUPS performance SOURCES async_gpu.cpp + blend_tail_strategies.cpp block_transpose.cpp boundary_conditions.cpp clamped_vector_load.cpp diff --git a/test/performance/blend_tail_strategies.cpp b/test/performance/blend_tail_strategies.cpp new file mode 100644 index 000000000000..fa6a6f03d8c4 --- /dev/null +++ b/test/performance/blend_tail_strategies.cpp @@ -0,0 +1,93 @@ +#include "Halide.h" +#include "halide_benchmark.h" + +using namespace Halide; +using namespace Halide::Tools; + +int main(int argc, char **argv) { + Var x("x"), y("y"); + + Target t = get_jit_target_from_environment(); + + // Make sure we don't have predicated instructions available + if ((t.arch != Target::X86 && t.arch != Target::ARM) || + t.has_feature(Target::AVX512) || + t.has_feature(Target::SVE)) { + printf("[SKIP] This is a test for architectures without predication. " + "Currently we only test x86 before AVX-512 and ARM without SVE\n"); + return 0; + } + + const int N = t.natural_vector_size() * 2; + const int reps = 1024 * 128; + + Buffer output_buf(N - 1, N - 1); + Buffer correct_output; + + std::map times; + for (auto ts : {TailStrategy::GuardWithIf, + TailStrategy::RoundUp, + TailStrategy::ShiftInwardsAndBlend, + TailStrategy::RoundUpAndBlend}) { + Func f, g; + f(x, y) = cast(x + y); + RDom r(0, reps); + f(x, y) = f(x, y) * 3 + cast(0 * r); + g(x, y) = f(x, y); + + f.compute_root() + .update() + .reorder(x, y, r) + .vectorize(x, N / 2, ts); + + if (ts == TailStrategy::ShiftInwardsAndBlend) { + // Hide the stall from a load that overlaps the previous store by + // doing multiple scanlines at once. We expect the tail in y might + // be large, so force partitioning of x even in the loop tail in y. + f.update() + .reorder(y, x) + .unroll(y, 8, TailStrategy::GuardWithIf) + .reorder(x, y) + .partition(x, Partition::Always); + } + + g.compile_jit(); + // Uncomment to see the assembly + // g.compile_to_assembly("/dev/stdout", {}, "f", t); + double t = benchmark([&]() { + g.realize(output_buf); + }); + + // Check correctness + if (ts == TailStrategy::GuardWithIf) { + correct_output = output_buf.copy(); + } else { + for (int y = 0; y < output_buf.height(); y++) { + for (int x = 0; x < output_buf.width(); x++) { + if (output_buf(x, y) != correct_output(x, y)) { + printf("output_buf(%d, %d) = %d instead of %d\n", + x, y, output_buf(x, y), correct_output(x, y)); + } + } + } + } + times[ts] = t; + } + + for (auto p : times) { + std::cout << p.first << " " << p.second << "\n"; + } + + if (times[TailStrategy::GuardWithIf] < times[TailStrategy::ShiftInwardsAndBlend]) { + printf("ShiftInwardsAndBlend is slower than it should be\n"); + return 1; + } + + if (times[TailStrategy::GuardWithIf] < times[TailStrategy::RoundUpAndBlend]) { + printf("RoundUpAndBlend is slower than it should be\n"); + return 1; + } + + printf("Success!\n"); + return 0; +}