From a2566d5d593d50559748b1f4306037a06de099b1 Mon Sep 17 00:00:00 2001 From: "Mr.UNIX" Date: Mon, 29 Jan 2024 21:19:57 +0100 Subject: [PATCH] fix: recursive function calls fixed --- .../builtin_functions/DefineFunction.h | 3 +- tests/bytecode/compiler/functions_test.cpp | 69 +++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/src/bytecode/builtin_functions/DefineFunction.h b/src/bytecode/builtin_functions/DefineFunction.h index 9fca121..3dfcaa6 100644 --- a/src/bytecode/builtin_functions/DefineFunction.h +++ b/src/bytecode/builtin_functions/DefineFunction.h @@ -16,10 +16,11 @@ namespace Bytecode::BuiltinFunctions { result.push_back(new Store(reg)); } else { auto segment = new Segment({}); + compiler.program.declare_function(*args[0]->token->asString(), segment); for (auto argument: args[0]->children) compiler.program.declare_variable(*argument->token->asString()); compiler.compile(*args[1], segment->instructions); - compiler.program.declare_function(*args[0]->token->asString(), segment); + } } }; diff --git a/tests/bytecode/compiler/functions_test.cpp b/tests/bytecode/compiler/functions_test.cpp index ef51d12..a3e7811 100644 --- a/tests/bytecode/compiler/functions_test.cpp +++ b/tests/bytecode/compiler/functions_test.cpp @@ -1,9 +1,13 @@ #include "bytecode/compiler/Segment.h" #include "bytecode/instructions/Add.h" #include "bytecode/instructions/Call.h" +#include "bytecode/instructions/CondJumpIfNot.h" +#include "bytecode/instructions/Equals.h" +#include "bytecode/instructions/Jump.h" #include "bytecode/instructions/Load.h" #include "bytecode/instructions/LoadLiteral.h" #include "bytecode/instructions/Multiply.h" +#include "bytecode/instructions/Subtract.h" #include "parser/SyntaxTreeNode.h" #include @@ -128,3 +132,68 @@ TEST(compiler_functions, SimpleFunctionCall) { EXPECT_EQ(expected_result == compiler.program, true); } + +TEST(compiler_functions, RecursiveFunction) { + // (define (sum n) (if (= n 1) 1 (+ n (sum (- n 1))))) + const auto expression = SyntaxTreeNode( + new Token(Token::Symbol, "define"), + { + new SyntaxTreeNode( + new Token(Token::Symbol, "sum"), + { + new SyntaxTreeNode(new Token(Token::Symbol, "n")), + }), + new SyntaxTreeNode( + new Token(Token::Symbol, "if"), + { + new SyntaxTreeNode( + new Token(Token::Symbol, "="), + { + new SyntaxTreeNode(new Token(Token::Symbol, "n")), + new SyntaxTreeNode(new Token(1)), + }), + new SyntaxTreeNode(new Token(1)), + new SyntaxTreeNode( + new Token(Token::Symbol, "+"), + { + new SyntaxTreeNode(new Token(Token::Symbol, "n")), + new SyntaxTreeNode( + new Token(Token::Symbol, "sum"), + { + new SyntaxTreeNode( + new Token(Token::Symbol, "-"), + { + new SyntaxTreeNode(new Token(Token::Symbol, "n")), + new SyntaxTreeNode(new Token(1)), + }), + }), + }), + }), + }); + + auto expected_result = Program( + {{"sum", 1}}, + {{"n", 0}}, + { + new Segment({}), + new Segment({ + new Load(0), + new LoadLiteral(1), + new Equals(), + new CondJumpIfNot(6), + new LoadLiteral(1), + new Jump(12), + new Load(0), + new Load(0), + new LoadLiteral(1), + new Subtract(), + new Call(1), + new Add(), + }), + }); + + Compiler compiler = Compiler(); + compiler.compile(expression); + + EXPECT_EQ(expected_result == compiler.program, true); +}