diff --git a/Makefile b/Makefile index 5c5a3593ee49..b15ddee4147d 100644 --- a/Makefile +++ b/Makefile @@ -470,6 +470,7 @@ SOURCE_FILES = \ BoundaryConditions.cpp \ Bounds.cpp \ BoundsInference.cpp \ + BoundConstantExtentLoops.cpp \ BoundSmallAllocations.cpp \ Buffer.cpp \ Callable.cpp \ @@ -665,6 +666,7 @@ HEADER_FILES = \ BoundaryConditions.h \ Bounds.h \ BoundsInference.h \ + BoundConstantExtentLoops.h \ BoundSmallAllocations.h \ Buffer.h \ Callable.h \ diff --git a/src/BoundConstantExtentLoops.cpp b/src/BoundConstantExtentLoops.cpp new file mode 100644 index 000000000000..d2901854f6eb --- /dev/null +++ b/src/BoundConstantExtentLoops.cpp @@ -0,0 +1,136 @@ +#include "BoundConstantExtentLoops.h" +#include "Bounds.h" +#include "CSE.h" +#include "IRMutator.h" +#include "IROperator.h" +#include "Simplify.h" +#include "SimplifyCorrelatedDifferences.h" +#include "Substitute.h" + +namespace Halide { +namespace Internal { + +namespace { +class BoundLoops : public IRMutator { + using IRMutator::visit; + + std::vector> lets; + + Stmt visit(const LetStmt *op) override { + if (is_pure(op->value)) { + lets.emplace_back(op->name, op->value); + Stmt s = IRMutator::visit(op); + lets.pop_back(); + return s; + } else { + return IRMutator::visit(op); + } + } + + std::vector facts; + Stmt visit(const IfThenElse *op) override { + facts.push_back(op->condition); + Stmt then_case = mutate(op->then_case); + Stmt else_case; + if (op->else_case.defined()) { + facts.back() = simplify(!op->condition); + else_case = mutate(op->else_case); + } + facts.pop_back(); + if (then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return op; + } else { + return IfThenElse::make(op->condition, then_case, else_case); + } + } + + Stmt visit(const For *op) override { + if (is_const(op->extent)) { + // Nothing needs to be done + return IRMutator::visit(op); + } + + if (op->for_type == ForType::Unrolled || + op->for_type == ForType::Vectorized) { + // Give it one last chance to simplify to an int + Expr extent = simplify(op->extent); + Stmt body = op->body; + const IntImm *e = extent.as(); + + if (e == nullptr) { + // We're about to hard fail. Get really aggressive + // with the simplifier. + for (auto it = lets.rbegin(); it != lets.rend(); it++) { + extent = Let::make(it->first, it->second, extent); + } + extent = remove_likelies(extent); + extent = substitute_in_all_lets(extent); + extent = simplify(extent, + true, + Scope::empty_scope(), + Scope::empty_scope(), + facts); + e = extent.as(); + } + + Expr extent_upper; + if (e == nullptr) { + // Still no luck. Try taking an upper bound and + // injecting an if statement around the body. + extent_upper = find_constant_bound(extent, Direction::Upper, Scope()); + if (extent_upper.defined()) { + e = extent_upper.as(); + body = + IfThenElse::make(likely_if_innermost(Variable::make(Int(32), op->name) < + op->min + op->extent), + body); + } + } + + if (e == nullptr && permit_failed_unroll && op->for_type == ForType::Unrolled) { + // Still no luck, but we're allowed to fail. Rewrite + // to a serial loop. + user_warning << "HL_PERMIT_FAILED_UNROLL is allowing us to unroll a non-constant loop into a serial loop. Did you mean to do this?\n"; + body = mutate(body); + return For::make(op->name, op->min, op->extent, + ForType::Serial, op->partition_policy, op->device_api, std::move(body)); + } + + user_assert(e) + << "Can only " << (op->for_type == ForType::Unrolled ? "unroll" : "vectorize") + << " for loops over a constant extent.\n" + << "Loop over " << op->name << " has extent " << extent << ".\n"; + body = mutate(body); + + return For::make(op->name, op->min, e, + op->for_type, op->partition_policy, op->device_api, std::move(body)); + } else { + return IRMutator::visit(op); + } + } + bool permit_failed_unroll = false; + +public: + BoundLoops() { + // Experimental autoschedulers may want to unroll without + // being totally confident the loop will indeed turn out + // to be constant-sized. If this feature continues to be + // important, we need to expose it in the scheduling + // language somewhere, but how? For now we do something + // ugly and expedient. + + // For the tracking issue to fix this, see + // https://github.com/halide/Halide/issues/3479 + permit_failed_unroll = get_env_variable("HL_PERMIT_FAILED_UNROLL") == "1"; + } +}; + +} // namespace + +Stmt bound_constant_extent_loops(const Stmt &s) { + return BoundLoops().mutate(s); +} + +} // namespace Internal +} // namespace Halide diff --git a/src/BoundConstantExtentLoops.h b/src/BoundConstantExtentLoops.h new file mode 100644 index 000000000000..061064f795f9 --- /dev/null +++ b/src/BoundConstantExtentLoops.h @@ -0,0 +1,24 @@ +#ifndef HALIDE_BOUND_CONSTANT_EXTENT_LOOPS_H +#define HALIDE_BOUND_CONSTANT_EXTENT_LOOPS_H + +/** \file + * Defines the lowering pass that enforces a constant extent on all + * vectorized or unrolled loops. + */ + +#include "Expr.h" + +namespace Halide { +namespace Internal { + +/** Replace all loop extents of unrolled or vectorized loops with constants, by + * substituting and simplifying as needed. If we can't determine a constant + * extent, but can determine a constant upper bound, inject an if statement into + * the body. If we can't even determine a constant upper bound, throw a user + * error. */ +Stmt bound_constant_extent_loops(const Stmt &s); + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index d8a1ff53cc37..31b441ea4251 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -1013,11 +1013,11 @@ class BoundsInference : public IRMutator { } // Dump out the region required of each stage for debugging. - /* debug(0) << "Box required of " << producer.name << " by " << consumer.name - << " stage " << consumer.stage << ":\n"; + << " stage " << consumer.stage << ":\n" + << " used: " << b.used << "\n"; for (size_t k = 0; k < b.size(); k++) { debug(0) << " " << b[k].min << " ... " << b[k].max << "\n"; } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 74e44de3c163..e708d29c11fb 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -21,7 +21,8 @@ set(HEADER_FILES BoundaryConditions.h Bounds.h BoundsInference.h - BoundSmallAllocations.h + BoundConstantExtentLoops.h + BoundSmallAllocations.h Buffer.h Callable.h CanonicalizeGPUVars.h @@ -189,6 +190,7 @@ set(SOURCE_FILES BoundaryConditions.cpp Bounds.cpp BoundsInference.cpp + BoundConstantExtentLoops.cpp BoundSmallAllocations.cpp Buffer.cpp Callable.cpp diff --git a/src/Lower.cpp b/src/Lower.cpp index 67aedde288d0..37c4bac07efb 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -11,6 +11,7 @@ #include "AddParameterChecks.h" #include "AllocationBoundsInference.h" #include "AsyncProducers.h" +#include "BoundConstantExtentLoops.h" #include "BoundSmallAllocations.h" #include "Bounds.h" #include "BoundsInference.h" @@ -312,6 +313,10 @@ void lower_impl(const vector &output_funcs, s = simplify_correlated_differences(s); log("Lowering after simplifying correlated differences:", s); + debug(1) << "Bounding constant extent loops...\n"; + s = bound_constant_extent_loops(s); + log("Lowering after bounding constant extent loops:", s); + debug(1) << "Unrolling...\n"; s = unroll_loops(s); log("Lowering after unrolling:", s); diff --git a/src/Simplify.cpp b/src/Simplify.cpp index 7a2cbac5a047..339ef2917c83 100644 --- a/src/Simplify.cpp +++ b/src/Simplify.cpp @@ -355,8 +355,13 @@ Simplify::ScopedFact::~ScopedFact() { Expr simplify(const Expr &e, bool remove_dead_let_stmts, const Scope &bounds, - const Scope &alignment) { + const Scope &alignment, + const std::vector &assumptions) { Simplify m(remove_dead_let_stmts, &bounds, &alignment); + std::vector facts; + for (const Expr &a : assumptions) { + facts.push_back(m.scoped_truth(a)); + } Expr result = m.mutate(e, nullptr); if (m.in_unreachable) { return unreachable(e.type()); @@ -366,8 +371,13 @@ Expr simplify(const Expr &e, bool remove_dead_let_stmts, Stmt simplify(const Stmt &s, bool remove_dead_let_stmts, const Scope &bounds, - const Scope &alignment) { + const Scope &alignment, + const std::vector &assumptions) { Simplify m(remove_dead_let_stmts, &bounds, &alignment); + std::vector facts; + for (const Expr &a : assumptions) { + facts.push_back(m.scoped_truth(a)); + } Stmt result = m.mutate(s); if (m.in_unreachable) { return Evaluate::make(unreachable()); diff --git a/src/Simplify.h b/src/Simplify.h index 14dec65fc025..b9335c0c3de9 100644 --- a/src/Simplify.h +++ b/src/Simplify.h @@ -13,19 +13,22 @@ namespace Halide { namespace Internal { -/** Perform a a wide range of simplifications to expressions and - * statements, including constant folding, substituting in trivial - * values, arithmetic rearranging, etc. Simplifies across let - * statements, so must not be called on stmts with dangling or - * repeated variable names. +/** Perform a wide range of simplifications to expressions and statements, + * including constant folding, substituting in trivial values, arithmetic + * rearranging, etc. Simplifies across let statements, so must not be called on + * stmts with dangling or repeated variable names. Can optionally be passed + * known bounds of any variables, known alignment properties, and any other + * Exprs that should be assumed to be true. */ // @{ Stmt simplify(const Stmt &, bool remove_dead_code = true, const Scope &bounds = Scope::empty_scope(), - const Scope &alignment = Scope::empty_scope()); + const Scope &alignment = Scope::empty_scope(), + const std::vector &assumptions = std::vector()); Expr simplify(const Expr &, bool remove_dead_code = true, const Scope &bounds = Scope::empty_scope(), - const Scope &alignment = Scope::empty_scope()); + const Scope &alignment = Scope::empty_scope(), + const std::vector &assumptions = std::vector()); // @} /** Attempt to statically prove an expression is true using the simplifier. */ diff --git a/src/UnrollLoops.cpp b/src/UnrollLoops.cpp index e1726aa28ceb..2823c8b9ac9f 100644 --- a/src/UnrollLoops.cpp +++ b/src/UnrollLoops.cpp @@ -1,16 +1,10 @@ #include "UnrollLoops.h" -#include "Bounds.h" -#include "CSE.h" #include "IRMutator.h" #include "IROperator.h" #include "Simplify.h" -#include "SimplifyCorrelatedDifferences.h" #include "Substitute.h" #include "UniquifyVariableNames.h" -using std::pair; -using std::vector; - namespace Halide { namespace Internal { @@ -19,62 +13,13 @@ namespace { class UnrollLoops : public IRMutator { using IRMutator::visit; - vector> lets; - - Stmt visit(const LetStmt *op) override { - if (is_pure(op->value)) { - lets.emplace_back(op->name, op->value); - Stmt s = IRMutator::visit(op); - lets.pop_back(); - return s; - } else { - return IRMutator::visit(op); - } - } - Stmt visit(const For *for_loop) override { if (for_loop->for_type == ForType::Unrolled) { - // Give it one last chance to simplify to an int - Expr extent = simplify(for_loop->extent); Stmt body = for_loop->body; - const IntImm *e = extent.as(); - - if (e == nullptr) { - // We're about to hard fail. Get really aggressive - // with the simplifier. - for (auto it = lets.rbegin(); it != lets.rend(); it++) { - extent = Let::make(it->first, it->second, extent); - } - extent = remove_likelies(extent); - extent = substitute_in_all_lets(extent); - extent = simplify(extent); - e = extent.as(); - } + const IntImm *e = for_loop->extent.as(); - Expr extent_upper; - bool use_guard = false; - if (e == nullptr) { - // Still no luck. Try taking an upper bound and - // injecting an if statement around the body. - extent_upper = find_constant_bound(extent, Direction::Upper, Scope()); - if (extent_upper.defined()) { - e = extent_upper.as(); - use_guard = true; - } - } - - if (e == nullptr && permit_failed_unroll) { - // Still no luck, but we're allowed to fail. Rewrite - // to a serial loop. - user_warning << "HL_PERMIT_FAILED_UNROLL is allowing us to unroll a non-constant loop into a serial loop. Did you mean to do this?\n"; - body = mutate(body); - return For::make(for_loop->name, for_loop->min, for_loop->extent, - ForType::Serial, for_loop->partition_policy, for_loop->device_api, std::move(body)); - } - - user_assert(e) - << "Can only unroll for loops over a constant extent.\n" - << "Loop over " << for_loop->name << " has extent " << extent << ".\n"; + internal_assert(e) + << "Loop over " << for_loop->name << " should have had a constant extent\n"; body = mutate(body); if (e->value == 1) { @@ -94,9 +39,6 @@ class UnrollLoops : public IRMutator { } else { iters = Block::make(iter, iters); } - if (use_guard) { - iters = IfThenElse::make(likely_if_innermost(i < for_loop->extent), iters); - } } return iters; @@ -105,21 +47,6 @@ class UnrollLoops : public IRMutator { return IRMutator::visit(for_loop); } } - bool permit_failed_unroll = false; - -public: - UnrollLoops() { - // Experimental autoschedulers may want to unroll without - // being totally confident the loop will indeed turn out - // to be constant-sized. If this feature continues to be - // important, we need to expose it in the scheduling - // language somewhere, but how? For now we do something - // ugly and expedient. - - // For the tracking issue to fix this, see - // https://github.com/halide/Halide/issues/3479 - permit_failed_unroll = get_env_variable("HL_PERMIT_FAILED_UNROLL") == "1"; - } }; } // namespace diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index df116c841217..79229e33a144 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -951,7 +951,9 @@ class VectorSubs : public IRMutator { if (op->for_type == ForType::Vectorized) { const IntImm *extent_int = extent.as(); - if (!extent_int || extent_int->value <= 1) { + internal_assert(extent_int) + << "Vectorized for loop extent should have been rewritten to a constant\n"; + if (extent_int->value <= 1) { user_error << "Loop over " << op->name << " has extent " << extent << ". Can only vectorize loops over a " diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 8fc403b298bb..88569236c106 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -318,6 +318,7 @@ tests(GROUPS correctness uninitialized_read.cpp unique_func_image.cpp unroll_dynamic_loop.cpp + unroll_loop_with_implied_constant_bounds.cpp unrolled_reduction.cpp unsafe_dedup_lets.cpp unsafe_promises.cpp @@ -335,6 +336,7 @@ tests(GROUPS correctness vectorize_nested.cpp vectorize_varying_allocation_size.cpp vectorized_gpu_allocation.cpp + vectorized_guard_with_if_tail.cpp vectorized_initialization.cpp vectorized_load_from_vectorized_allocation.cpp vectorized_reduction_bug.cpp diff --git a/test/correctness/unroll_loop_with_implied_constant_bounds.cpp b/test/correctness/unroll_loop_with_implied_constant_bounds.cpp new file mode 100644 index 000000000000..c38d59c5214a --- /dev/null +++ b/test/correctness/unroll_loop_with_implied_constant_bounds.cpp @@ -0,0 +1,54 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + // This test verifies that unrolling/vectorizing is capable of inferring + // constant bounds of loops that are implied by containing if statement + // conditions, e.g the following structure should work: + + /* + let extent = foo + if (foo == 7) { + unrolled for (x from 0 to extent) {...} + } + */ + + for (int i = 0; i < 2; i++) { + Func intermediate("intermediate"); + + Func output1("output1"), output2("output2"); + + Var x("x"), y("y"), c("c"); + + intermediate(x, y, c) = x + y + c; + + output1(x, y, c) = intermediate(x, y, c); + output2(x, y, c) = intermediate(x, y, c); + + Expr three_channels = + (output1.output_buffer().dim(2).extent() == 3 && + output1.output_buffer().dim(2).min() == 0 && + output2.output_buffer().dim(2).extent() == 3 && + output2.output_buffer().dim(2).min() == 0); + + if (i == 0) { + intermediate.compute_root() + .specialize(three_channels) + .unroll(c); + } else { + intermediate.compute_root() + .specialize(three_channels) + .vectorize(c); + } + + Pipeline p{{output1, output2}}; + + // Should not throw an error in loop unrolling or vectorization. + p.compile_jit(); + } + + printf("Success!\n"); + + return 0; +} diff --git a/test/correctness/vectorized_guard_with_if_tail.cpp b/test/correctness/vectorized_guard_with_if_tail.cpp new file mode 100644 index 000000000000..62bf975d93f1 --- /dev/null +++ b/test/correctness/vectorized_guard_with_if_tail.cpp @@ -0,0 +1,42 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + Var x; + + for (int i = 0; i < 2; i++) { + Func f, g; + f(x) = x; + g(x) = f(x) * 2; + + g.vectorize(x, 8, TailStrategy::GuardWithIf); + + f.compute_at(g, x); + + // A varying amount of f is required depending on if we're in the steady + // state of g or the tail. Nonetheless, the amount required has a constant + // upper bound of 8. Vectorization, unrolling, and variants of store_in that + // require constant extent should all be able to handle this. + if (i == 0) { + f.vectorize(x); + } else { + f.unroll(x); + } + f.store_in(MemoryType::Register); + + Buffer buf = g.realize({37}); + + for (int i = 0; i < buf.width(); i++) { + int correct = i * 2; + if (buf(i) != correct) { + printf("buf(%d) = %d instead of %d\n", + i, buf(i), correct); + return 1; + } + } + } + + printf("Success!\n"); + return 0; +}