Skip to content

Commit

Permalink
Merge branch 'main' into srj/mattrs-array
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-johnson committed Dec 8, 2023
2 parents a730c6b + 19c1c81 commit 8f9fd0c
Show file tree
Hide file tree
Showing 18 changed files with 110 additions and 44 deletions.
17 changes: 8 additions & 9 deletions README_webassembly.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ backend.
As WebAssembly itself is still under active development, Halide's support has
some limitations. Some of the most important:

- Sign-extension operations are enabled by default (but can be avoided via
Target::WasmMvpOnly).
- Non-trapping float-to-int conversions are enabled by default (but can be
avoided via Target::WasmMvpOnly).
- Fixed-width SIMD (128 bit) can be enabled via Target::WasmSimd128.
- Sign-extension operations can be enabled via Target::WasmSignExt.
- Non-trapping float-to-int conversions can be enabled via
Target::WasmSatFloatToInt.
- Threads have very limited support via Target::WasmThreads; see
[below](#using-threads) for more details.
- Halide's JIT for Wasm is extremely limited and really useful only for
Expand Down Expand Up @@ -152,9 +153,8 @@ cmake -DLLVM_ENABLE_PROJECTS="clang;lld" ...
```
- To run the JIT tests, set `HL_JIT_TARGET=wasm-32-wasmrt` (possibly adding
`wasm_simd128`, `wasm_signext`, and/or `wasm_sat_float_to_int`) and run
CMake/CTest normally. Note that wasm testing is only support under CMake
(not via Make).
`wasm_simd128`) and run CMake/CTest normally. Note that wasm testing is
only supported under CMake (not via Make).
## Enabling wasm AOT
Expand All @@ -165,9 +165,8 @@ will), you need to install Emscripten locally.
(https://emscripten.org/docs/getting_started/downloads.html).
- To run the AOT tests, set `HL_TARGET=wasm-32-wasmrt` (possibly adding
`wasm_simd128`, `wasm_signext`, and/or `wasm_sat_float_to_int`) and run
CMake/CTest normally. Note that wasm testing is only support under CMake
(not via Make).
`wasm_simd128`) and run CMake/CTest normally. Note that wasm testing is
only supported under CMake (not via Make).
# Running benchmarks
Expand Down
3 changes: 2 additions & 1 deletion apps/hist/hist_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class Hist : public Halide::Generator<Hist> {
.compute_at(hist_rows.in(), y)
.vectorize(x, vec);

hist_rows.update(0).unscheduled();
hist_rows.in()
.compute_root()
.vectorize(x, vec)
Expand All @@ -199,7 +200,7 @@ class Hist : public Halide::Generator<Hist> {
.parallel(x)
.reorder(ry, x);

cdf.compute_root();
cdf.compute_root().update().unscheduled();
output.reorder(c, x, y)
.bound(c, 0, 3)
.unroll(c)
Expand Down
2 changes: 2 additions & 0 deletions apps/iir_blur/iir_blur_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ Func blur_cols_transpose(Func input, Expr height, Expr alpha, bool skip_schedule
blur.compute_at(transpose, yo);

// Vectorize computations within the strips.
blur.update(0)
.unscheduled();
blur.update(1)
.reorder(x, ry)
.vectorize(x);
Expand Down
3 changes: 1 addition & 2 deletions python_bindings/src/halide/halide_/PyEnums.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,8 @@ void define_enums(py::module &m) {
.value("HexagonDma", Target::Feature::HexagonDma)
.value("EmbedBitcode", Target::Feature::EmbedBitcode)
.value("EnableLLVMLoopOpt", Target::Feature::EnableLLVMLoopOpt)
.value("WasmMvpOnly", Target::Feature::WasmMvpOnly)
.value("WasmSimd128", Target::Feature::WasmSimd128)
.value("WasmSignExt", Target::Feature::WasmSignExt)
.value("WasmSatFloatToInt", Target::Feature::WasmSatFloatToInt)
.value("WasmThreads", Target::Feature::WasmThreads)
.value("WasmBulkMemory", Target::Feature::WasmBulkMemory)
.value("SVE", Target::Feature::SVE)
Expand Down
6 changes: 2 additions & 4 deletions src/CodeGen_WebAssembly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,13 @@ string CodeGen_WebAssembly::mattrs() const {

std::vector<std::string_view> attrs;

if (target.has_feature(Target::WasmSignExt)) {
if (!target.has_feature(Target::WasmMvpOnly)) {
attrs.emplace_back("+sign-ext");
attrs.emplace_back("+nontrapping-fptoint");
}
if (target.has_feature(Target::WasmSimd128)) {
attrs.emplace_back("+simd128");
}
if (target.has_feature(Target::WasmSatFloatToInt)) {
attrs.emplace_back("+nontrapping-fptoint");
}
if (target.has_feature(Target::WasmThreads)) {
// "WasmThreads" doesn't directly affect LLVM codegen,
// but it does end up requiring atomics, so be sure to enable them.
Expand Down
1 change: 0 additions & 1 deletion src/IRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,6 @@ void IRPrinter::visit(const VectorReduce *op) {
stream << "("
<< op->type
<< ")vector_reduce_" << op->op << "("
<< ", "
<< op->value
<< ")";
}
Expand Down
3 changes: 1 addition & 2 deletions src/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,8 +533,7 @@ const std::map<std::string, Target::Feature> feature_name_map = {
{"embed_bitcode", Target::EmbedBitcode},
{"enable_llvm_loop_opt", Target::EnableLLVMLoopOpt},
{"wasm_simd128", Target::WasmSimd128},
{"wasm_signext", Target::WasmSignExt},
{"wasm_sat_float_to_int", Target::WasmSatFloatToInt},
{"wasm_mvponly", Target::WasmMvpOnly},
{"wasm_threads", Target::WasmThreads},
{"wasm_bulk_memory", Target::WasmBulkMemory},
{"webgpu", Target::WebGPU},
Expand Down
3 changes: 1 addition & 2 deletions src/Target.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,8 @@ struct Target {
CheckUnsafePromises = halide_target_feature_check_unsafe_promises,
EmbedBitcode = halide_target_feature_embed_bitcode,
EnableLLVMLoopOpt = halide_target_feature_enable_llvm_loop_opt,
WasmMvpOnly = halide_target_feature_wasm_mvponly,
WasmSimd128 = halide_target_feature_wasm_simd128,
WasmSignExt = halide_target_feature_wasm_signext,
WasmSatFloatToInt = halide_target_feature_wasm_sat_float_to_int,
WasmThreads = halide_target_feature_wasm_threads,
WasmBulkMemory = halide_target_feature_wasm_bulk_memory,
WebGPU = halide_target_feature_webgpu,
Expand Down
2 changes: 1 addition & 1 deletion src/VectorizeLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ class VectorSubs : public IRMutator {
}

Stmt visit(const AssertStmt *op) override {
return (op->condition.type().lanes() > 1) ? scalarize(op) : op;
return (mutate(op->condition).type().lanes() > 1) ? scalarize(op) : op;
}

Stmt visit(const IfThenElse *op) override {
Expand Down
6 changes: 2 additions & 4 deletions src/WasmExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1308,15 +1308,13 @@ wabt::interp::HostFunc::Ptr make_extern_callback(wabt::interp::Store &store,

wabt::Features calc_features(const Target &target) {
wabt::Features f;
if (target.has_feature(Target::WasmSignExt)) {
if (!target.has_feature(Target::WasmMvpOnly)) {
f.enable_sign_extension();
f.enable_sat_float_to_int();
}
if (target.has_feature(Target::WasmSimd128)) {
f.enable_simd();
}
if (target.has_feature(Target::WasmSatFloatToInt)) {
f.enable_sat_float_to_int();
}
return f;
}
#endif // WITH_WABT
Expand Down
2 changes: 1 addition & 1 deletion src/autoschedulers/anderson2021/cost_model_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ class CostModel : public Generator<CostModel<training>> {
};

// Pipeline features processing
conv1_stage1.compute_root().vectorize(c);
conv1_stage1.compute_root().vectorize(c).update().vectorize(c);
squashed_head1_filter.compute_root().vectorize(c);

// Schedule features processing. The number of schedule
Expand Down
44 changes: 32 additions & 12 deletions src/autoschedulers/mullapudi2016/AutoSchedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,20 +837,27 @@ struct AutoSchedule {
}
}

for (const auto &m : f.second) {
const int stage = m.first;
const vector<string> &schedules = m.second;
internal_assert(!schedules.empty());
const int num_stages = func.updates().size() + 1;
for (int stage = 0; stage < num_stages; stage++) {
schedule_ss << " " << fname;
if (stage > 0) {
schedule_ss << ".update(" << std::to_string(stage - 1) << ")";
schedule_ss << ".update(" << (stage - 1) << ")";
}
for (const std::string &s : schedules) {
schedule_ss << "\n ." << s;
auto it = f.second.find(stage);
if (it != f.second.end()) {
const vector<string> &schedules = it->second;
internal_assert(!schedules.empty());
for (const std::string &s : schedules) {
internal_assert(!s.empty());
schedule_ss << "\n ." << s;
}
} else {
if (stage > 0) {
schedule_ss << ".unscheduled()";
}
}
schedule_ss << ";\n";
}

schedule_ss << "}\n";
}

Expand Down Expand Up @@ -2472,10 +2479,13 @@ void Partitioner::vectorize_stage(const Group &g, Stage f_handle, int stage_num,
// storage dimension of the func.
//
// TODO: Check if the warning is necessary.
if (vec_dim_index > 0) {
user_warning << "Outer dim vectorization of var \"" << vec_dim_name
<< "\" in function \"" << f_handle.name() << "\"\n";
}
//
// Disabled: this isn't really user actionable, and is just noise.
//
// if (vec_dim_index > 0) {
// user_warning << "Outer dim vectorization of var \"" << vec_dim_name
// << "\" in function \"" << f_handle.name() << "\"\n";
// }
}
}

Expand Down Expand Up @@ -3386,6 +3396,16 @@ string generate_schedules(const vector<Function> &outputs, const Target &target,
debug(2) << "Generating CPU schedule...\n";
part.generate_cpu_schedule(target, sched);

// Ensure that all update stages are "touched" so we get no warnings/errors
for (const auto &f : sched.func_schedules) {
const Function &func = get_element(sched.env, f.first);
const int num_update_stages = func.updates().size();
for (int stage = 0; stage < num_update_stages; stage++) {
Definition def = get_stage_definition(func, stage + 1);
def.schedule().touched() = true;
}
}

std::ostringstream oss;
oss << sched;
string sched_string = oss.str();
Expand Down
3 changes: 1 addition & 2 deletions src/runtime/HalideRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -1386,9 +1386,8 @@ typedef enum halide_target_feature_t {
halide_target_feature_hexagon_dma, ///< Enable Hexagon DMA buffers.
halide_target_feature_embed_bitcode, ///< Emulate clang -fembed-bitcode flag.
halide_target_feature_enable_llvm_loop_opt, ///< Enable loop vectorization + unrolling in LLVM. Overrides halide_target_feature_disable_llvm_loop_opt. (Ignored for non-LLVM targets.)
halide_target_feature_wasm_mvponly, ///< Disable all extensions to WebAssembly codegen (including +sign-ext and +nontrapping-fptoint, which are on by default).
halide_target_feature_wasm_simd128, ///< Enable +simd128 instructions for WebAssembly codegen.
halide_target_feature_wasm_signext, ///< Enable +sign-ext instructions for WebAssembly codegen.
halide_target_feature_wasm_sat_float_to_int, ///< Enable saturating (nontrapping) float-to-int instructions for WebAssembly codegen.
halide_target_feature_wasm_threads, ///< Enable use of threads in WebAssembly codegen. Requires the use of a wasm runtime that provides pthread-compatible wrappers (typically, Emscripten with the -pthreads flag). Unsupported under WASI.
halide_target_feature_wasm_bulk_memory, ///< Enable +bulk-memory instructions for WebAssembly codegen.
halide_target_feature_webgpu, ///< Enable the WebGPU runtime.
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ tests(GROUPS correctness
vectorize_mixed_widths.cpp
vectorize_nested.cpp
vectorize_varying_allocation_size.cpp
vectorized_assert.cpp
vectorized_gpu_allocation.cpp
vectorized_initialization.cpp
vectorized_load_from_vectorized_allocation.cpp
Expand Down
4 changes: 4 additions & 0 deletions test/correctness/recursive_box_filters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ int main(int argc, char **argv) {
// have to pass 'true' to the atomic call to tell it to skip the check.
h.update(2).atomic(true).vectorize(r, 16);

// These stages don't need scheduling
h.update(0).unscheduled();
h.update(1).unscheduled();

Buffer<int> r0(size);
Buffer<int> r1(size);
h.realize({r0, r1});
Expand Down
7 changes: 4 additions & 3 deletions test/correctness/simd_op_check_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class SimdOpCheckWASM : public SimdOpCheckTest {
SimdOpCheckWASM(Target t, int w = 768, int h = 128)
: SimdOpCheckTest(t, w, h) {
use_wasm_simd128 = target.has_feature(Target::WasmSimd128);
use_wasm_sat_float_to_int = target.has_feature(Target::WasmSatFloatToInt);
use_wasm_sign_ext = target.has_feature(Target::WasmSignExt);
use_wasm_sign_ext = !target.has_feature(Target::WasmMvpOnly);
use_wasm_sat_float_to_int = !target.has_feature(Target::WasmMvpOnly);
}

void add_tests() override {
Expand Down Expand Up @@ -544,6 +544,7 @@ int main(int argc, char **argv) {
argc, argv,
{
Target("wasm-32-wasmrt"),
Target("wasm-32-wasmrt-wasm_simd128-wasm_sat_float_to_int"),
Target("wasm-32-wasmrt-wasm_simd128"),
Target("wasm-32-wasmrt-wasm_mvponly"),
});
}
46 changes: 46 additions & 0 deletions test/correctness/vectorized_assert.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "Halide.h"

using namespace Halide;

int error_count = 0;
void my_error(JITUserContext *ucon, const char *msg) {
error_count++;
}

int main(int argc, char **argv) {
Func f("f"), g("g");
Var x("x");
Param<int> p;

f(x) = x;
f(x) += 1;
g(x) = f(x) + f(2 * x + p);

g.vectorize(x, 8);
f.bound_storage(x, 32);
// No way to check this at compile time. The size of f depends on both x and
// p. An assert is injected, but the assert is inside g's vectorized loop.

g.jit_handlers().custom_error = my_error;

g.compile_jit();

// Will trigger the assert
p.set(256);
g.realize({128});
if (error_count != 1) {
printf("There should have been an error\n");
return 1;
}

// Will not trigger the assert
p.set(0);
g.realize({8});
if (error_count != 1) {
printf("There should not have been an error\n");
return 1;
}

printf("Success!\n");
return 0;
}
1 change: 1 addition & 0 deletions test/error/tuple_output_bounds_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ int main(int argc, char **argv) {

Var xo, xi;
h.split(x, xo, xi, 16, TailStrategy::RoundUp);
h.update(0).unscheduled();

Buffer<int> r0(size);
Buffer<int> r1(size);
Expand Down

0 comments on commit 8f9fd0c

Please sign in to comment.