diff --git a/bpf-processor/src/main/java/me/bechberger/cast/CAST.java b/bpf-processor/src/main/java/me/bechberger/cast/CAST.java new file mode 100644 index 0000000..e95844e --- /dev/null +++ b/bpf-processor/src/main/java/me/bechberger/cast/CAST.java @@ -0,0 +1,1065 @@ +package me.bechberger.cast; + +import org.jetbrains.annotations.Nullable; + +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Represents an abstract syntax tree for C, + * loosely based on the grammar from lysator.liu.se, + * for generating eBPF C programs. + *

+ * Example: {@snippet : + * variableDefinition(struct(variable("myStruct"), + * List.of( + * structMember(Declarator.identifier("int"), variable("b"))) + * ), variable("myVar", sec("a")) + * ) + * } + */ +public interface CAST { + + List children(); + + sealed interface Expression extends CAST permits Declarator, InitDeclarator, Initializer, OperatorExpression, + PrimaryExpression { + + @Override + List children(); + + static PrimaryExpression.Constant constant(Object value) { + return new PrimaryExpression.Constant(value); + } + + static PrimaryExpression.Variable variable(String name) { + return new PrimaryExpression.Variable(name); + } + + static PrimaryExpression.Variable variable(String name, PrimaryExpression.CAnnotation... annotations) { + return new PrimaryExpression.Variable(name, annotations); + } + + static PrimaryExpression.ParenthesizedExpression parenthesizedExpression(Expression expression) { + return new PrimaryExpression.ParenthesizedExpression(expression); + } + + static PrimaryExpression.EnumerationConstant enumerationConstant(String name) { + return new PrimaryExpression.EnumerationConstant(name); + } + } + + /** + * {@snippet : + * primary_expression + * : IDENTIFIER + * | constant + * | string + * | '(' expression ')' + * ; + *} + */ + sealed interface PrimaryExpression extends Expression { + + @Override + default List children() { + return List.of(); + } + + @Override + String toString(); + + /** + * Annotation like @SEC("...") + * + * @param annotation + * @param value + */ + record CAnnotation(String annotation, String value) implements CAST { + @Override + public List children() { + return List.of(); + } + + @Override + public String toString() { + return annotation + "(" + Expression.constant(value) + ")"; + } + + static CAnnotation annotation(String annotation, String value) { + return new CAnnotation(annotation, value); + } + + static CAnnotation sec(String value) { + return new CAnnotation("SEC", value); + } + } + + /** + * Variable name for expressions + */ + record Variable(String name, CAnnotation... annotations) implements PrimaryExpression { + @Override + public String toString() { + String annString = Arrays.stream(annotations).map(Object::toString).collect(Collectors.joining(" ")); + return name + (annString.isEmpty() ? "" : " " + annString); + } + } + + record EnumerationConstant(String name) implements PrimaryExpression { + @Override + public String toString() { + return name; + } + } + + /** + * Constant value for expressions, escapes string + */ + record Constant(Object value) implements PrimaryExpression { + @Override + public String toString() { + if (value instanceof String) { +// Escape the string + return "\"" + value.toString().replace("\\", "\\\\").replace("\"", "\\\"") + "\""; + } else { + return value.toString(); + } + } + } + + /** + * Wraps an expression in parentheses + */ + record ParenthesizedExpression(Expression expression) implements PrimaryExpression { + @Override + public String toString() { + return "(" + expression + ")"; + } + + @Override + public List children() { + return List.of(expression); + } + } + } + + /** + * Operators with precedence and associativity, + * based on cppreference.com + */ + enum Operator { + SUFFIX_INCREMENT("++", 2), SUFFIX_DECREMENT("--", 2), FUNCTION_CALL("()", 2), SUBSCRIPT("[]", 2), + MEMBER_ACCESS(".", 2), POSTFIX_INCREMENT("++", 3), POSTFIX_DECREMENT("--", 3), UNARY_PLUS("+", 3), + UNARY_MINUS("-", 3), LOGICAL_NOT("!", 3), BITWISE_NOT("~", 3), DEREFERENCE("*", 3), ADDRESS_OF("&", 3), + SIZEOF("sizeof", 3), CAST("cast", 3), MULTIPLICATION("*", 5), DIVISION("/", 5), MODULUS("%", 5), ADDITION("+" + , 6), SUBTRACTION("-", 6), SHIFT_LEFT("<<", 7), SHIFT_RIGHT(">>", 7), LESS_THAN("<", 9), + LESS_THAN_OR_EQUAL("<=", 9), GREATER_THAN(">", 9), GREATER_THAN_OR_EQUAL(">=", 9), EQUAL("==", 10), + NOT_EQUAL("!=", 10), BITWISE_AND("&", 11), BITWISE_XOR("^", 12), BITWISE_OR("|", 13), LOGICAL_AND("&&", 14), + LOGICAL_OR("||", 15), CONDITIONAL("?", 16), ASSIGNMENT("=", 16), MULTIPLICATION_ASSIGNMENT("*=", 16), + DIVISION_ASSIGNMENT("/=", 16), MODULUS_ASSIGNMENT("%=", 16), ADDITION_ASSIGNMENT("+=", 16), + SUBTRACTION_ASSIGNMENT("-=", 16), SHIFT_LEFT_ASSIGNMENT("<<=", 16), SHIFT_RIGHT_ASSIGNMENT(">>=", 16), + BITWISE_AND_ASSIGNMENT("&=", 16), BITWISE_XOR_ASSIGNMENT("^=", 16), BITWISE_OR_ASSIGNMENT("|=", 16), COMMA("," + , 17); + + private static final Map OPERATORS = new HashMap<>(); + private static final Map ASSIGNMENT_OPERATORS = new HashMap<>(); + private static final Map UNARY_OPERATORS = new HashMap<>(); + private static final Map BINARY_OPERATORS = new HashMap<>(); + private static final Map POSTFIX_OPERATORS = new HashMap<>(); + + static { +// sort all operators into their respective maps + for (Operator op : Operator.values()) { + OPERATORS.put(op.op, op); + if (op.op.endsWith("=")) { + ASSIGNMENT_OPERATORS.put(op.op, op); + } else if (op.precedence == 3) { + UNARY_OPERATORS.put(op.op, op); + } else if (op.precedence == 2) { + POSTFIX_OPERATORS.put(op.op, op); + } else { + BINARY_OPERATORS.put(op.op, op); + } + } + } + + public enum Associativity { + LEFT, RIGHT + } + + public final String op; + public final int precedence; + + public final Associativity associativity; + + Operator(String op, int precedence) { + this.op = op; + this.precedence = precedence; + if (precedence == 3 || precedence == 16) { + this.associativity = Associativity.RIGHT; + } else { + this.associativity = Associativity.LEFT; + } + } + + public boolean isUnitary() { + return precedence == 3 || precedence == 2; + } + + @Override + public String toString() { + return op; + } + + static Operator binary(String op) { + return BINARY_OPERATORS.get(op); + } + + static Operator unary(String op) { + return UNARY_OPERATORS.get(op); + } + + static Operator postfix(String op) { + return POSTFIX_OPERATORS.get(op); + } + + + static Operator assignment(String op) { + return ASSIGNMENT_OPERATORS.get(op); + } + + static Operator fromString(String op) { + return OPERATORS.get(op); + } + } + + record OperatorExpression(Operator operator, Expression... expressions) implements Expression { + + @Override + public List children() { + return Arrays.asList(expressions); + } + + /** + * Takes care of operator precedence and associativity + */ + @Override + public String toString() { +// idea: +// if the operator is a binary operator, we need to check the precedence of the children +// if the precedence of the child is lower, we need to wrap it in parentheses +// if the precedence is the same, we need to check if the operator is left or right associative +// if it is left associative, we need to wrap the left child in parentheses +// if it is right associative, we need to wrap the right child in parentheses +// if the operator is a unary operator, we need to check if the child is a binary operator +// if the child is a binary operator, we need to wrap it in parentheses +// if the operator is a postfix operator, we need to wrap the child in parentheses +// if the operator is an assignment operator, we need to wrap the left child in parentheses +// if the operator is a ternary operator, we need to wrap the children in parentheses +// if the operator is a member access operator, we need to wrap the right child in parentheses +// if the operator is a subscript operator, we need to wrap the right child in parentheses +// if the operator is a function call operator, we need to wrap the left child in parentheses +// if the operator is a dereference operator, we need to wrap the child in parentheses +// if the operator is an address of operator, we need to wrap the child in parentheses +// if the operator is a sizeof operator, we need to wrap the child in parentheses +// if the operator is a logical not operator, we need to wrap the child in parentheses +// if the operator is a bitwise not operator, we need to wrap the child in parentheses +// if the operator is a unary plus operator, we need to wrap the child in parentheses +// if the operator is a unary minus operator, we need to wrap the child in parentheses +// if the operator is a suffix increment operator, we need to wrap the child in parentheses +// if the operator is a suffix decrement operator, we need to wrap the child in parentheses +// if the operator is a prefix increment operator, we need to wrap the child in parentheses +// if the operator is a prefix decrement operator, we need to wrap the child in parentheses +// if the operator is a bitwise and operator, we need to wrap the children in parentheses + +// if the operator is a binary operator, we need to check the precedence of the children + +// if the operator is a unary operator, we need to check if the child is a binary operator + if (operator().precedence == 3) { + Expression operator1 = children().getFirst(); + String op1String = operator1.toString(); + if (operator1 instanceof OperatorExpression operatorExpression) { + if (operatorExpression.operator().precedence < operator().precedence) { + op1String = "(" + operatorExpression + ")"; + } + } + return operator() + op1String; + } else { +// if the operator is a postfix operator, we need to wrap the child in parentheses + if (operator().precedence == 2) { + Expression operator1 = children().getFirst(); + String op1String = operator1.toString(); + if (operator1 instanceof OperatorExpression operatorExpression) { + op1String = "(" + operatorExpression + ")"; + } + return op1String + operator(); + } else { +// if the operator is an assignment operator, we need to wrap the left child in parentheses + if (operator().precedence == 16) { + Expression operator1 = children().get(0); + Expression operator2 = children().get(1); + String op1String = operator1.toString(); + String op2String = operator2.toString(); + if (operator1 instanceof OperatorExpression operatorExpression) { + op1String = "(" + operatorExpression + ")"; + } + if (operator2 instanceof OperatorExpression operatorExpression) { + if (operatorExpression.operator().precedence < operator().precedence) { + op2String = "(" + operatorExpression + ")"; + } + } + return op1String + " " + operator() + " " + op2String; + } else { +// if the operator is a ternary operator, we need to wrap the children in parentheses + if (operator() == Operator.CONDITIONAL) { + Expression operator1 = children().get(0); + Expression operator2 = children().get(1); + Expression operator3 = children().get(2); + String op1String = operator1.toString(); + String op2String = operator2.toString(); + String op3String = operator3.toString(); + if (operator1 instanceof OperatorExpression operatorExpression) { + op1String = "(" + operatorExpression + ")"; + } + if (operator2 instanceof OperatorExpression operatorExpression) { + op2String = "(" + operatorExpression + ")"; + } + if (operator3 instanceof OperatorExpression operatorExpression) { + op3String = "(" + operatorExpression + ")"; + } + return op1String + " ? " + op2String + " : " + op3String; + } else if (operator() == Operator.MEMBER_ACCESS) { + Expression operator1 = children().get(0); + Expression operator2 = children().get(1); + String op1String = operator1.toString(); + String op2String = operator2.toString(); + if (operator2 instanceof OperatorExpression operatorExpression) { + op2String = "(" + operatorExpression + ")"; + } + return op1String + "." + op2String; + } else if (operator() == Operator.SUBSCRIPT) { + Expression operator1 = children().get(0); + Expression operator2 = children().get(1); + String op1String = operator1.toString(); + String op2String = operator2.toString(); + if (operator2 instanceof OperatorExpression operatorExpression) { + op2String = "(" + operatorExpression + ")"; + } + return op1String + "[" + op2String + "]"; + } else if (operator() == Operator.FUNCTION_CALL) { + Expression func = children().getFirst(); + String funcString = func.toString(); + if (func instanceof OperatorExpression operatorExpression) { + funcString = "(" + funcString + ")"; + } + return funcString + "(" + children().stream().skip(1).map(Object::toString).collect(Collectors.joining(", ")) + ")"; + } else if (operator() == Operator.SIZEOF) { + Expression operator1 = children().getFirst(); + String op1String = operator1.toString(); + if (operator1 instanceof OperatorExpression operatorExpression) { + op1String = "(" + operatorExpression + ")"; + } + return "sizeof(" + op1String + ")"; + } else if (operator() == Operator.CAST) { + return "(" + children().get(0) + ")" + children().get(1); + } else if (operator().isUnitary()) { + Expression operator1 = children().getFirst(); + String op1String = operator1.toString(); + if (operator1 instanceof OperatorExpression operatorExpression) { + op1String = "(" + operatorExpression + ")"; + } + if (operator().associativity == Operator.Associativity.RIGHT) { + return operator() + op1String; + } else { + return op1String + operator(); + } + } else { + Expression operator1 = children().get(0); + Expression operator2 = children().get(1); + String op1String = operator1.toString(); + String op2String = operator2.toString(); + if (operator1 instanceof OperatorExpression operatorExpression) { + if (operatorExpression.operator().precedence < operator().precedence) { + op1String = "(" + operatorExpression + ")"; + } else if (operatorExpression.operator().precedence == operator().precedence) { + if (operatorExpression.operator().associativity == Operator.Associativity.LEFT) { + op1String = "(" + operatorExpression + ")"; + } + } + } + if (operator2 instanceof OperatorExpression operatorExpression) { + if (operatorExpression.operator().precedence < operator().precedence) { + op2String = "(" + operatorExpression + ")"; + } else if (operatorExpression.operator().precedence == operator().precedence) { + if (operatorExpression.operator().associativity == Operator.Associativity.RIGHT) { + op2String = "(" + operatorExpression + ")"; + } + } + } + return op1String + " " + operator() + " " + op2String; + } + } + + } + } + } + + public static OperatorExpression binary(String op, Expression left, Expression right) { + return new OperatorExpression(Operator.binary(op), left, right); + } + + public static OperatorExpression unary(String op, Expression expression) { + return new OperatorExpression(Operator.unary(op), expression); + } + + public static OperatorExpression postfix(String op, Expression expression) { + return new OperatorExpression(Operator.postfix(op), expression); + } + + public static OperatorExpression assignment(String op, Expression left, Expression right) { + return new OperatorExpression(Operator.assignment(op), left, right); + } + + public static OperatorExpression ternary(Expression condition, Expression trueExpression, + Expression falseExpression) { + return new OperatorExpression(Operator.CONDITIONAL, condition, trueExpression, falseExpression); + } + + public static OperatorExpression memberAccess(Expression left, Expression right) { + return new OperatorExpression(Operator.MEMBER_ACCESS, left, right); + } + + public static OperatorExpression arrayAccess(Expression left, Expression right) { + return new OperatorExpression(Operator.SUBSCRIPT, left, right); + } + + public static OperatorExpression call(Expression func, Expression... args) { + return new OperatorExpression(Operator.FUNCTION_CALL, + Stream.concat(Stream.of(func), Arrays.stream(args)).toArray(Expression[]::new)); + } + + public static OperatorExpression pointer(Expression expression) { + return new OperatorExpression(Operator.DEREFERENCE, expression); + } + + public static OperatorExpression cast(PrimaryExpression.Variable type, Expression expression) { + return new OperatorExpression(Operator.CAST, type, expression); + } + } + + record InitDeclarator(@Nullable PrimaryExpression.Variable name, Expression expression) implements Expression { + + @Override + public List children() { + return List.of(expression); + } + + @Override + public String toString() { + return (name == null ? "" : name + " = ") + expression; + } + } + + sealed interface Initializer extends Expression { + + record InitializerList(List declarators) implements Initializer { + @Override + public List children() { + return declarators; + } + + @Override + public String toString() { + return "{" + declarators.stream().map(InitDeclarator::toString).collect(Collectors.joining(", ")) + "}"; + } + } + } + + sealed interface Declarator extends Expression { + + default String toPrettyString(String indent, String increment) { + return indent + this; + } + + default String toPrettyString() { + return toPrettyString("", " "); + } + + record PointerDeclarator(Declarator declarator) implements Declarator { + @Override + public List children() { + return List.of(declarator); + } + + @Override + public String toString() { + return "*" + declarator; + } + } + + record ArrayDeclarator(Declarator declarator, @Nullable Expression size) implements Declarator { + @Override + public List children() { + return size == null ? List.of(declarator) : List.of(declarator, size); + } + + @Override + public String toString() { + return declarator + (size == null ? "[]" : "[" + size + "]"); + } + } + + /** + * Struct member with optional size for ebpf member declaration (e.g. u32 (var, 10)) + */ + record StructMember(Declarator declarator, PrimaryExpression.Variable name, + @Nullable PrimaryExpression ebpfSize) implements Declarator { + + @Override + public List children() { + return List.of(declarator, name); + } + + @Override + public String toString() { + return declarator + " " + name; + } + + @Override + public String toPrettyString(String indent, String increment) { + if (ebpfSize == null) { + return declarator.toPrettyString(indent, increment) + " " + name + ";"; + } + return indent + declarator + " (" + name + ", " + ebpfSize + ");"; + } + } + + record StructDeclarator(@Nullable PrimaryExpression.Variable name, + List members) implements Declarator { + @Override + public List children() { + if (name == null) { + return members; + } + return Stream.concat(Stream.of(name), members.stream()).collect(Collectors.toList()); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + "struct " + (name == null ? "" : name + " ") + "{\n" + members.stream().map(m -> m.toPrettyString(indent + increment, increment)).collect(Collectors.joining("\n")) + "\n" + indent + "}"; + } + } + + record FunctionDeclarator(Declarator declarator, List parameters) implements Declarator { + @Override + public List children() { + return parameters; + } + + @Override + public String toString() { + return declarator + "(" + parameters.stream().map(Object::toString).collect(Collectors.joining(", ")) + ")"; + } + } + + record IdentifierDeclarator(PrimaryExpression.Variable name) implements Declarator { + @Override + public List children() { + return List.of(); + } + + @Override + public String toString() { + return name.toString(); + } + } + + static Declarator pointer(Declarator declarator) { + return new PointerDeclarator(declarator); + } + + static Declarator array(Declarator declarator, @Nullable Expression size) { + return new ArrayDeclarator(declarator, size); + } + + static Declarator function(Declarator declarator, List parameters) { + return new FunctionDeclarator(declarator, parameters); + } + + static Declarator identifier(PrimaryExpression.Variable name) { + return new IdentifierDeclarator(name); + } + + static Declarator identifier(String name) { + return new IdentifierDeclarator(new PrimaryExpression.Variable(name)); + } + + static Declarator struct(PrimaryExpression.Variable name, List members) { + return new StructDeclarator(name, members); + } + + static StructMember structMember(Declarator declarator, PrimaryExpression.Variable name) { + return new StructMember(declarator, name, null); + } + + static StructMember structMember(Declarator declarator, PrimaryExpression.Variable name, + PrimaryExpression ebpfSize) { + return new StructMember(declarator, name, ebpfSize); + } + } + + + interface Statement extends CAST { + + /** + * Pretty string representation of the AST + * + * @param indent indent of the current line, ignored by expressions + * @param increment increment of the indent + * @return pretty string representation of the AST + */ + default String toPrettyString(String indent, String increment) { + return indent + this.children().getFirst() + ";"; + } + + default String toPrettyString() { + return toPrettyString("", " "); + } + + record ExpressionStatement(Expression expression) implements Statement { + + @Override + public List children() { + return List.of(expression); + } + + @Override + public String toString() { + return toPrettyString("", ""); + } + } + + record VariableDefinition(Declarator type, PrimaryExpression.Variable name) implements Statement { + + @Override + public List children() { + return List.of(type, name); + } + + @Override + public String toPrettyString(String indent, String increment) { + return type.toPrettyString(indent, increment) + " " + name + ";"; + } + + } + + record CompoundStatement(List statements) implements Statement { + + @Override + public List children() { + return statements; + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + "{\n" + statements.stream().map(s -> s.toPrettyString(indent + increment, increment)).collect(Collectors.joining("\n")) + "\n" + indent + "}"; + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record IfStatement(Expression condition, Statement thenStatement, + @Nullable Statement elseStatement) implements Statement { + + @Override + public List children() { + return elseStatement == null ? List.of(condition, thenStatement) : List.of(condition, thenStatement, + elseStatement); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + "if (" + condition + ")\n" + thenStatement.toPrettyString(indent + increment, + increment) + (elseStatement == null ? "" : + " else\n" + elseStatement.toPrettyString(indent + increment, increment)); + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record WhileStatement(Expression condition, Statement body) implements Statement { + + @Override + public List children() { + return List.of(condition, body); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + "while (" + condition + ")\n" + body.toPrettyString(indent + increment, increment); + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record ForStatement(@Nullable Expression init, @Nullable Expression condition, @Nullable Expression increment, + Statement body) implements Statement { + + @Override + public List children() { + List children = new ArrayList<>(); + if (init != null) { + children.add(init); + } + if (condition != null) { + children.add(condition); + } + if (increment != null) { + children.add(increment); + } + children.add(body); + return children; + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + "for (" + (init == null ? "" : init) + "; " + (condition == null ? "" : condition) + + "; " + (increment == null ? "" : increment) + ")\n" + body.toPrettyString(indent + increment, + increment); + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record ReturnStatement(@Nullable Expression expression) implements Statement { + + @Override + public List children() { + return expression == null ? List.of() : List.of(expression); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + "return" + (expression == null ? "" : " " + expression) + ";"; + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record BreakStatement() implements Statement { + + @Override + public List children() { + return List.of(); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + "break;"; + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record ContinueStatement() implements Statement { + + @Override + public List children() { + return List.of(); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + "continue;"; + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record EmptyStatement() implements Statement { + + @Override + public List children() { + return List.of(); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + ";"; + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record DeclarationStatement(Declarator declarator, @Nullable Initializer initializer) implements Statement { + + @Override + public List children() { + return initializer == null ? List.of(declarator) : List.of(declarator, initializer); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + declarator + (initializer == null ? "" : " = " + initializer) + ";"; + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record StructDeclarationStatement(Declarator.StructDeclarator declarator) implements Statement { + + @Override + public List children() { + return List.of(declarator); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + declarator + ";"; + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record FunctionDeclarationStatement(Declarator.FunctionDeclarator declarator, + CompoundStatement body) implements Statement { + + @Override + public List children() { + return List.of(declarator, body); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + declarator + "\n" + body.toPrettyString(indent + increment, increment); + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record Define(String name, PrimaryExpression.Constant value) implements Statement { + + @Override + public List children() { + return List.of(value); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + "#define " + name + " " + value; + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record Include(String file) implements Statement { + + @Override + public List children() { + return List.of(); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + "#include " + file; + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record Typedef(Declarator declarator) implements Statement { + + @Override + public List children() { + return List.of(declarator); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + "typedef " + declarator; + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record CaseStatement(Expression expression, Statement body) implements Statement { + + @Override + public List children() { + return List.of(expression, body); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + "case " + expression + ":\n" + body.toPrettyString(indent + increment, increment); + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record DefaultStatement(Statement body) implements Statement { + + @Override + public List children() { + return List.of(body); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + "default:\n" + body.toPrettyString(indent + increment, increment); + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + record SwitchStatement(Expression expression, Statement body) implements Statement { + + @Override + public List children() { + return List.of(expression, body); + } + + @Override + public String toPrettyString(String indent, String increment) { + return indent + "switch (" + expression + ")\n" + body.toPrettyString(indent + increment, increment); + } + + @Override + public String toString() { + return toPrettyString("", " "); + } + } + + static Statement expression(Expression expression) { + return new ExpressionStatement(expression); + } + + static Statement compound(Statement... statements) { + return new CompoundStatement(List.of(statements)); + } + + static Statement compound(List statements) { + return new CompoundStatement(statements); + } + + static Statement ifStatement(Expression condition, Statement thenStatement, @Nullable Statement elseStatement) { + return new IfStatement(condition, thenStatement, elseStatement); + } + + static Statement whileStatement(Expression condition, Statement body) { + return new WhileStatement(condition, body); + } + + static Statement forStatement(@Nullable Expression init, @Nullable Expression condition, + @Nullable Expression increment, Statement body) { + return new ForStatement(init, condition, increment, body); + } + + static Statement returnStatement(@Nullable Expression expression) { + return new ReturnStatement(expression); + } + + static Statement breakStatement() { + return new BreakStatement(); + } + + static Statement continueStatement() { + return new ContinueStatement(); + } + + static Statement emptyStatement() { + return new EmptyStatement(); + } + + static Statement declarationStatement(Declarator declarator, @Nullable Initializer initializer) { + return new DeclarationStatement(declarator, initializer); + } + + static Statement structDeclarationStatement(Declarator.StructDeclarator declarator) { + return new StructDeclarationStatement(declarator); + } + + static Statement functionDeclarationStatement(Declarator.FunctionDeclarator declarator, + CompoundStatement body) { + return new FunctionDeclarationStatement(declarator, body); + } + + static Statement define(String name, PrimaryExpression.Constant value) { + return new Define(name, value); + } + + static Statement include(String file) { + return new Include(file); + } + + static Statement typedef(Declarator declarator) { + return new Typedef(declarator); + } + + static Statement caseStatement(Expression expression, Statement body) { + return new CaseStatement(expression, body); + } + + static Statement defaultStatement(Statement body) { + return new DefaultStatement(body); + } + + static Statement switchStatement(Expression expression, Statement body) { + return new SwitchStatement(expression, body); + } + + static Statement variableDefinition(Declarator type, PrimaryExpression.Variable name) { + return new VariableDefinition(type, name); + } + } +} \ No newline at end of file diff --git a/bpf-processor/src/test/java/me/bechberger/cast/ASTTest.java b/bpf-processor/src/test/java/me/bechberger/cast/ASTTest.java new file mode 100644 index 0000000..84628b2 --- /dev/null +++ b/bpf-processor/src/test/java/me/bechberger/cast/ASTTest.java @@ -0,0 +1,50 @@ +package me.bechberger.cast; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.List; +import java.util.stream.Stream; + +import static me.bechberger.cast.CAST.Declarator.struct; +import static me.bechberger.cast.CAST.Declarator.structMember; +import static me.bechberger.cast.CAST.Expression.constant; +import static me.bechberger.cast.CAST.Expression.variable; +import static me.bechberger.cast.CAST.OperatorExpression.binary; +import static me.bechberger.cast.CAST.PrimaryExpression.CAnnotation.sec; +import static me.bechberger.cast.CAST.Statement.*; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Basic AST tests + */ +public class ASTTest { + + static Stream exprAstAndExpectedCode() { + return Stream.of(Arguments.of(expression(binary("+", constant(1), constant(2))), "1 + 2;")); + } + + @ParameterizedTest + @MethodSource("exprAstAndExpectedCode") + void testExpr(CAST.Statement statement, String expectedCode) { + assertEquals(expectedCode, statement.toPrettyString()); + } + + static Stream declAstAndExpectedCode() { + return Stream.of(Arguments.of(variableDefinition(struct(variable("myStruct"), + List.of(structMember(CAST.Declarator.identifier("int"), variable("b")))), variable("myVar", sec("a"))) + , "struct myStruct {\n int b;\n} myVar SEC(\"a\");"), + Arguments.of(variableDefinition(struct(variable("myStruct"), + List.of(structMember(CAST.Declarator.identifier("int"), variable("b"), constant(1)))), + variable( + "myVar", sec("a"))) + , "struct myStruct {\n int (b, 1);\n} myVar SEC(\"a\");")); + } + + @ParameterizedTest + @MethodSource("declAstAndExpectedCode") + void testDecl(Statement ast, String expectedCode) { + assertEquals(expectedCode, ast.toPrettyString()); + } +} \ No newline at end of file