Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very simple revamp of dead code elimination. #28507

Open
wants to merge 1 commit into
base: mainnet
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
4 changes: 2 additions & 2 deletions compiler/ast/src/common/identifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ impl Identifier {

impl fmt::Display for Identifier {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.name)
self.name.fmt(f)
}
}
impl fmt::Debug for Identifier {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.name)
self.name.fmt(f)
}
}

Expand Down
12 changes: 3 additions & 9 deletions compiler/ast/src/expressions/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
use super::*;
use leo_span::Symbol;

use itertools::Itertools as _;

/// A function call expression, e.g.`foo(args)` or `Foo::bar(args)`.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CallExpression {
Expand All @@ -35,15 +37,7 @@ pub struct CallExpression {

impl fmt::Display for CallExpression {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}(", self.function)?;

for (i, param) in self.arguments.iter().enumerate() {
write!(f, "{param}")?;
if i < self.arguments.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ")")
write!(f, "{}({})", self.function, self.arguments.iter().format(", "))
}
}

Expand Down
8 changes: 7 additions & 1 deletion compiler/ast/src/expressions/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

use super::*;

use itertools::Itertools as _;

// TODO: Consider a restricted interface for constructing a tuple expression.

/// A tuple expression, e.g., `(foo, false, 42)`.
Expand All @@ -32,7 +34,11 @@ pub struct TupleExpression {

impl fmt::Display for TupleExpression {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "({})", self.elements.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(","))
if self.elements.len() == 1 {
write!(f, "({},)", self.elements[0])
} else {
write!(f, "({})", self.elements.iter().join(","))
}
}
}

Expand Down
5 changes: 1 addition & 4 deletions compiler/ast/src/passes/reconstructor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,7 @@ pub trait ExpressionReconstructor {
.into_iter()
.map(|member| StructVariableInitializer {
identifier: member.identifier,
expression: match member.expression {
Some(expression) => Some(self.reconstruct_expression(expression).0),
None => Some(self.reconstruct_expression(Expression::Identifier(member.identifier)).0),
},
expression: member.expression.map(|expr| self.reconstruct_expression(expr).0),
span: member.span,
id: member.id,
})
Expand Down
4 changes: 2 additions & 2 deletions compiler/ast/src/statement/console/console_statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ pub struct ConsoleStatement {

impl fmt::Display for ConsoleStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "console.{};", self.function)
write!(f, "console.{}", self.function)
}
}

impl fmt::Debug for ConsoleStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "console.{};", self.function)
write!(f, "console.{}", self.function)
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/ast/src/statement/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub struct ExpressionStatement {

impl fmt::Display for ExpressionStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{};", self.expression)
self.expression.fmt(f)
}
}

Expand Down
8 changes: 6 additions & 2 deletions compiler/ast/src/statement/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,19 @@ pub enum Statement {

impl Statement {
/// Returns a dummy statement made from an empty block `{}`.
pub fn dummy(span: Span, id: NodeID) -> Self {
Self::Block(Block { statements: Vec::new(), span, id })
pub fn dummy() -> Self {
Self::Block(Block { statements: Vec::new(), span: Default::default(), id: Default::default() })
}

pub(crate) fn semicolon(&self) -> &'static str {
use Statement::*;

if matches!(self, Block(..) | Conditional(..) | Iteration(..)) { "" } else { ";" }
}

pub fn is_empty(self: &Statement) -> bool {
matches!(self, Statement::Block(block) if block.statements.is_empty())
}
}

impl fmt::Display for Statement {
Expand Down
6 changes: 3 additions & 3 deletions compiler/compiler/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,9 @@ impl<'a, N: Network> Compiler<'a, N> {
}

/// Runs the dead code elimination pass.
pub fn dead_code_elimination_pass(&mut self) -> Result<()> {
pub fn dead_code_elimination_pass(&mut self, symbol_table: &SymbolTable) -> Result<()> {
if self.compiler_options.build.dce_enabled {
self.ast = DeadCodeEliminator::do_pass((std::mem::take(&mut self.ast), &self.node_builder))?;
self.ast = DeadCodeEliminator::do_pass((std::mem::take(&mut self.ast), symbol_table, &self.type_table))?;
}

if self.compiler_options.output.dce_ast {
Expand Down Expand Up @@ -347,7 +347,7 @@ impl<'a, N: Network> Compiler<'a, N> {

self.function_inlining_pass(&call_graph)?;

self.dead_code_elimination_pass()?;
self.dead_code_elimination_pass(&st)?;

Ok((st, struct_graph, call_graph))
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/compiler/tests/integration/utilities/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ pub fn compile_and_process<'a>(parsed: &'a mut Compiler<'a, CurrentNetwork>) ->

parsed.function_inlining_pass(&call_graph)?;

parsed.dead_code_elimination_pass()?;
parsed.dead_code_elimination_pass(&st)?;

// Compile Leo program to bytecode.
let bytecode = parsed.code_generation_pass(&st, &struct_graph, &call_graph)?;
Expand Down
5 changes: 2 additions & 3 deletions compiler/parser/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@

use crate::{ParserContext, SpannedToken, tokenizer};

use leo_ast::{NodeBuilder, NodeID, Statement};
use leo_ast::{NodeBuilder, Statement};
use leo_errors::{LeoError, emitter::Handler};
use leo_span::{
Span,
source_map::FileName,
symbol::{SessionGlobals, create_session_if_not_set_then},
};
Expand Down Expand Up @@ -120,7 +119,7 @@ impl Namespace for ParseStatementNamespace {
create_session_if_not_set_then(|s| {
let tokenizer = tokenize(test, s)?;
if all_are_comments(&tokenizer) {
return Ok(toml_or_fail(Statement::dummy(Span::default(), NodeID::default())));
return Ok(toml_or_fail(Statement::dummy()));
}
with_handler(tokenizer, |p| p.parse_statement()).map(toml_or_fail)
})
Expand Down
84 changes: 74 additions & 10 deletions compiler/passes/src/dead_code_elimination/dead_code_eliminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,89 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

use leo_ast::NodeBuilder;
use leo_span::Symbol;
use leo_ast::{AccessExpression, Expression, Location, Node, Type};
use leo_span::{Symbol, sym};

use indexmap::IndexSet;

use crate::{SymbolTable, TypeTable};

#[derive(Debug)]
pub struct DeadCodeEliminator<'a> {
/// A counter to generate unique node IDs.
pub(crate) node_builder: &'a NodeBuilder,
/// The set of used variables in the current function body.
pub(crate) used_variables: IndexSet<Symbol>,
/// Whether or not the variables are necessary.
pub(crate) is_necessary: bool,
/// Whether or not we are currently traversing an async function.
pub(crate) is_async: bool,

/// The name of the program currently being processed.
pub(crate) program_name: Symbol,

pub(crate) symbol_table: &'a SymbolTable,

pub(crate) type_table: &'a TypeTable,
}

impl<'a> DeadCodeEliminator<'a> {
/// Initializes a new `DeadCodeEliminator`.
pub fn new(node_builder: &'a NodeBuilder) -> Self {
Self { node_builder, used_variables: Default::default(), is_necessary: false, is_async: false }
pub(crate) fn new(symbol_table: &'a SymbolTable, type_table: &'a TypeTable) -> Self {
Self { used_variables: Default::default(), program_name: Symbol::intern(""), symbol_table, type_table }
}

fn contains_future_or_record(&self, ty: &Type) -> bool {
use Type::*;
match ty {
Array(array) => self.contains_future_or_record(array.element_type()),
Composite(composite) => {
let program = composite.program.unwrap_or(self.program_name);
let location = Location::new(program, composite.id.name);
// Struct or record can't contain a record or future, so
// we don't need to check for that.
self.symbol_table.lookup_record(location).is_some()
}

Tuple(tuple) => tuple.elements().iter().any(|ty| self.contains_future_or_record(ty)),

Future(..) => true,

Address | Boolean | Field | Group | Identifier(_) | Integer(_) | Mapping(_) | Scalar | Signature
| String | Unit | Err => false,
}
}

pub(crate) fn side_effect_free(&self, expr: &Expression) -> bool {
use Expression::*;

let sef = |expr| self.side_effect_free(expr);

match expr {
Access(AccessExpression::Array(array)) => sef(&array.array) && sef(&array.index),
Access(AccessExpression::AssociatedConstant(_)) => true,
Access(AccessExpression::AssociatedFunction(func)) => {
func.arguments.iter().all(sef)
&& !matches!(func.variant.name, sym::CheatCode | sym::Mapping | sym::Future)
}
Access(AccessExpression::Member(mem)) => sef(&mem.inner),
Access(AccessExpression::Tuple(tuple)) => sef(&tuple.tuple),
Array(array) => array.elements.iter().all(sef),
Binary(bin) => sef(&bin.left) && sef(&bin.right),
Call(call) => {
// A function call is side effect free if none of its arguments or returns
// contain a future or record, and all its arguments are side effect free.
let ret_ty = self.type_table.get(&call.id()).expect("Type checking should have provided a type.");
if self.contains_future_or_record(&ret_ty) {
false
} else {
call.arguments.iter().all(|arg| {
let ty = self.type_table.get(&arg.id()).expect("Type checking should have provided a type.");
sef(arg) && !self.contains_future_or_record(&ty)
})
}
}
Cast(cast) => sef(&cast.expression),
Struct(struct_) => struct_.members.iter().all(|mem| mem.expression.as_ref().map_or(true, sef)),
Ternary(tern) => [&*tern.condition, &*tern.if_true, &*tern.if_false].into_iter().all(sef),
Tuple(tuple) => tuple.elements.iter().all(sef),
Unary(un) => sef(&un.receiver),
Err(_) => false,
Identifier(_) | Literal(_) | Locator(_) | Unit(_) => true,
}
}
}
75 changes: 3 additions & 72 deletions compiler/passes/src/dead_code_elimination/eliminate_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,83 +16,14 @@

use crate::DeadCodeEliminator;

use leo_ast::{
AccessExpression,
AssociatedFunction,
Expression,
ExpressionReconstructor,
Identifier,
StructExpression,
StructVariableInitializer,
};
use leo_span::sym;
use leo_ast::{Expression, ExpressionReconstructor, Identifier};

impl ExpressionReconstructor for DeadCodeEliminator<'_> {
type AdditionalOutput = ();

/// Reconstructs the associated function access expression.
fn reconstruct_associated_function(&mut self, input: AssociatedFunction) -> (Expression, Self::AdditionalOutput) {
// If the associated function manipulates a mapping, or a cheat code, mark the statement as necessary.
match (&input.variant.name, input.name.name) {
(&sym::Mapping, sym::remove)
| (&sym::Mapping, sym::set)
| (&sym::Future, sym::Await)
| (&sym::CheatCode, _) => {
self.is_necessary = true;
}
_ => {}
};
// Reconstruct the access expression.
let result = (
Expression::Access(AccessExpression::AssociatedFunction(AssociatedFunction {
variant: input.variant,
name: input.name,
arguments: input.arguments.into_iter().map(|arg| self.reconstruct_expression(arg).0).collect(),
span: input.span,
id: input.id,
})),
Default::default(),
);
// Unset `self.is_necessary`.
self.is_necessary = false;
result
}

/// Reconstruct the components of the struct init expression.
/// This is necessary since the reconstructor does not explicitly visit each component of the expression.
fn reconstruct_struct_init(&mut self, input: StructExpression) -> (Expression, Self::AdditionalOutput) {
(
Expression::Struct(StructExpression {
name: input.name,
// Reconstruct each of the struct members.
members: input
.members
.into_iter()
.map(|member| StructVariableInitializer {
identifier: member.identifier,
expression: match member.expression {
Some(expression) => Some(self.reconstruct_expression(expression).0),
None => unreachable!("Static single assignment ensures that the expression always exists."),
},
span: member.span,
id: member.id,
})
.collect(),
span: input.span,
id: input.id,
}),
Default::default(),
)
}

/// Marks identifiers as used.
/// This is necessary to determine which statements can be eliminated from the program.
// Use and reconstruct an identifier.
fn reconstruct_identifier(&mut self, input: Identifier) -> (Expression, Self::AdditionalOutput) {
// Add the identifier to `self.used_variables`.
if self.is_necessary {
self.used_variables.insert(input.name);
}
// Return the identifier as is.
self.used_variables.insert(input.name);
(Expression::Identifier(input), Default::default())
}
}
28 changes: 11 additions & 17 deletions compiler/passes/src/dead_code_elimination/eliminate_program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,22 @@

use crate::DeadCodeEliminator;

use leo_ast::{Function, ProgramReconstructor, StatementReconstructor};
use leo_ast::{Function, ProgramReconstructor, StatementReconstructor as _};

impl ProgramReconstructor for DeadCodeEliminator<'_> {
fn reconstruct_function(&mut self, input: Function) -> Function {
fn reconstruct_function(&mut self, mut input: Function) -> Function {
// Reset the state of the dead code eliminator.
self.used_variables.clear();
self.is_necessary = false;
self.is_async = input.variant.is_async_function();

// Traverse the function body.
let block = self.reconstruct_block(input.block).0;

Function {
annotations: input.annotations,
variant: input.variant,
identifier: input.identifier,
input: input.input,
output: input.output,
output_type: input.output_type,
block,
span: input.span,
id: input.id,
}
input.block = self.reconstruct_block(input.block).0;

input
}

fn reconstruct_program_scope(&mut self, mut input: leo_ast::ProgramScope) -> leo_ast::ProgramScope {
self.program_name = input.program_id.name.name;
input.functions = input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect();
input
}
}
Loading
Loading