Skip to content

Commit

Permalink
Let BDD simplification address priority-selects with at least one bit…
Browse files Browse the repository at this point in the history
… set

This lets us trim at least one case, replacing the default with the first case that will definitely be selected if no higher-priority case is selected.

PiperOrigin-RevId: 648504801
  • Loading branch information
ericastor authored and copybara-github committed Jul 1, 2024
1 parent 07043c7 commit 329dcda
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 0 deletions.
2 changes: 2 additions & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,7 @@ cc_library(
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"//xls/common:module_initializer",
"//xls/common/status:status_macros",
"//xls/ir",
Expand Down Expand Up @@ -2421,6 +2422,7 @@ cc_test(
"//xls/ir:function_builder",
"//xls/ir:ir_matcher",
"//xls/ir:ir_test_base",
"//xls/solvers:z3_ir_equivalence_testutils",
"@com_google_googletest//:gtest",
],
)
Expand Down
30 changes: 30 additions & 0 deletions xls/passes/bdd_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "xls/common/module_initializer.h"
#include "xls/common/status/status_macros.h"
#include "xls/ir/bits.h"
Expand Down Expand Up @@ -333,6 +334,35 @@ absl::StatusOr<bool> SimplifyNode(Node* node, const QueryEngine& query_engine,
return true;
}

// Simplify kPrioritySelect operations where the selector is known to have at
// least one set bit.
if (NarrowingEnabled(opt_level) && node->Is<PrioritySelect>() &&
query_engine.AtLeastOneBitTrue(node->As<PrioritySelect>()->selector())) {
PrioritySelect* sel = node->As<PrioritySelect>();

int64_t last_bit = 0;
std::vector<TreeBitLocation> trailing_bits;
for (; last_bit < sel->selector()->BitCountOrDie() - 1; ++last_bit) {
trailing_bits.push_back(TreeBitLocation(sel->selector(), last_bit));
if (query_engine.AtLeastOneTrue(trailing_bits)) {
break;
}
}
DCHECK(last_bit < sel->selector()->BitCountOrDie() - 1 ||
query_engine.AtLeastOneTrue(trailing_bits));

XLS_ASSIGN_OR_RETURN(Node * new_selector,
node->function_base()->MakeNode<BitSlice>(
node->loc(), sel->selector(),
/*start=*/0, /*width=*/last_bit));
absl::Span<Node* const> new_cases = sel->cases().subspan(0, last_bit);
Node* new_default = sel->get_case(last_bit);
XLS_RETURN_IF_ERROR(node->ReplaceUsesWithNew<PrioritySelect>(
new_selector, new_cases, new_default)
.status());
return true;
}

return false;
}

Expand Down
41 changes: 41 additions & 0 deletions xls/passes/bdd_simplification_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "xls/ir/package.h"
#include "xls/passes/optimization_pass.h"
#include "xls/passes/pass_base.h"
#include "xls/solvers/z3_ir_equivalence_testutils.h"

namespace m = ::xls::op_matchers;

Expand Down Expand Up @@ -122,6 +123,46 @@ TEST_F(BddSimplificationPassTest, RemoveRedundantOneHot) {
EXPECT_THAT(f->return_value(), m::Concat(m::Eq(), m::Concat()));
}

TEST_F(BddSimplificationPassTest, RemoveRedundantPrioritySelectCases) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue x = fb.Param("x", p->GetBitsType(8));
BValue a = fb.Param("a", p->GetBitsType(24));
BValue b = fb.Param("b", p->GetBitsType(24));
BValue c = fb.Param("c", p->GetBitsType(24));
BValue d = fb.Param("d", p->GetBitsType(24));
BValue x_ge_5 = fb.UGe(x, fb.Literal(UBits(5, 8)));
BValue x_le_42 = fb.ULe(x, fb.Literal(UBits(42, 8)));
BValue x_eq_8 = fb.Eq(x, fb.Literal(UBits(8, 8)));
fb.PrioritySelect(fb.Concat({x_eq_8, x_ge_5, x_le_42}), /*cases=*/{a, b, c},
/*default_value=*/d);
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

solvers::z3::ScopedVerifyEquivalence sve{f};
EXPECT_THAT(Run(f), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(),
m::PrioritySelect(m::BitSlice(m::Concat()), {m::Param("a")},
m::Param("b")));
}

TEST_F(BddSimplificationPassTest, PreserveNonRedundantPrioritySelectCases) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue x = fb.Param("x", p->GetBitsType(8));
BValue a = fb.Param("a", p->GetBitsType(24));
BValue b = fb.Param("b", p->GetBitsType(24));
BValue c = fb.Param("c", p->GetBitsType(24));
BValue d = fb.Param("d", p->GetBitsType(24));
BValue x_ge_5 = fb.UGe(x, fb.Literal(UBits(5, 8)));
BValue x_lt_3 = fb.ULt(x, fb.Literal(UBits(3, 8)));
BValue x_eq_3 = fb.Eq(x, fb.Literal(UBits(3, 8)));
fb.PrioritySelect(fb.Concat({x_eq_3, x_ge_5, x_lt_3}), /*cases=*/{a, b, c},
/*default_value=*/d);
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

EXPECT_THAT(Run(f), IsOkAndHolds(false));
}

TEST_F(BddSimplificationPassTest, ConvertTwoWayOneHotSelect) {
auto p = CreatePackage();
XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction(R"(
Expand Down

0 comments on commit 329dcda

Please sign in to comment.