diff --git a/src/ast.rs b/src/ast.rs index 37b845c..5c44dc9 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1,20 +1,4 @@ /// The AST maps closely to assembly for simplicity. -#[derive(Debug, Eq, PartialEq, Default, Clone)] -pub struct Node { - pub statement: Statement, - pub child: Option, - pub next: Option, -} -impl Node { - pub fn new(statement: Statement) -> Self { - Self { - statement, - child: None, - next: None, - } - } -} - use std::ptr::NonNull; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -23,6 +7,7 @@ pub enum Preceding { Previous(NonNull), } +#[derive(Debug)] pub struct NewNode { pub statement: Statement, pub preceding: Option, diff --git a/src/main.rs b/src/main.rs index eb43821..3edda44 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,6 @@ extern crate test; use std::io::Read; mod ast; -#[allow(unused_imports)] use ast::*; mod frontend; use frontend::*; @@ -31,7 +30,8 @@ fn main() { let reader = std::io::BufReader::new(empty); let mut iter = reader.bytes().peekable(); let nodes = get_nodes(&mut iter).unwrap(); - let path = explore(nodes); + let roots = roots(nodes); + let path = explore(&roots); let optimized_nodes = optimize(path); // let nodes = optimize_nodes(&nodes); let _assembly = assembly_from_node(optimized_nodes); @@ -96,28 +96,49 @@ mod tests { nodes } - fn test_exploration( - nodes: NonNull, - expected: &[TypeValueState], - ) -> NonNull { - let states = unsafe { explore(nodes) }; - - let mut index = 0; - let mut stack = vec![states]; + unsafe fn check_states(actual: NonNull, expected: &[TypeValueState]) { + let mut stack = vec![actual]; + let mut set = Vec::new(); while let Some(current) = stack.pop() { - let node = unsafe { current.as_ref() }; + let node = current.as_ref(); if let Some(next_one) = node.next.0 { stack.push(next_one); } if let Some(next_two) = node.next.1 { stack.push(next_two); } - - assert_eq!(Some(&node.state), expected.get(index)); - index += 1; + set.push(node.state.clone()); } - states + assert_eq!(set, expected); + } + + fn test_exploration( + nodes: NonNull, + expected_roots: &[TypeValueState], + expected_states: &[&[TypeValueState]], + expected_path: &[TypeValueState], + ) -> NonNull { + unsafe { + let roots = roots(nodes); + for (actual, expected) in roots.iter().zip(expected_roots.iter()) { + assert_eq!(&actual.as_ref().state, expected); + } + + let mut explorer = Explorer::new(&roots); + let mut expected_iter = expected_states.iter(); + let finished = loop { + match explorer.next() { + Explore::Current(current) => { + check_states(current, expected_iter.next().unwrap()) + } + Explore::Finished(finished) => break finished, + } + }; + println!("finished"); + check_states(finished, expected_path); + finished + } } fn test_optimization(nodes: NonNull, expected: &[Statement]) -> NonNull { @@ -215,7 +236,12 @@ mod tests { ); // Exploration - let path = test_exploration(nodes, &[TypeValueState::new()]); + let path = test_exploration( + nodes, + &[TypeValueState::new()], + &[&[TypeValueState::new()]], + &[TypeValueState::new()], + ); // Optimization let optimized = test_optimization( @@ -256,7 +282,12 @@ mod tests { ); // Exploration - let path = test_exploration(nodes, &[TypeValueState::new()]); + let path = test_exploration( + nodes, + &[TypeValueState::new()], + &[&[TypeValueState::new()]], + &[TypeValueState::new()], + ); // Optimization let optimized = test_optimization( @@ -297,7 +328,12 @@ mod tests { ); // Exploration - let path = test_exploration(nodes, &[TypeValueState::new()]); + let path = test_exploration( + nodes, + &[TypeValueState::new()], + &[&[TypeValueState::new()]], + &[TypeValueState::new()], + ); // Optimization let optimized = test_optimization( @@ -345,8 +381,12 @@ mod tests { ); // Exploration - let path = test_exploration(nodes, &[TypeValueState::new()]); - + let path = test_exploration( + nodes, + &[TypeValueState::new()], + &[&[TypeValueState::new()]], + &[TypeValueState::new()], + ); // Optimization let optimized = test_optimization( path, @@ -402,6 +442,106 @@ mod tests { // Exploration let path = test_exploration( nodes, + &[ + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(1))), + )]), + ], + &[ + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(1))), + )])], + ], &[ TypeValueState::from([( ident("x"), @@ -465,6 +605,106 @@ mod tests { // Exploration let path = test_exploration( nodes, + &[ + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(1))), + )]), + ], + &[ + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(1))), + )])], + ], &[ TypeValueState::from([( ident("x"), @@ -536,6 +776,138 @@ mod tests { // Exploration let path = test_exploration( nodes, + &[ + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(1))), + )]), + ], + &[ + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(2))), + )])], + ], &[ TypeValueState::from([( ident("x"), @@ -617,10 +989,138 @@ mod tests { let path = test_exploration( nodes, &[ + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(1))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(1))), + )]), TypeValueState::from([( ident("x"), TypeValue::Integer(TypeValueInteger::U8(MyRange::from(1))), )]), + ], + &[ + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(1))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(1))), + )])], + ], + &[ TypeValueState::from([( ident("x"), TypeValue::Integer(TypeValueInteger::U8(MyRange::from(1))), @@ -639,22 +1139,11 @@ mod tests { // Optimization let optimized = test_optimization( path, - &[ - Statement { - comptime: false, - op: Op::Special(Special::Type), - arg: vec![ - Value::Variable(Variable::new("x")), - Value::Type(Type::U8), - Value::Literal(Literal::Integer(1)), - ], - }, - Statement { - comptime: false, - op: Op::Syscall(Syscall::Exit), - arg: vec![Value::Literal(Literal::Integer(0))], - }, - ], + &[Statement { + comptime: false, + op: Op::Syscall(Syscall::Exit), + arg: vec![Value::Literal(Literal::Integer(0))], + }], ); // Assembly @@ -666,103 +1155,221 @@ mod tests { mov x8, #93\n\ mov x0, #0\n\ svc #0\n\ - .data\n\ - x: .byte 1\n\ ", - 2, + 0, ); } - #[cfg(feature = "false")] + #[test] fn ten() { - const TEN: &str = "x := 2\nif x = 2\n exit 1\nexit 0"; + const SOURCE: &str = "x := 2\nif x = 2\n exit 1\nexit 0"; - let nodes = parse(TEN); - assert_eq!( - nodes, - [ - Node { - statement: Statement { - comptime: false, - op: Op::Intrinsic(Intrinsic::Assign), - arg: vec![ - Value::Variable(Variable::new("x")), - Value::Literal(Literal::Integer(2)) - ] - }, - child: None, - next: Some(1), + // Parsing + let nodes = test_parsing( + SOURCE, + &[ + Statement { + comptime: false, + op: Op::Intrinsic(Intrinsic::Assign), + arg: vec![ + Value::Variable(Variable::new("x")), + Value::Literal(Literal::Integer(2)), + ], }, - Node { - statement: Statement { - comptime: false, - op: Op::Intrinsic(Intrinsic::If(Cmp::Eq)), - arg: vec![ - Value::Variable(Variable::new("x")), - Value::Literal(Literal::Integer(2)) - ] - }, - child: Some(2), - next: Some(3), + Statement { + comptime: false, + op: Op::Intrinsic(Intrinsic::If(Cmp::Eq)), + arg: vec![ + Value::Variable(Variable::new("x")), + Value::Literal(Literal::Integer(2)), + ], }, - Node { - statement: Statement { - comptime: false, - op: Op::Syscall(Syscall::Exit), - arg: vec![Value::Literal(Literal::Integer(1))] - }, - child: None, - next: None, + Statement { + comptime: false, + op: Op::Syscall(Syscall::Exit), + arg: vec![Value::Literal(Literal::Integer(1))], }, - Node { - statement: Statement { - comptime: false, - op: Op::Syscall(Syscall::Exit), - arg: vec![Value::Literal(Literal::Integer(0))] - }, - child: None, - next: None, + Statement { + comptime: false, + op: Op::Syscall(Syscall::Exit), + arg: vec![Value::Literal(Literal::Integer(0))], }, - ] + ], ); - let optimized_nodes = optimize(&nodes); - assert_eq!( - optimized_nodes, - [ - Node { - statement: Statement { - comptime: false, - op: Op::Special(Special::Type), - arg: vec![ - Value::Variable(Variable::new("x")), - Value::Type(Type::U8), - Value::Literal(Literal::Integer(2)) - ] - }, - child: None, - next: Some(1), - }, - Node { - statement: Statement { - comptime: false, - op: Op::Syscall(Syscall::Exit), - arg: vec![Value::Literal(Literal::Integer(1))] - }, - child: None, - next: None, - }, - ] + + // Exploration + let path = test_exploration( + nodes, + &[ + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(2))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(2))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(2))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(2))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(2))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(2))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(2))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(2))), + )]), + ], + &[ + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U16(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U32(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U64(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I8(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I16(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I32(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(2))), + )])], + &[TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::I64(MyRange::from(2))), + )])], + ], + &[ + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(2))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(2))), + )]), + TypeValueState::from([( + ident("x"), + TypeValue::Integer(TypeValueInteger::U8(MyRange::from(2))), + )]), + ], ); - let expected_assembly = "\ + + // Optimization + let optimized = test_optimization( + path, + &[Statement { + comptime: false, + op: Op::Syscall(Syscall::Exit), + arg: vec![Value::Literal(Literal::Integer(1))], + }], + ); + + // Assembly + test_assembling( + optimized, + "\ .global _start\n\ _start:\n\ mov x8, #93\n\ mov x0, #1\n\ svc #0\n\ - .data\n\ - x: .byte 2\n\ - "; - assemble(&optimized_nodes, expected_assembly, 1); + ", + 1, + ); } #[cfg(feature = "false")] diff --git a/src/middle.rs b/src/middle.rs index 30cd155..ead833e 100644 --- a/src/middle.rs +++ b/src/middle.rs @@ -2,60 +2,20 @@ use crate::ast::*; use num_traits::bounds::Bounded; use num_traits::identities::One; use num_traits::identities::Zero; +use std::alloc; use std::cmp::Ordering; use std::collections::HashMap; use std::collections::HashSet; +use std::ptr; +use std::ptr::NonNull; const DEFAULT_LOOP_LIMIT: usize = 100; - -pub unsafe fn optimize_temp(nodes: NonNull) -> NonNull { - // Get all possible paths. - let graph = explore(nodes); - - let mut type_state = TypeState::new(); - let mut stack = vec![graph]; - while let Some(current) = stack.pop() { - let node = current.as_ref(); - if let Some(a) = node.next.0 { - stack.push(a); - } - if let Some(b) = node.next.1 { - stack.push(b); - } - for (ident, value_type) in node.state.iter() { - type_state.insert(ident.clone(), Type::from(value_type.clone())); - } - } - - todo!() - - // Use the best path to apply optimizations to the nodes. - // optimize(graph, type_state) -} - +#[allow(dead_code)] const UNROLL_LIMIT: usize = 4096; -// Applies typical optimizations. E.g. removing unused variables, unreachable code, etc. -pub unsafe fn optimize(graph: NonNull) -> NonNull { - // Remove unreachable nodes. - // let mut reachable_nodes = remove_unreachable_nodes(nodes, start, graph_nodes); - - // // unroll_loops(&mut reachable_nodes, start, graph_nodes); - - // // TODO `inline_ifs` doesn't currently update `graph_nodes` so it cannot be used after. It may - // // even be bugged currently within `inline_ifs`. Change `inline_ifs` to update `graph_nodes` so - // // its consistent. - // inline_ifs(&mut reachable_nodes, start, graph_nodes); - - // remove_requires(&mut reachable_nodes); - - // // Find assignments that can be promoted to allocations. - // promote_assignments(&mut reachable_nodes, &type_state); - - // reachable_nodes - - // Construct new optimized abstract syntax tree. - +pub unsafe fn get_optimized_tree( + graph: NonNull, +) -> (NonNull, HashSet) { let mut types = HashMap::new(); let mut read = HashSet::new(); @@ -109,7 +69,7 @@ pub unsafe fn optimize(graph: NonNull) -> NonNull { current_new_node.as_mut().statement.arg = vec![Value::Literal(Literal::Integer(i128::from(exact)))]; } else { - read.insert(identifier); + read.insert(identifier.clone()); } } _ => todo!(), @@ -162,6 +122,45 @@ pub unsafe fn optimize(graph: NonNull) -> NonNull { _ => todo!(), } } + Op::Intrinsic(Intrinsic::If(Cmp::Eq)) => { + match node.next { + (Some(_), None) | (None, Some(_)) => { + // Update AST + let old_node = current_new_node; + match current_new_node.as_ref().preceding { + Some(Preceding::Parent(mut parent)) => { + parent.as_mut().child = None; + current_new_node = parent; + } + Some(Preceding::Previous(mut previous)) => { + previous.as_mut().next = None; + current_new_node = previous; + } + None => unreachable!(), + } + + // Update state graph + if let Some(mut next_one) = current.as_mut().next.0 { + next_one.as_mut().prev = current.as_ref().prev; + } + if let Some(mut next_two) = current.as_mut().next.1 { + next_two.as_mut().prev = current.as_ref().prev; + } + let mut prev = current.as_ref().prev.unwrap(); + prev.as_mut().next = current.as_ref().next; + let old_state = current; + current = prev; + + // Deallocate nodes + alloc::dealloc( + old_state.as_ptr().cast(), + alloc::Layout::new::(), + ); + alloc::dealloc(old_node.as_ptr().cast(), alloc::Layout::new::()); + } + _ => todo!(), + } + } _ => todo!(), } @@ -169,35 +168,60 @@ pub unsafe fn optimize(graph: NonNull) -> NonNull { match node.next { (Some(next), None) | (None, Some(next)) => { - let statement_ref = next.as_ref().statement.as_ref(); - let mut temp = { - let ptr = alloc::alloc(alloc::Layout::new::()).cast::(); - ptr::write(ptr, NewNode::new(statement_ref.statement.clone())); - NonNull::new(ptr).unwrap() - }; + match node.statement.as_ref().statement.op { + Op::Intrinsic(Intrinsic::If(_)) | Op::Intrinsic(Intrinsic::Loop) => { + let statement_ref = next.as_ref().statement.as_ref(); + let mut temp = { + let ptr = + alloc::alloc(alloc::Layout::new::()).cast::(); + ptr::write(ptr, NewNode::new(statement_ref.statement.clone())); + NonNull::new(ptr).unwrap() + }; - temp.as_mut().preceding = statement_ref.preceding; - match statement_ref.preceding.unwrap() { - Preceding::Parent(_) => { - current_new_node.as_mut().child = Some(temp); + // Succedding item may be child or next + temp.as_mut().preceding = statement_ref.preceding; + match statement_ref.preceding.unwrap() { + Preceding::Parent(_) => { + current_new_node.as_mut().child = Some(temp); + } + Preceding::Previous(_) => { + current_new_node.as_mut().next = Some(temp); + } + } + current_new_node = temp; + + stack.push(next); } - Preceding::Previous(_) => { + _ => { + let statement_ref = next.as_ref().statement.as_ref(); + let mut temp = { + let ptr = + alloc::alloc(alloc::Layout::new::()).cast::(); + ptr::write(ptr, NewNode::new(statement_ref.statement.clone())); + NonNull::new(ptr).unwrap() + }; + + // Succedding item can only be next + temp.as_mut().preceding = Some(Preceding::Previous(current_new_node)); + current_new_node.as_mut().child = None; current_new_node.as_mut().next = Some(temp); + current_new_node = temp; + + stack.push(next); } } - current_new_node = temp; - - stack.push(next); } (Some(_), Some(_)) => todo!(), (None, None) => break, } } + (new_nodes, read) +} - // TODO This doesn't dealloc anything in `graph` which may be very very big. Do this deallocation. - - dbg!(&read); - +pub unsafe fn strip_optimized_tree( + new_nodes: NonNull, + read: HashSet, +) -> NonNull { // After iterating through the full AST we now know which variables are used so can remove // unused variables let mut first = new_nodes; @@ -248,14 +272,29 @@ pub unsafe fn optimize(graph: NonNull) -> NonNull { first } -unsafe fn new_passback_end(mut node: NonNull, _end: GraphNodeEnd) { +// Applies typical optimizations. E.g. removing unused variables, unreachable code, etc. +pub unsafe fn optimize(graph: NonNull) -> NonNull { + // Construct new optimized abstract syntax tree. + // TODO This doesn't dealloc anything in `graph` which may be very very big. Do this deallocation. + let (new_nodes, read) = get_optimized_tree(graph); + + dbg!(&read); + + strip_optimized_tree(new_nodes, read) +} + +unsafe fn new_passback_end(mut node: NonNull, end: GraphNodeEnd) { debug_assert!(node.as_ref().unexplored.0.is_empty()); debug_assert!(node.as_ref().unexplored.1.is_empty()); debug_assert!(node.as_ref().cost.is_none()); + dbg!(&end); + println!("{:?}", &node.as_ref().state); + dbg!(&node.as_ref().statement.as_ref().statement.op); + // TODO Simplify this. - let cost = TypeState::from(node.as_ref().state.clone()).cost(); - node.as_mut().cost = Some(cost); + let type_state_cost = TypeState::from(node.as_ref().state.clone()).cost(); + node.as_mut().cost = Some(type_state_cost.saturating_add(end.cost())); // Backpropagate the cost to `node.prev`, deallocating nodes not in the lowest cost path. while let Some(mut prev) = node.as_ref().prev { @@ -350,6 +389,8 @@ unsafe fn new_passback_end(mut node: NonNull, _end: GraphNodeEnd) node = prev; } } + +#[derive(Debug)] enum GraphNodeEnd { /// E.g. `exit` syscall Valid, @@ -409,7 +450,7 @@ unsafe fn new_append( .collect() } -use std::ptr::NonNull; +#[derive(Debug)] pub struct NewStateNode { pub state: TypeValueState, pub statement: NonNull, @@ -424,9 +465,6 @@ pub struct NewStateNode { pub cost: Option, } -use std::alloc; -use std::ptr; - enum IfBool { True, False, @@ -533,8 +571,21 @@ unsafe fn explore_if( } } -pub unsafe fn explore(nodes: NonNull) -> NonNull { - let roots = get_possible_states(&nodes.as_ref().statement, &TypeValueState::new()) +unsafe fn print_ast(nodes: NonNull) { + let mut stack = vec![(nodes, 0)]; + while let Some((current, indent)) = stack.pop() { + println!("{}{:?}", " ".repeat(indent), current.as_ref().statement); + if let Some(next) = current.as_ref().next { + stack.push((next, indent)); + } + if let Some(child) = current.as_ref().child { + stack.push((child, indent + 1)); + } + } +} + +pub unsafe fn roots(node: NonNull) -> Vec> { + get_possible_states(&node.as_ref().statement, &TypeValueState::new()) .into_iter() .map(|state| { let ptr = alloc::alloc(alloc::Layout::new::()).cast::(); @@ -542,7 +593,7 @@ pub unsafe fn explore(nodes: NonNull) -> NonNull { ptr, NewStateNode { state, - statement: nodes, + statement: node, prev: None, next: (None, None), unexplored: (Vec::new(), Vec::new()), @@ -553,133 +604,85 @@ pub unsafe fn explore(nodes: NonNull) -> NonNull { ); NonNull::new(ptr).unwrap() }) - .collect::>(); - let mut stack = roots.clone(); + .collect() +} - while let Some(mut current) = stack.pop() { - let current_ref = current.as_mut(); - let ast_node = current_ref.statement.as_ref(); - let statement = &ast_node.statement; - - match statement.op { - Op::Intrinsic(Intrinsic::If(Cmp::Eq)) => match statement.arg.as_slice() { - [Value::Variable(Variable { identifier, .. }), Value::Literal(Literal::Integer(x))] => - { - let _scope = current_ref.scope; - let y = current_ref - .state - .get(identifier) - .unwrap() - .integer() - .unwrap(); - - let if_bool = if y.value() == Some(*x) { - IfBool::True - } else if y.excludes(*x) { - IfBool::False - } else { - IfBool::Unknown - }; +pub unsafe fn explore_node( + mut current: NonNull, + stack: &mut Vec>, +) { + let current_ref = current.as_mut(); + let ast_node = current_ref.statement.as_ref(); + let statement = &ast_node.statement; - explore_if(if_bool, current, &mut stack); - } - _ => todo!(), - }, - Op::Intrinsic(Intrinsic::If(Cmp::Lt)) => match statement.arg.as_slice() { - [Value::Variable(Variable { identifier, .. }), Value::Literal(Literal::Integer(x))] => - { - let _scope = current_ref.scope; - let y = current_ref - .state - .get(identifier) - .unwrap() - .integer() - .unwrap(); - - let if_bool = if y.max() < *x { - IfBool::True - } else if y.min() >= *x { - IfBool::False - } else { - IfBool::Unknown - }; - explore_if(if_bool, current, &mut stack); - } - _ => todo!(), - }, - Op::Intrinsic(Intrinsic::Loop) => { - match statement.arg.as_slice() { - [] => { - current_ref.unexplored = ( - if let Some(child) = ast_node.child { - // Since `new_append` applies the same loop limit with the last - // element -1, we need to add a loop limit to the loop node so - // the children get the limit. - // TODO Do this in a better way. - current_ref.loop_limit.push(DEFAULT_LOOP_LIMIT); - let temp = new_append(current, Some(current), child, &mut stack); - // Since the loop isn't actually in a loop we pop it after. - current_ref.loop_limit.pop(); - temp - } else if let Some(next) = ast_node.next { - new_append(current, current_ref.scope, next, &mut stack) - } - // If this AST node has no next, look for next node in parents. - else { - 'outer: loop { - let mut preceding_opt = ast_node.preceding; - let parent = loop { - match preceding_opt { - None => break 'outer Vec::new(), - Some(Preceding::Previous(previous)) => { - preceding_opt = previous.as_ref().preceding; - } - Some(Preceding::Parent(parent)) => break parent, - } - }; - - // If this would exit a loop, the next statement is the 1st statement of the loop. - if parent.as_ref().statement.op - == Op::Intrinsic(Intrinsic::Loop) - { - debug_assert_eq!( - current_ref.scope.unwrap().as_ref().statement, - parent - ); - let parent_child = parent.as_ref().child.unwrap(); - break new_append( - current, - current_ref.scope, - parent_child, - &mut stack, - ); - } - // Else if this wouldn't exit a loop, the next statement is the next statement of this parent if there is a next. - else if let Some(parent_next) = parent.as_ref().next { - break new_append( - current, - current_ref.scope, - parent_next, - &mut stack, - ); - } - } - }, - Vec::new(), - ) - } - _ => todo!(), - } + debug_assert!(current_ref.unexplored.0.is_empty()); + debug_assert!(current_ref.unexplored.1.is_empty()); + debug_assert!(current_ref.next.0.is_none()); + debug_assert!(current_ref.next.1.is_none()); + + match statement.op { + Op::Intrinsic(Intrinsic::If(Cmp::Eq)) => match statement.arg.as_slice() { + [Value::Variable(Variable { identifier, .. }), Value::Literal(Literal::Integer(x))] => { + let _scope = current_ref.scope; + let y = current_ref + .state + .get(identifier) + .unwrap() + .integer() + .unwrap(); + + let if_bool = if y.value() == Some(*x) { + IfBool::True + } else if y.excludes(*x) { + IfBool::False + } else { + IfBool::Unknown + }; + + current_ref.unexplored = explore_if(if_bool, current, stack); + } + _ => todo!(), + }, + Op::Intrinsic(Intrinsic::If(Cmp::Lt)) => match statement.arg.as_slice() { + [Value::Variable(Variable { identifier, .. }), Value::Literal(Literal::Integer(x))] => { + let _scope = current_ref.scope; + let y = current_ref + .state + .get(identifier) + .unwrap() + .integer() + .unwrap(); + + let if_bool = if y.max() < *x { + IfBool::True + } else if y.min() >= *x { + IfBool::False + } else { + IfBool::Unknown + }; + current_ref.unexplored = explore_if(if_bool, current, stack); } - Op::Intrinsic(Intrinsic::Break) => match statement.arg.as_slice() { + _ => todo!(), + }, + Op::Intrinsic(Intrinsic::Loop) => { + match statement.arg.as_slice() { [] => { - let prev_scope_graph_node = current_ref.scope.unwrap(); - let scope_node = prev_scope_graph_node.as_ref().statement; - current.as_mut().unexplored = ( - if let Some(next) = scope_node.as_ref().next { - let scope = prev_scope_graph_node.as_ref().scope; - new_append(current, scope, next, &mut stack) - } else { + current_ref.unexplored = ( + if let Some(child) = ast_node.child { + // Since `new_append` applies the same loop limit with the last + // element -1, we need to add a loop limit to the loop node so + // the children get the limit. + // TODO Do this in a better way. + current_ref.loop_limit.push(DEFAULT_LOOP_LIMIT); + let temp = new_append(current, Some(current), child, stack); + // Since the loop isn't actually in a loop we pop it after. + current_ref.loop_limit.pop(); + temp + } else if let Some(next) = ast_node.next { + new_append(current, current_ref.scope, next, stack) + } + // If this AST node has no next, look for next node in parents. + else { 'outer: loop { let mut preceding_opt = ast_node.preceding; let parent = loop { @@ -703,7 +706,7 @@ pub unsafe fn explore(nodes: NonNull) -> NonNull { current, current_ref.scope, parent_child, - &mut stack, + stack, ); } // Else if this wouldn't exit a loop, the next statement is the next statement of this parent if there is a next. @@ -712,41 +715,123 @@ pub unsafe fn explore(nodes: NonNull) -> NonNull { current, current_ref.scope, parent_next, - &mut stack, + stack, ); } } }, Vec::new(), - ); + ) } _ => todo!(), - }, - // See 2 - Op::Syscall(Syscall::Exit) => {} - // See 1 & 2 - _ => { - let scope = current_ref.scope; - current_ref.unexplored = ( - new_append(current, scope, ast_node.next.unwrap(), &mut stack), + } + } + Op::Intrinsic(Intrinsic::Break) => match statement.arg.as_slice() { + [] => { + let prev_scope_graph_node = current_ref.scope.unwrap(); + let scope_node = prev_scope_graph_node.as_ref().statement; + current.as_mut().unexplored = ( + if let Some(next) = scope_node.as_ref().next { + let scope = prev_scope_graph_node.as_ref().scope; + new_append(current, scope, next, stack) + } else { + 'outer: loop { + let mut preceding_opt = ast_node.preceding; + let parent = loop { + match preceding_opt { + None => break 'outer Vec::new(), + Some(Preceding::Previous(previous)) => { + preceding_opt = previous.as_ref().preceding; + } + Some(Preceding::Parent(parent)) => break parent, + } + }; + + // If this would exit a loop, the next statement is the 1st statement of the loop. + if parent.as_ref().statement.op == Op::Intrinsic(Intrinsic::Loop) { + debug_assert_eq!( + current_ref.scope.unwrap().as_ref().statement, + parent + ); + let parent_child = parent.as_ref().child.unwrap(); + break new_append(current, current_ref.scope, parent_child, stack); + } + // Else if this wouldn't exit a loop, the next statement is the next statement of this parent if there is a next. + else if let Some(parent_next) = parent.as_ref().next { + break new_append(current, current_ref.scope, parent_next, stack); + } + } + }, Vec::new(), ); } + _ => todo!(), + }, + // See 2 + Op::Syscall(Syscall::Exit) => {} + // See 1 & 2 + _ => { + let scope = current_ref.scope; + current_ref.unexplored = ( + new_append(current, scope, ast_node.next.unwrap(), stack), + Vec::new(), + ); } + } +} - // `exit` syscalls and loop limits will `continue` and not his this. - if current_ref.unexplored.0.is_empty() && current_ref.unexplored.1.is_empty() { - if statement.op == Op::Syscall(Syscall::Exit) { - new_passback_end(current, GraphNodeEnd::Valid); - } else { - new_passback_end(current, GraphNodeEnd::Invalid); - } - } else if let Some(0) = current_ref.loop_limit.last() { - new_passback_end(current, GraphNodeEnd::Loop); +pub unsafe fn close_path(current: NonNull) { + if current.as_ref().unexplored.0.is_empty() && current.as_ref().unexplored.1.is_empty() { + if current.as_ref().statement.as_ref().statement.op == Op::Syscall(Syscall::Exit) { + new_passback_end(current, GraphNodeEnd::Valid); + } else { + new_passback_end(current, GraphNodeEnd::Invalid); } + } else if let Some(0) = current.as_ref().loop_limit.last() { + new_passback_end(current, GraphNodeEnd::Loop); } +} + +pub enum Explore { + Finished(NonNull), + Current(NonNull), +} + +pub struct Explorer<'a> { + roots: &'a [NonNull], + stack: Vec>, +} +impl<'a> Explorer<'a> { + pub unsafe fn new(roots: &'a [NonNull]) -> Self { + Self { + roots, + stack: Vec::from(roots), + } + } + pub unsafe fn next(&mut self) -> Explore { + if let Some(current) = self.stack.pop() { + explore_node(current, &mut self.stack); + close_path(current); + Explore::Current(current) + } else { + Explore::Finished(pick_path(self.roots)) + } + } +} + +pub unsafe fn explore(roots: &[NonNull]) -> NonNull { + dbg!(&roots); + + let mut stack = Vec::from(roots); + while let Some(current) = stack.pop() { + explore_node(current, &mut stack); + close_path(current); + } + pick_path(roots) +} - let mut iter = roots.into_iter(); +pub unsafe fn pick_path(roots: &[NonNull]) -> NonNull { + let mut iter = roots.iter().copied(); let mut best = iter.next().unwrap(); for root in iter { // Deallocate all nodes now longer in the lowest cost path.