Skip to content

Commit

Permalink
Scheduling directive to support ring buffering (#7967)
Browse files Browse the repository at this point in the history
* Half-plumbed

* Revert "Half-plumbed"

This reverts commit eb9dd02.

* Interface for double buffer

* Update Provides, Calls and Realizes for double buffering

* Proper sync for double buffering

* Use proper name for the semaphor and use correct initial value

* Rename the class

* Pass expression for index

* Adds storage for double buffering index

* Use a separate index to go through the double buffer

* Failing test

* Better handling of hoisted storage in all of the async-related passes

* New test and clean-up the generated IR

* More tests

* Allow double buffering without async and add corresponding test

* Filter out incorrect double_buffer schedules

* Add tests to the cmake files

* Clean up

* Update the comment

* Clean up

* Clean up

* Update serialization

* complete_x86_target() should enable F16C and FMA when AVX2 is present (#7971)

All known AVX2-enabled architectures definitely have these features.

* Add two new tail strategies for update definitions (#7949)

* Add two new tail strategies for update definitions

* Stop printing asm

* Update expected number of partitions for Partition::Always

* Add a comment explaining why the blend safety check is per dimension

* Add serialization support for the new tail strategies

* trigger buildbots

* Add comment

---------

Co-authored-by: Steven Johnson <srj@google.com>

* Add appropriate mattrs for arm-32 extensions (#7978)

* Add appropriate mattrs for arm-32 extensions

Fixes #7976

* Pull clauses out of if

* Move canonical version numbers into source, not build system (#7980) (#7981)

* Move canonical version numbers into source, not build system (#7980)

* Fixes

* Silence useless "Insufficient parallelism" autoscheduler warning (#7990)

* Add a notebook with a visualization of the aprrox_* functions and their errors (#7974)

* Add a notebook with a visualization of the aprrox_* functions and their errors

* Fix spelling error

* Make narrowing float->int casts on wasm go via wider ints (#7973)

Fixes #7972

* Fix handling of assert statements whose conditions get vectorized (#7989)

* Fix handling of assert statements whose conditions get vectorized

* Fix test name

* Fix all "unscheduled update()" warnings in our code (#7991)

* Fix all "unscheduled update()" warnings in our code

And also fix the Mullapudi scheduler to explicitly touch all update stages. This allows us to mark this warning as an error if we so choose.

* fixes

* fixes

* Update recursive_box_filters.cpp

* Silence useless 'Outer dim vectorization of var' warning in Mullapudi… (#7992)

Silence useless 'Outer dim vectorization of var' warning in Mullapudi scheduler

* Add a tutorial for async and double_buffer

* Renamed double_buffer to ring_buffer

* ring_buffer() now expects an extent Expr

* Actually use extent for ring_buffer()

* Address some of the comments

* Provide an example of the code structure for producer-consumer async example

* Comments updates

* Fix clang-format and clang-tidy

* Add Python binding for Func::ring_buffer()

* Don't use a separate index for ring buffer + add a new test

* Rename the tests

* Clean up the old name

* Add &

* Move test to the right folder

* Move expr

* Add comments for InjectRingBuffering

* Improve ring_buffer doc

* Fix comments

* Comments

* A better error message

* Mention that extent is expected to be a positive integer

* Add another code structure and explain how the indices for ring buffer are computed

* Expand test comments

* Fix spelling

---------

Co-authored-by: Steven Johnson <srj@google.com>
Co-authored-by: Andrew Adams <andrew.b.adams@gmail.com>
  • Loading branch information
3 people authored Dec 19, 2023
1 parent 6bcb695 commit 61b8d38
Show file tree
Hide file tree
Showing 17 changed files with 1,045 additions and 48 deletions.
1 change: 1 addition & 0 deletions python_bindings/src/halide/halide_/PyFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ void define_func(py::module &m) {
.def("store_at", (Func & (Func::*)(LoopLevel)) & Func::store_at, py::arg("loop_level"))

.def("async_", &Func::async)
.def("ring_buffer", &Func::ring_buffer)
.def("bound_storage", &Func::bound_storage)
.def("memoize", &Func::memoize)
.def("compute_inline", &Func::compute_inline)
Expand Down
297 changes: 252 additions & 45 deletions src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ class NoOpCollapsingMutator : public IRMutator {
}
}

Stmt visit(const HoistedStorage *op) override {
Stmt body = mutate(op->body);
if (is_no_op(body)) {
return body;
} else {
return HoistedStorage::make(op->name, body);
}
}

Stmt visit(const Allocate *op) override {
Stmt body = mutate(op->body);
if (is_no_op(body)) {
Expand Down Expand Up @@ -198,6 +207,9 @@ class GenerateProducerBody : public NoOpCollapsingMutator {
if (starts_with(op->name, func + ".folding_semaphore.") && ends_with(op->name, ".head")) {
// This is a counter associated with the producer side of a storage-folding semaphore. Keep it.
return op;
} else if (starts_with(op->name, func + ".ring_buffer.")) {
// This is a counter associated with the producer side of a ring buffering.
return op;
} else {
return Evaluate::make(0);
}
Expand Down Expand Up @@ -243,8 +255,42 @@ class GenerateProducerBody : public NoOpCollapsingMutator {
return op;
}

Stmt visit(const Allocate *op) override {
Stmt body = mutate(op->body);
if (is_no_op(body)) {
return body;
} else {
return Allocate::make(op->name, op->type, op->memory_type,
op->extents, op->condition, body,
op->new_expr, op->free_function, op->padding);
}
}

Stmt visit(const Realize *op) override {
Stmt body = mutate(op->body);
if (is_no_op(body)) {
return body;
} else {
inner_realizes.insert(op->name);
return Realize::make(op->name, op->types, op->memory_type,
op->bounds, op->condition, body);
}
}

Stmt visit(const HoistedStorage *op) override {
Stmt body = mutate(op->body);
if (is_no_op(body)) {
return body;
} else if (inner_realizes.count(op->name) == 0) {
return body;
} else {
return HoistedStorage::make(op->name, body);
}
}

map<string, vector<string>> &cloned_acquires;
set<string> inner_semaphores;
set<string> inner_realizes;

public:
GenerateProducerBody(const string &f, const vector<Expr> &s, map<string, vector<string>> &a)
Expand Down Expand Up @@ -363,57 +409,78 @@ class ForkAsyncProducers : public IRMutator {
const map<string, Function> &env;

map<string, vector<string>> cloned_acquires;

Stmt visit(const Realize *op) override {
auto it = env.find(op->name);
internal_assert(it != env.end());
Function f = it->second;
if (f.schedule().async()) {
Stmt body = op->body;

// Make two copies of the body, one which only does the
// producer, and one which only does the consumer. Inject
// synchronization to preserve dependencies. Put them in a
// task-parallel block.

// Make a semaphore per consume node
CountConsumeNodes consumes(op->name);
body.accept(&consumes);

vector<string> sema_names;
vector<Expr> sema_vars;
for (int i = 0; i < consumes.count; i++) {
sema_names.push_back(op->name + ".semaphore_" + std::to_string(i));
sema_vars.push_back(Variable::make(type_of<halide_semaphore_t *>(), sema_names.back()));
std::set<string> hoisted_storages;

Stmt process_body(const string &name, Stmt body) {
// Make two copies of the body, one which only does the
// producer, and one which only does the consumer. Inject
// synchronization to preserve dependencies. Put them in a
// task-parallel block.

// Make a semaphore per consume node
CountConsumeNodes consumes(name);
body.accept(&consumes);

vector<string> sema_names;
vector<Expr> sema_vars;
for (int i = 0; i < consumes.count; i++) {
sema_names.push_back(name + ".semaphore_" + std::to_string(i));
sema_vars.push_back(Variable::make(type_of<halide_semaphore_t *>(), sema_names.back()));
}

Stmt producer = GenerateProducerBody(name, sema_vars, cloned_acquires).mutate(body);
Stmt consumer = GenerateConsumerBody(name, sema_vars).mutate(body);

// Recurse on both sides
producer = mutate(producer);
consumer = mutate(consumer);

// Run them concurrently
body = Fork::make(producer, consumer);

for (const string &sema_name : sema_names) {
// Make a semaphore on the stack
Expr sema_space = Call::make(type_of<halide_semaphore_t *>(), "halide_make_semaphore",
{0}, Call::Extern);

// If there's a nested async producer, we may have
// recursively cloned this semaphore inside the mutation
// of the producer and consumer.
const vector<string> &clones = cloned_acquires[sema_name];
for (const auto &i : clones) {
body = CloneAcquire(sema_name, i).mutate(body);
body = LetStmt::make(i, sema_space, body);
}

Stmt producer = GenerateProducerBody(op->name, sema_vars, cloned_acquires).mutate(body);
Stmt consumer = GenerateConsumerBody(op->name, sema_vars).mutate(body);

// Recurse on both sides
producer = mutate(producer);
consumer = mutate(consumer);

// Run them concurrently
body = Fork::make(producer, consumer);
body = LetStmt::make(sema_name, sema_space, body);
}

for (const string &sema_name : sema_names) {
// Make a semaphore on the stack
Expr sema_space = Call::make(type_of<halide_semaphore_t *>(), "halide_make_semaphore",
{0}, Call::Extern);
return body;
}

// If there's a nested async producer, we may have
// recursively cloned this semaphore inside the mutation
// of the producer and consumer.
const vector<string> &clones = cloned_acquires[sema_name];
for (const auto &i : clones) {
body = CloneAcquire(sema_name, i).mutate(body);
body = LetStmt::make(i, sema_space, body);
}
Stmt visit(const HoistedStorage *op) override {
hoisted_storages.insert(op->name);
Stmt body = op->body;

body = LetStmt::make(sema_name, sema_space, body);
}
auto it = env.find(op->name);
internal_assert(it != env.end());
Function f = it->second;
if (f.schedule().async() && f.schedule().ring_buffer().defined()) {
body = process_body(op->name, body);
} else {
body = mutate(body);
}
hoisted_storages.erase(op->name);
return HoistedStorage::make(op->name, body);
}

Stmt visit(const Realize *op) override {
auto it = env.find(op->name);
internal_assert(it != env.end());
Function f = it->second;
if (f.schedule().async() && hoisted_storages.count(op->name) == 0) {
Stmt body = op->body;
body = process_body(op->name, body);
return Realize::make(op->name, op->types, op->memory_type,
op->bounds, op->condition, body);
} else {
Expand Down Expand Up @@ -592,6 +659,117 @@ class TightenProducerConsumerNodes : public IRMutator {
}
};

// Update indices to add ring buffer.
class UpdateIndices : public IRMutator {
using IRMutator::visit;

Stmt visit(const Provide *op) override {
if (op->name == func_name) {
std::vector<Expr> args = op->args;
args.push_back(ring_buffer_index);
return Provide::make(op->name, op->values, args, op->predicate);
}
return IRMutator::visit(op);
}

Expr visit(const Call *op) override {
if (op->call_type == Call::Halide && op->name == func_name) {
std::vector<Expr> args = op->args;
args.push_back(ring_buffer_index);
return Call::make(op->type, op->name, args, op->call_type, op->func, op->value_index, op->image, op->param);
}
return IRMutator::visit(op);
}

std::string func_name;
Expr ring_buffer_index;

public:
UpdateIndices(const string &fn, Expr di)
: func_name(fn), ring_buffer_index(std::move(di)) {
}
};

// Inject ring buffering.
class InjectRingBuffering : public IRMutator {
using IRMutator::visit;

struct Loop {
std::string name;
Expr min;
Expr extent;

Loop(std::string n, Expr m, Expr e)
: name(std::move(n)), min(std::move(m)), extent(std::move(e)) {
}
};

const map<string, Function> &env;
std::vector<Loop> loops;
std::map<std::string, int> hoist_storage_loop_index;

Stmt visit(const Realize *op) override {
Stmt body = mutate(op->body);
Function f = env.find(op->name)->second;
Region bounds = op->bounds;
if (f.schedule().ring_buffer().defined()) {
// For the ring buffering we expand the storage by adding another dimension of
// the range of [0, ring_buffer.extent].
bounds.emplace_back(0, f.schedule().ring_buffer());
// Build an index for accessing ring buffer as a linear combination of all
// loop variables between the storage location (defined by the HoistStorage loop level)
// and corresponding Realize node.
int loop_index = hoist_storage_loop_index[op->name] + 1;
Expr current_index = Variable::make(Int(32), loops[loop_index].name);
while (++loop_index < (int)loops.size()) {
current_index = current_index *
(loops[loop_index].extent - loops[loop_index].min) +
Variable::make(Int(32), loops[loop_index].name);
}
current_index = current_index % f.schedule().ring_buffer();
// Adds an extra index for to the all of the references of f.
body = UpdateIndices(op->name, current_index).mutate(body);
Expr sema_var = Variable::make(type_of<halide_semaphore_t *>(), f.name() + ".folding_semaphore.ring_buffer");
Expr release_producer = Call::make(Int(32), "halide_semaphore_release", {sema_var, 1}, Call::Extern);
Stmt release = Evaluate::make(release_producer);
body = Block::make(body, release);
body = Acquire::make(sema_var, 1, body);
}

return Realize::make(op->name, op->types, op->memory_type, bounds, op->condition, body);
}

Stmt visit(const HoistedStorage *op) override {
// Store the index of the last loop we encountered.
hoist_storage_loop_index[op->name] = loops.size() - 1;
Function f = env.find(op->name)->second;

Stmt mutated = mutate(op->body);
mutated = HoistedStorage::make(op->name, mutated);

if (f.schedule().ring_buffer().defined()) {
// Make a semaphore on the stack
Expr sema_space = Call::make(type_of<halide_semaphore_t *>(), "halide_make_semaphore",
{2}, Call::Extern);
mutated = LetStmt::make(f.name() + std::string(".folding_semaphore.ring_buffer"), sema_space, mutated);
}
hoist_storage_loop_index.erase(op->name);
return mutated;
}

Stmt visit(const For *op) override {
loops.emplace_back(op->name, op->min, op->extent);
Stmt mutated = IRMutator::visit(op);
loops.pop_back();
return mutated;
}

public:
InjectRingBuffering(const map<string, Function> &e)
: env(e) {
}
};

// Broaden the scope of acquire nodes to pack trailing work into the
// same task and to potentially reduce the nesting depth of tasks.
class ExpandAcquireNodes : public IRMutator {
Expand Down Expand Up @@ -639,6 +817,18 @@ class ExpandAcquireNodes : public IRMutator {
}
}

Stmt visit(const HoistedStorage *op) override {
Stmt body = mutate(op->body);
if (const Acquire *a = body.as<Acquire>()) {
// Don't do the allocation until we have the
// semaphore. Reduces peak memory use.
return Acquire::make(a->semaphore, a->count,
mutate(HoistedStorage::make(op->name, a->body)));
} else {
return HoistedStorage::make(op->name, body);
}
}

Stmt visit(const LetStmt *op) override {
Stmt orig = op;
Stmt body;
Expand Down Expand Up @@ -693,6 +883,9 @@ class TightenForkNodes : public IRMutator {
const LetStmt *lr = rest.as<LetStmt>();
const Realize *rf = first.as<Realize>();
const Realize *rr = rest.as<Realize>();
const HoistedStorage *hf = first.as<HoistedStorage>();
const HoistedStorage *hr = rest.as<HoistedStorage>();

if (lf && lr &&
lf->name == lr->name &&
equal(lf->value, lr->value)) {
Expand All @@ -707,6 +900,10 @@ class TightenForkNodes : public IRMutator {
} else if (rr && !stmt_uses_var(first, rr->name)) {
return Realize::make(rr->name, rr->types, rr->memory_type,
rr->bounds, rr->condition, make_fork(first, rr->body));
} else if (hf && !stmt_uses_var(rest, hf->name)) {
return HoistedStorage::make(hf->name, make_fork(rf->body, rest));
} else if (hr && !stmt_uses_var(first, hr->name)) {
return HoistedStorage::make(hr->name, make_fork(first, hr->body));
} else {
return Fork::make(first, rest);
}
Expand Down Expand Up @@ -740,6 +937,15 @@ class TightenForkNodes : public IRMutator {
}
}

Stmt visit(const HoistedStorage *op) override {
Stmt body = mutate(op->body);
if (in_fork && !stmt_uses_var(body, op->name)) {
return body;
} else {
return HoistedStorage::make(op->name, body);
}
}

Stmt visit(const LetStmt *op) override {
Stmt body = mutate(op->body);
if (in_fork && !stmt_uses_var(body, op->name)) {
Expand All @@ -758,6 +964,7 @@ class TightenForkNodes : public IRMutator {

Stmt fork_async_producers(Stmt s, const map<string, Function> &env) {
s = TightenProducerConsumerNodes(env).mutate(s);
s = InjectRingBuffering(env).mutate(s);
s = ForkAsyncProducers(env).mutate(s);
s = ExpandAcquireNodes().mutate(s);
s = TightenForkNodes().mutate(s);
Expand Down
2 changes: 2 additions & 0 deletions src/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,7 @@ FuncSchedule Deserializer::deserialize_func_schedule(const Serialize::FuncSchedu
const auto memory_type = deserialize_memory_type(func_schedule->memory_type());
const auto memoized = func_schedule->memoized();
const auto async = func_schedule->async();
const auto ring_buffer = deserialize_expr(func_schedule->ring_buffer_type(), func_schedule->ring_buffer());
const auto memoize_eviction_key = deserialize_expr(func_schedule->memoize_eviction_key_type(), func_schedule->memoize_eviction_key());
auto hl_func_schedule = FuncSchedule();
hl_func_schedule.store_level() = store_level;
Expand All @@ -1029,6 +1030,7 @@ FuncSchedule Deserializer::deserialize_func_schedule(const Serialize::FuncSchedu
hl_func_schedule.memory_type() = memory_type;
hl_func_schedule.memoized() = memoized;
hl_func_schedule.async() = async;
hl_func_schedule.ring_buffer() = ring_buffer;
hl_func_schedule.memoize_eviction_key() = memoize_eviction_key;
return hl_func_schedule;
}
Expand Down
Loading

0 comments on commit 61b8d38

Please sign in to comment.