Skip to content

Commit

Permalink
Very simple revamp of dead code elimination.
Browse files Browse the repository at this point in the history
Enable DCE by default. Eliminates useless
associated function calls.

Also threw in just a few changes to Display
implementations for some nodes.

Fixes #28504
  • Loading branch information
mikebenfield committed Feb 20, 2025
1 parent c968579 commit e49b6f3
Show file tree
Hide file tree
Showing 682 changed files with 915 additions and 1,015 deletions.
4 changes: 2 additions & 2 deletions compiler/ast/src/common/identifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ impl Identifier {

impl fmt::Display for Identifier {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.name)
self.name.fmt(f)
}
}
impl fmt::Debug for Identifier {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.name)
self.name.fmt(f)
}
}

Expand Down
12 changes: 3 additions & 9 deletions compiler/ast/src/expressions/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
use super::*;
use leo_span::Symbol;

use itertools::Itertools as _;

/// A function call expression, e.g.`foo(args)` or `Foo::bar(args)`.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CallExpression {
Expand All @@ -35,15 +37,7 @@ pub struct CallExpression {

impl fmt::Display for CallExpression {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}(", self.function)?;

for (i, param) in self.arguments.iter().enumerate() {
write!(f, "{param}")?;
if i < self.arguments.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ")")
write!(f, "{}({})", self.function, self.arguments.iter().format(", "))
}
}

Expand Down
8 changes: 7 additions & 1 deletion compiler/ast/src/expressions/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

use super::*;

use itertools::Itertools as _;

// TODO: Consider a restricted interface for constructing a tuple expression.

/// A tuple expression, e.g., `(foo, false, 42)`.
Expand All @@ -32,7 +34,11 @@ pub struct TupleExpression {

impl fmt::Display for TupleExpression {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "({})", self.elements.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(","))
if self.elements.len() == 1 {
write!(f, "({},)", self.elements[0])
} else {
write!(f, "({})", self.elements.iter().join(","))
}
}
}

Expand Down
5 changes: 1 addition & 4 deletions compiler/ast/src/passes/reconstructor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,7 @@ pub trait ExpressionReconstructor {
.into_iter()
.map(|member| StructVariableInitializer {
identifier: member.identifier,
expression: match member.expression {
Some(expression) => Some(self.reconstruct_expression(expression).0),
None => Some(self.reconstruct_expression(Expression::Identifier(member.identifier)).0),
},
expression: member.expression.map(|expr| self.reconstruct_expression(expr).0),
span: member.span,
id: member.id,
})
Expand Down
4 changes: 2 additions & 2 deletions compiler/ast/src/statement/console/console_statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ pub struct ConsoleStatement {

impl fmt::Display for ConsoleStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "console.{};", self.function)
write!(f, "console.{}", self.function)
}
}

impl fmt::Debug for ConsoleStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "console.{};", self.function)
write!(f, "console.{}", self.function)
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/ast/src/statement/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub struct ExpressionStatement {

impl fmt::Display for ExpressionStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{};", self.expression)
self.expression.fmt(f)
}
}

Expand Down
8 changes: 6 additions & 2 deletions compiler/ast/src/statement/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,19 @@ pub enum Statement {

impl Statement {
/// Returns a dummy statement made from an empty block `{}`.
pub fn dummy(span: Span, id: NodeID) -> Self {
Self::Block(Block { statements: Vec::new(), span, id })
pub fn dummy() -> Self {
Self::Block(Block { statements: Vec::new(), span: Default::default(), id: Default::default() })
}

pub(crate) fn semicolon(&self) -> &'static str {
use Statement::*;

if matches!(self, Block(..) | Conditional(..) | Iteration(..)) { "" } else { ";" }
}

pub fn is_empty(self: &Statement) -> bool {
matches!(self, Statement::Block(block) if block.statements.is_empty())
}
}

impl fmt::Display for Statement {
Expand Down
6 changes: 3 additions & 3 deletions compiler/compiler/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,9 @@ impl<'a, N: Network> Compiler<'a, N> {
}

/// Runs the dead code elimination pass.
pub fn dead_code_elimination_pass(&mut self) -> Result<()> {
pub fn dead_code_elimination_pass(&mut self, symbol_table: &SymbolTable) -> Result<()> {
if self.compiler_options.build.dce_enabled {
self.ast = DeadCodeEliminator::do_pass((std::mem::take(&mut self.ast), &self.node_builder))?;
self.ast = DeadCodeEliminator::do_pass((std::mem::take(&mut self.ast), symbol_table, &self.type_table))?;
}

if self.compiler_options.output.dce_ast {
Expand Down Expand Up @@ -347,7 +347,7 @@ impl<'a, N: Network> Compiler<'a, N> {

self.function_inlining_pass(&call_graph)?;

self.dead_code_elimination_pass()?;
self.dead_code_elimination_pass(&st)?;

Ok((st, struct_graph, call_graph))
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/compiler/tests/integration/utilities/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ pub fn compile_and_process<'a>(parsed: &'a mut Compiler<'a, CurrentNetwork>) ->

parsed.function_inlining_pass(&call_graph)?;

parsed.dead_code_elimination_pass()?;
parsed.dead_code_elimination_pass(&st)?;

// Compile Leo program to bytecode.
let bytecode = parsed.code_generation_pass(&st, &struct_graph, &call_graph)?;
Expand Down
5 changes: 2 additions & 3 deletions compiler/parser/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@

use crate::{ParserContext, SpannedToken, tokenizer};

use leo_ast::{NodeBuilder, NodeID, Statement};
use leo_ast::{NodeBuilder, Statement};
use leo_errors::{LeoError, emitter::Handler};
use leo_span::{
Span,
source_map::FileName,
symbol::{SessionGlobals, create_session_if_not_set_then},
};
Expand Down Expand Up @@ -120,7 +119,7 @@ impl Namespace for ParseStatementNamespace {
create_session_if_not_set_then(|s| {
let tokenizer = tokenize(test, s)?;
if all_are_comments(&tokenizer) {
return Ok(toml_or_fail(Statement::dummy(Span::default(), NodeID::default())));
return Ok(toml_or_fail(Statement::dummy()));
}
with_handler(tokenizer, |p| p.parse_statement()).map(toml_or_fail)
})
Expand Down
84 changes: 74 additions & 10 deletions compiler/passes/src/dead_code_elimination/dead_code_eliminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,89 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

use leo_ast::NodeBuilder;
use leo_span::Symbol;
use leo_ast::{AccessExpression, Expression, Location, Node, Type};
use leo_span::{Symbol, sym};

use indexmap::IndexSet;

use crate::{SymbolTable, TypeTable};

#[derive(Debug)]
pub struct DeadCodeEliminator<'a> {
/// A counter to generate unique node IDs.
pub(crate) node_builder: &'a NodeBuilder,
/// The set of used variables in the current function body.
pub(crate) used_variables: IndexSet<Symbol>,
/// Whether or not the variables are necessary.
pub(crate) is_necessary: bool,
/// Whether or not we are currently traversing an async function.
pub(crate) is_async: bool,

/// The name of the program currently being processed.
pub(crate) program_name: Symbol,

pub(crate) symbol_table: &'a SymbolTable,

pub(crate) type_table: &'a TypeTable,
}

impl<'a> DeadCodeEliminator<'a> {
/// Initializes a new `DeadCodeEliminator`.
pub fn new(node_builder: &'a NodeBuilder) -> Self {
Self { node_builder, used_variables: Default::default(), is_necessary: false, is_async: false }
pub(crate) fn new(symbol_table: &'a SymbolTable, type_table: &'a TypeTable) -> Self {
Self { used_variables: Default::default(), program_name: Symbol::intern(""), symbol_table, type_table }
}

fn contains_future_or_record(&self, ty: &Type) -> bool {
use Type::*;
match ty {
Array(array) => self.contains_future_or_record(array.element_type()),
Composite(composite) => {
let program = composite.program.unwrap_or(self.program_name);
let location = Location::new(program, composite.id.name);
// Struct or record can't contain a record or future, so
// we don't need to check for that.
self.symbol_table.lookup_record(location).is_some()
}

Tuple(tuple) => tuple.elements().iter().any(|ty| self.contains_future_or_record(ty)),

Future(..) => true,

Address | Boolean | Field | Group | Identifier(_) | Integer(_) | Mapping(_) | Scalar | Signature
| String | Unit | Err => false,
}
}

pub(crate) fn side_effect_free(&self, expr: &Expression) -> bool {
use Expression::*;

let sef = |expr| self.side_effect_free(expr);

match expr {
Access(AccessExpression::Array(array)) => sef(&array.array) && sef(&array.index),
Access(AccessExpression::AssociatedConstant(_)) => true,
Access(AccessExpression::AssociatedFunction(func)) => {
func.arguments.iter().all(sef)
&& !matches!(func.variant.name, sym::CheatCode | sym::Mapping | sym::Future)
}
Access(AccessExpression::Member(mem)) => sef(&mem.inner),
Access(AccessExpression::Tuple(tuple)) => sef(&tuple.tuple),
Array(array) => array.elements.iter().all(sef),
Binary(bin) => sef(&bin.left) && sef(&bin.right),
Call(call) => {
// A function call is side effect free if none of its arguments or returns
// contain a future or record, and all its arguments are side effect free.
let ret_ty = self.type_table.get(&call.id()).expect("Type checking should have provided a type.");
if self.contains_future_or_record(&ret_ty) {
false
} else {
call.arguments.iter().all(|arg| {
let ty = self.type_table.get(&arg.id()).expect("Type checking should have provided a type.");
sef(arg) && !self.contains_future_or_record(&ty)
})
}
}
Cast(cast) => sef(&cast.expression),
Struct(struct_) => struct_.members.iter().all(|mem| mem.expression.as_ref().map_or(true, sef)),
Ternary(tern) => [&*tern.condition, &*tern.if_true, &*tern.if_false].into_iter().all(sef),
Tuple(tuple) => tuple.elements.iter().all(sef),
Unary(un) => sef(&un.receiver),
Err(_) => false,
Identifier(_) | Literal(_) | Locator(_) | Unit(_) => true,
}
}
}
75 changes: 3 additions & 72 deletions compiler/passes/src/dead_code_elimination/eliminate_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,83 +16,14 @@

use crate::DeadCodeEliminator;

use leo_ast::{
AccessExpression,
AssociatedFunction,
Expression,
ExpressionReconstructor,
Identifier,
StructExpression,
StructVariableInitializer,
};
use leo_span::sym;
use leo_ast::{Expression, ExpressionReconstructor, Identifier};

impl ExpressionReconstructor for DeadCodeEliminator<'_> {
type AdditionalOutput = ();

/// Reconstructs the associated function access expression.
fn reconstruct_associated_function(&mut self, input: AssociatedFunction) -> (Expression, Self::AdditionalOutput) {
// If the associated function manipulates a mapping, or a cheat code, mark the statement as necessary.
match (&input.variant.name, input.name.name) {
(&sym::Mapping, sym::remove)
| (&sym::Mapping, sym::set)
| (&sym::Future, sym::Await)
| (&sym::CheatCode, _) => {
self.is_necessary = true;
}
_ => {}
};
// Reconstruct the access expression.
let result = (
Expression::Access(AccessExpression::AssociatedFunction(AssociatedFunction {
variant: input.variant,
name: input.name,
arguments: input.arguments.into_iter().map(|arg| self.reconstruct_expression(arg).0).collect(),
span: input.span,
id: input.id,
})),
Default::default(),
);
// Unset `self.is_necessary`.
self.is_necessary = false;
result
}

/// Reconstruct the components of the struct init expression.
/// This is necessary since the reconstructor does not explicitly visit each component of the expression.
fn reconstruct_struct_init(&mut self, input: StructExpression) -> (Expression, Self::AdditionalOutput) {
(
Expression::Struct(StructExpression {
name: input.name,
// Reconstruct each of the struct members.
members: input
.members
.into_iter()
.map(|member| StructVariableInitializer {
identifier: member.identifier,
expression: match member.expression {
Some(expression) => Some(self.reconstruct_expression(expression).0),
None => unreachable!("Static single assignment ensures that the expression always exists."),
},
span: member.span,
id: member.id,
})
.collect(),
span: input.span,
id: input.id,
}),
Default::default(),
)
}

/// Marks identifiers as used.
/// This is necessary to determine which statements can be eliminated from the program.
// Use and reconstruct an identifier.
fn reconstruct_identifier(&mut self, input: Identifier) -> (Expression, Self::AdditionalOutput) {
// Add the identifier to `self.used_variables`.
if self.is_necessary {
self.used_variables.insert(input.name);
}
// Return the identifier as is.
self.used_variables.insert(input.name);
(Expression::Identifier(input), Default::default())
}
}
28 changes: 11 additions & 17 deletions compiler/passes/src/dead_code_elimination/eliminate_program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,22 @@

use crate::DeadCodeEliminator;

use leo_ast::{Function, ProgramReconstructor, StatementReconstructor};
use leo_ast::{Function, ProgramReconstructor, StatementReconstructor as _};

impl ProgramReconstructor for DeadCodeEliminator<'_> {
fn reconstruct_function(&mut self, input: Function) -> Function {
fn reconstruct_function(&mut self, mut input: Function) -> Function {
// Reset the state of the dead code eliminator.
self.used_variables.clear();
self.is_necessary = false;
self.is_async = input.variant.is_async_function();

// Traverse the function body.
let block = self.reconstruct_block(input.block).0;

Function {
annotations: input.annotations,
variant: input.variant,
identifier: input.identifier,
input: input.input,
output: input.output,
output_type: input.output_type,
block,
span: input.span,
id: input.id,
}
input.block = self.reconstruct_block(input.block).0;

input
}

fn reconstruct_program_scope(&mut self, mut input: leo_ast::ProgramScope) -> leo_ast::ProgramScope {
self.program_name = input.program_id.name.name;
input.functions = input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect();
input
}
}
Loading

0 comments on commit e49b6f3

Please sign in to comment.