diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index 75d413dedad3e..98ff8fc83b373 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -1179,6 +1179,16 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { } } + pub fn codegen_block_as_unreachable(&mut self, bb: mir::BasicBlock) { + let llbb = match self.try_llbb(bb) { + Some(llbb) => llbb, + None => return, + }; + let bx = &mut Bx::build(self.cx, llbb); + debug!("codegen_block_as_unreachable({:?})", bb); + bx.unreachable(); + } + fn codegen_terminator( &mut self, bx: &mut Bx, diff --git a/compiler/rustc_codegen_ssa/src/mir/mod.rs b/compiler/rustc_codegen_ssa/src/mir/mod.rs index a6fcf1fd38c1f..a8d22ba7d51fe 100644 --- a/compiler/rustc_codegen_ssa/src/mir/mod.rs +++ b/compiler/rustc_codegen_ssa/src/mir/mod.rs @@ -256,13 +256,22 @@ pub fn codegen_mir<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( // Apply debuginfo to the newly allocated locals. fx.debug_introduce_locals(&mut start_bx); + let reachable_blocks = mir.reachable_blocks_in_mono(cx.tcx(), instance); + // The builders will be created separately for each basic block at `codegen_block`. // So drop the builder of `start_llbb` to avoid having two at the same time. drop(start_bx); // Codegen the body of each block using reverse postorder for (bb, _) in traversal::reverse_postorder(mir) { - fx.codegen_block(bb); + if reachable_blocks.contains(bb) { + fx.codegen_block(bb); + } else { + // This may have references to things we didn't monomorphize, so we + // don't actually codegen the body. We still create the block so + // terminators in other blocks can reference it without worry. + fx.codegen_block_as_unreachable(bb); + } } } diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs index 3017f912ef027..50ecbcac04f8e 100644 --- a/compiler/rustc_middle/src/mir/mod.rs +++ b/compiler/rustc_middle/src/mir/mod.rs @@ -10,7 +10,7 @@ use crate::ty::print::{pretty_print_const, with_no_trimmed_paths}; use crate::ty::print::{FmtPrinter, Printer}; use crate::ty::visit::TypeVisitableExt; use crate::ty::{self, List, Ty, TyCtxt}; -use crate::ty::{AdtDef, InstanceDef, UserTypeAnnotationIndex}; +use crate::ty::{AdtDef, Instance, InstanceDef, UserTypeAnnotationIndex}; use crate::ty::{GenericArg, GenericArgsRef}; use rustc_data_structures::captures::Captures; @@ -29,6 +29,7 @@ pub use rustc_ast::Mutability; use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::fx::FxHashSet; use rustc_data_structures::graph::dominators::Dominators; +use rustc_index::bit_set::BitSet; use rustc_index::{Idx, IndexSlice, IndexVec}; use rustc_serialize::{Decodable, Encodable}; use rustc_span::symbol::Symbol; @@ -642,6 +643,73 @@ impl<'tcx> Body<'tcx> { self.injection_phase.is_some() } + /// Finds which basic blocks are actually reachable for a specific + /// monomorphization of this body. + /// + /// This is allowed to have false positives; just because this says a block + /// is reachable doesn't mean that's necessarily true. It's thus always + /// legal for this to return a filled set. + /// + /// Regardless, the [`BitSet::domain_size`] of the returned set will always + /// exactly match the number of blocks in the body so that `contains` + /// checks can be done without worrying about panicking. + /// + /// The main case this supports is filtering out `if ::CONST` + /// bodies that can't be removed in generic MIR, but *can* be removed once + /// the specific `T` is known. + /// + /// This is used in the monomorphization collector as well as in codegen. + pub fn reachable_blocks_in_mono( + &self, + tcx: TyCtxt<'tcx>, + instance: Instance<'tcx>, + ) -> BitSet { + if instance.args.non_erasable_generics(tcx, instance.def_id()).next().is_none() { + // If it's non-generic, then mir-opt const prop has already run, meaning it's + // probably not worth doing any further filtering. So call everything reachable. + return BitSet::new_filled(self.basic_blocks.len()); + } + + let mut set = BitSet::new_empty(self.basic_blocks.len()); + self.reachable_blocks_in_mono_from(tcx, instance, &mut set, START_BLOCK); + set + } + + fn reachable_blocks_in_mono_from( + &self, + tcx: TyCtxt<'tcx>, + instance: Instance<'tcx>, + set: &mut BitSet, + bb: BasicBlock, + ) { + if !set.insert(bb) { + return; + } + + let data = &self.basic_blocks[bb]; + + if let TerminatorKind::SwitchInt { discr: Operand::Constant(constant), targets } = + &data.terminator().kind + { + let env = ty::ParamEnv::reveal_all(); + let mono_literal = instance.instantiate_mir_and_normalize_erasing_regions( + tcx, + env, + crate::ty::EarlyBinder::bind(constant.const_), + ); + if let Some(bits) = mono_literal.try_eval_bits(tcx, env) { + let target = targets.target_for_value(bits); + return self.reachable_blocks_in_mono_from(tcx, instance, set, target); + } else { + bug!("Couldn't evaluate constant {:?} in mono {:?}", constant, instance); + } + } + + for target in data.terminator().successors() { + self.reachable_blocks_in_mono_from(tcx, instance, set, target); + } + } + /// For a `Location` in this scope, determine what the "caller location" at that point is. This /// is interesting because of inlining: the `#[track_caller]` attribute of inlined functions /// must be honored. Falls back to the `tracked_caller` value for `#[track_caller]` functions, diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 945c3c662a604..4e73059373a33 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -108,6 +108,7 @@ mod check_alignment; pub mod simplify; mod simplify_branches; mod simplify_comparison_integral; +mod simplify_if_const; mod sroa; mod uninhabited_enum_branching; mod unreachable_prop; @@ -616,6 +617,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &large_enums::EnumSizeOpt { discrepancy: 128 }, // Some cleanup necessary at least for LLVM and potentially other codegen backends. &add_call_guards::CriticalCallEdges, + &simplify_if_const::SimplifyIfConst, // Cleanup for human readability, off by default. &prettify::ReorderBasicBlocks, &prettify::ReorderLocals, diff --git a/compiler/rustc_mir_transform/src/simplify_if_const.rs b/compiler/rustc_mir_transform/src/simplify_if_const.rs new file mode 100644 index 0000000000000..7adb714ad2825 --- /dev/null +++ b/compiler/rustc_mir_transform/src/simplify_if_const.rs @@ -0,0 +1,76 @@ +//! A pass that simplifies branches when their condition is known. + +use crate::MirPass; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +/// The lowering for `if CONST` produces +/// ``` +/// _1 = Const(...); +/// switchInt (move _1) +/// ``` +/// so this pass replaces that with +/// ``` +/// switchInt (Const(...)) +/// ``` +/// so that further MIR consumers can special-case it more easily. +/// +/// Unlike ConstProp, this supports generic constants too, not just concrete ones. +pub struct SimplifyIfConst; + +impl<'tcx> MirPass<'tcx> for SimplifyIfConst { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + for block in body.basic_blocks_mut() { + simplify_assign_move_switch(tcx, block); + } + } +} + +fn simplify_assign_move_switch(tcx: TyCtxt<'_>, block: &mut BasicBlockData<'_>) { + let Some(Terminator { kind: TerminatorKind::SwitchInt { discr: switch_desc, .. }, .. }) = + &mut block.terminator + else { + return; + }; + + let &mut Operand::Move(switch_place) = &mut *switch_desc else { return }; + + let Some(switch_local) = switch_place.as_local() else { return }; + + let Some(last_statement) = block.statements.last_mut() else { return }; + + let StatementKind::Assign(boxed_place_rvalue) = &last_statement.kind else { return }; + + let Some(assigned_local) = boxed_place_rvalue.0.as_local() else { return }; + + if switch_local != assigned_local { + return; + } + + if !matches!(boxed_place_rvalue.1, Rvalue::Use(Operand::Constant(_))) { + return; + } + + let should_optimize = tcx.consider_optimizing(|| { + format!( + "SimplifyBranches - Assignment: {:?} SourceInfo: {:?}", + boxed_place_rvalue, last_statement.source_info + ) + }); + + if should_optimize { + let Some(last_statement) = block.statements.pop() else { + bug!("Somehow the statement disappeared?"); + }; + + let StatementKind::Assign(boxed_place_rvalue) = last_statement.kind else { + bug!("Somehow it's not an assignment any more?"); + }; + + let Rvalue::Use(assigned_constant @ Operand::Constant(_)) = boxed_place_rvalue.1 else { + bug!("Somehow it's not a use of a constant any more?"); + }; + + *switch_desc = assigned_constant; + } +} diff --git a/tests/codegen/skip-mono-inside-if-false.rs b/tests/codegen/skip-mono-inside-if-false.rs new file mode 100644 index 0000000000000..2557daa91043c --- /dev/null +++ b/tests/codegen/skip-mono-inside-if-false.rs @@ -0,0 +1,43 @@ +// compile-flags: -O -C no-prepopulate-passes + +#![crate_type = "lib"] + +#[no_mangle] +pub fn demo_for_i32() { + generic_impl::(); +} + +// Two important things here: +// - We replace the "then" block with `unreachable` to avoid linking problems +// - We neither declare nor define the `big_impl` that said block "calls". + +// CHECK-LABEL: ; skip_mono_inside_if_false::generic_impl +// CHECK: start: +// CHECK-NEXT: br i1 false, label %[[THEN_BRANCH:bb[0-9]+]], label %[[ELSE_BRANCH:bb[0-9]+]] +// CHECK: [[ELSE_BRANCH]]: +// CHECK-NEXT: call skip_mono_inside_if_false::small_impl +// CHECK: [[THEN_BRANCH]]: +// CHECK-NEXT: unreachable + +// CHECK-NOT: @_ZN25skip_mono_inside_if_false8big_impl +// CHECK: define internal void @_ZN25skip_mono_inside_if_false10small_impl +// CHECK-NOT: @_ZN25skip_mono_inside_if_false8big_impl + +fn generic_impl() { + trait MagicTrait { + const IS_BIG: bool; + } + impl MagicTrait for T { + const IS_BIG: bool = std::mem::size_of::() > 10; + } + if T::IS_BIG { + big_impl::(); + } else { + small_impl::(); + } +} + +#[inline(never)] +fn small_impl() {} +#[inline(never)] +fn big_impl() {}