Skip to content

Commit

Permalink
Dequalify names when constructing RVars in rfactor (#8560)
Browse files Browse the repository at this point in the history
Also includes two drive-by fixes:

* Don't apply C++-specific warnings to C sources

GCC complains about certain warning flags being passed to C sources.
This will break -Werror builds. The flags in question are:

  * -Woverloaded-virtual
  * -Wsuggest-override
  * -Wno-old-style-cast

* Use set_error on pybind11 2.12+

In pybind/pybind11#4772, the py::exception<>::operator() functions were
deprecated because the static destructor could run after interpreter
finalization and lead to undefined behavior.

The old code has been modified to use PyErr_SetString directly, which
also avoids lifetime issues.
  • Loading branch information
alexreinking authored Jan 29, 2025
1 parent 53afee7 commit 9500c05
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 10 deletions.
7 changes: 4 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,12 @@ function(set_halide_compiler_warnings NAME)

$<$<CXX_COMPILER_ID:GNU,Clang,AppleClang>:-Wcast-qual>
$<$<CXX_COMPILER_ID:GNU,Clang,AppleClang>:-Wignored-qualifiers>
$<$<CXX_COMPILER_ID:GNU,Clang,AppleClang>:-Woverloaded-virtual>
$<$<CXX_COMPILER_ID:GNU,Clang,AppleClang>:-Wimplicit-fallthrough>

$<$<CXX_COMPILER_ID:GNU>:-Wsuggest-override>
# GCC warns when these warnings are given to plain-C sources
$<$<COMPILE_LANG_AND_ID:CXX,GNU,Clang,AppleClang>:-Woverloaded-virtual>
$<$<COMPILE_LANG_AND_ID:CXX,GNU>:-Wsuggest-override>
$<$<COMPILE_LANG_AND_ID:CXX,GNU,Clang,AppleClang>:-Wno-old-style-cast>

$<$<CXX_COMPILER_ID:Clang,AppleClang>:-Winconsistent-missing-destructor-override>
$<$<CXX_COMPILER_ID:Clang,AppleClang>:-Winconsistent-missing-override>
Expand All @@ -139,7 +141,6 @@ function(set_halide_compiler_warnings NAME)
$<$<CXX_COMPILER_ID:GNU,Clang,AppleClang>:-Wno-float-conversion>
$<$<CXX_COMPILER_ID:GNU,Clang,AppleClang>:-Wno-float-equal>
$<$<CXX_COMPILER_ID:GNU,Clang,AppleClang>:-Wno-missing-field-initializers>
$<$<CXX_COMPILER_ID:GNU,Clang,AppleClang>:-Wno-old-style-cast>
$<$<CXX_COMPILER_ID:GNU,Clang,AppleClang>:-Wno-shadow>
$<$<CXX_COMPILER_ID:GNU,Clang,AppleClang>:-Wno-sign-conversion>
$<$<CXX_COMPILER_ID:GNU,Clang,AppleClang>:-Wno-switch-enum>
Expand Down
7 changes: 6 additions & 1 deletion python_bindings/src/halide/halide_/PyError.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ void define_error(py::module &m) {
std::rethrow_exception(p);
}
} catch (const Error &e) {
halide_error(e.what());
#if PYBIND11_VERSION_HEX >= 0x020C0000 // 2.12
set_error(halide_error, e.what());
#else
// TODO: remove this branch when upgrading pybind11 past 2.12.0
PyErr_SetString(halide_error.ptr(), e.what());
#endif
}
});
}
Expand Down
33 changes: 27 additions & 6 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,25 @@ void add_let(SubstitutionMap &subst, const string &name, const Expr &value) {
subst.emplace(name, value);
}

string dequalify(string name) {
if (const auto it = name.rfind('.'); it != string::npos) {
return name.substr(it + 1);
}
return name;
}

vector<Dim> subst_dims(const SubstitutionMap &substitution_map, const vector<Dim> &dims) {
auto new_dims = dims;
for (auto &dim : new_dims) {
if (const auto it = substitution_map.find(dim.var); it != substitution_map.end()) {
const Variable *new_var = it->second.as<Variable>();
internal_assert(new_var);
dim.var = new_var->name;
}
}
return new_dims;
}

pair<ReductionDomain, SubstitutionMap> project_rdom(const vector<Dim> &dims, const ReductionDomain &rdom, const vector<Split> &splits) {
// The bounds projections maps expressions that reference the old RDom
// bounds to expressions that reference the new RDom bounds (from dims).
Expand All @@ -694,7 +713,7 @@ pair<ReductionDomain, SubstitutionMap> project_rdom(const vector<Dim> &dims, con
for (const Dim &dim : dims) {
const Expr new_min = simplify(bounds_projection.at(dim.var + ".loop_min"));
const Expr new_extent = simplify(bounds_projection.at(dim.var + ".loop_extent"));
new_rvars.push_back(ReductionVariable{dim.var, new_min, new_extent});
new_rvars.push_back(ReductionVariable{dequalify(dim.var), new_min, new_extent});
}
ReductionDomain new_rdom{new_rvars};
new_rdom.where(rdom.predicate());
Expand Down Expand Up @@ -730,8 +749,8 @@ pair<ReductionDomain, SubstitutionMap> project_rdom(const vector<Dim> &dims, con
}
}
}
for (const auto &rv : new_rdom.domain()) {
add_let(dim_projection, rv.var, Variable::make(Int(32), rv.var, new_rdom));
for (size_t i = 0; i < new_rdom.domain().size(); i++) {
add_let(dim_projection, dims[i].var, RVar(new_rdom, i));
}
return {new_rdom, dim_projection};
}
Expand Down Expand Up @@ -902,7 +921,9 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
// Preserved
std::tie(preserved_rdom, preserved_map) = project_rdom(preserved_rdims, rdom, rvar_splits);
Scope<Interval> intm_rdom;
for (const auto &[var, min, extent] : intermediate_rdom.domain()) {
for (size_t i = 0; i < intermediate_rdom.domain().size(); i++) {
const auto &var = intermediate_rdims[i].var;
const auto &[_, min, extent] = intermediate_rdom.domain()[i];
intm_rdom.push(var, Interval{min, min + extent - 1});
}
preserved_rdom.set_predicate(or_condition_over_domain(substitute(preserved_map, preserved_rdom.predicate()), intm_rdom));
Expand Down Expand Up @@ -960,7 +981,7 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
}

intm.function().update(0).schedule() = definition.schedule().get_copy();
intm.function().update(0).schedule().dims() = std::move(intm_dims);
intm.function().update(0).schedule().dims() = subst_dims(intermediate_map, intm_dims);
intm.function().update(0).schedule().rvars() = intermediate_rdom.domain();
intm.function().update(0).schedule().splits() = var_splits;
}
Expand Down Expand Up @@ -1022,7 +1043,7 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
definition.args() = dim_vars_exprs;
definition.values() = substitute(preserved_map, prover_result.pattern.ops);
definition.predicate() = preserved_rdom.predicate();
definition.schedule().dims() = std::move(reducing_dims);
definition.schedule().dims() = subst_dims(preserved_map, reducing_dims);
definition.schedule().rvars() = preserved_rdom.domain();
definition.schedule().splits() = var_splits;
}
Expand Down
5 changes: 5 additions & 0 deletions src/Reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ class IRMutator;

/** A single named dimension of a reduction domain */
struct ReductionVariable {
/**
* A variable name for the reduction variable. This name must be a
* valid Var name, i.e. it must not contain a <tt>.</tt> character.
*/
std::string var;

Expr min, extent;

/** This lets you use a ReductionVariable as a key in a map of the form
Expand Down

0 comments on commit 9500c05

Please sign in to comment.