Skip to content

Commit

Permalink
Implement tail calls
Browse files Browse the repository at this point in the history
  • Loading branch information
zombiezen committed Dec 21, 2024
1 parent b87f7ed commit eb7c4f5
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 12 deletions.
41 changes: 29 additions & 12 deletions internal/luacode/instruction.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,21 +569,40 @@ const (
// A B k if (not R[B] == k) then pc++ else R[A] := R[B] (
OpTestSet OpCode = 67 // TESTSET

// A B C R[A], ... ,R[A+C-2] := R[A](R[A+1], ... ,R[A+B-1])
// OpCall calls a function.
//
// - R[A] is the function to call.
// When the function returns, R[A] is also where the first result will be stored.
// - B is the number of arguments to pass to the function plus one.
// If B is zero, this indicates to use all values after R[A] on the stack
// as arguments.
// - C is the number of expected results plus one.
// If C is zero, then all results from the function will be pushed onto the stack,
// starting at R[A].
OpCall OpCode = 68 // CALL
// A B C k return R[A](R[A+1], ... ,R[A+B-1])
// OpTailCall calls a function as the return of the function,
// replacing the stack entry of the calling function.
//
// - R[A] is the function to call.
// - B is the number of arguments to pass to the function plus one.
// If B is zero, this indicates to use all values after R[A] on the stack
// as arguments.
// - C > 0 means the calling function is vararg,
// so that any effects of [OpVarargPrep] must be corrected before returning;
// in this case, (C - 1) is its number of fixed parameters.
// - k should be true if there are upvalues that need to be closed.
// (The language does not permit tail calls in blocks with to-be-closed variables in scope.)
OpTailCall OpCode = 69 // TAILCALL

// OpReturn instructs control flow to return to the function's caller.
//
// A B C k return R[A], ... ,R[A+B-2]
//
// If (B == 0) then return up to 'top'.
// 'k' specifies that the function builds upvalues,
// which may need to be closed.
// C > 0 means the function is vararg,
// so that any effects of [OpVarargPrep] must be corrected before returning;
// in this case, (C - 1) is its number of fixed parameters.
// - R[A] is the first result to return.
// - B is the number of results to return.
// If B is zero, then return up to 'top'.
// - C > 0 means the function is vararg,
// so that any effects of [OpVarargPrep] must be corrected before returning;
// in this case, (C - 1) is its number of fixed parameters.
// - k should be true if there are upvalues and/or to-be-closed variables that need to be closed.
OpReturn OpCode = 70 // RETURN
// OpReturn0 instructs control flow to return to the function's caller
// with zero results.
Expand All @@ -598,8 +617,6 @@ const (
// that is variadic,
// has variables referenced by its functions,
// or has to-be-closed variables.
//
// A return R[A]
OpReturn1 OpCode = 72 // RETURN1

// A Bx update counters; if loop continues then pc-=Bx;
Expand Down
32 changes: 32 additions & 0 deletions internal/mylua/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,38 @@ func (l *State) exec() (err error) {
return err
}
}
case luacode.OpTailCall:
if maxTBC, hasTBC := l.tbc.Max(); hasTBC && maxTBC >= uint(frame.registerStart()) {
return fmt.Errorf("%s: internal error: cannot make tail call when block has to-be-closed variables in scope",
sourceLocation(f.proto, frame.pc-1))
}

numArguments := int(i.ArgB()) - 1
numResults := l.frame().numResults
// TODO(soon): Validate ArgA.
functionIndex := frame.registerStart() + int(i.ArgA())

l.closeUpvalues(frame.registerStart())
if numArguments < 0 {
// Varargs: read from top.
numArguments = len(l.stack) - (functionIndex + 1)
} else {
l.setTop(functionIndex + 1 + numArguments)
}
fp := frame.framePointer()
copy(l.stack[fp:], l.stack[functionIndex:])
l.setTop(fp + 1 + numArguments)
l.callStack = l.callStack[:len(l.callStack)-1]
isLua, err := l.prepareCall(numArguments, numResults)
if err != nil {
return err
}
if isLua {
frame, f, err = l.loadLuaFrame()
if err != nil {
return err
}
}
case luacode.OpReturn:
// TODO(soon): Validate ArgA+numResults.
resultStackStart := frame.registerStart() + int(i.ArgA())
Expand Down
29 changes: 29 additions & 0 deletions internal/mylua/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -563,4 +563,33 @@ func TestVM(t *testing.T) {
t.Errorf("emit sequence = %v; want %v", got, want)
}
})

t.Run("TailCall", func(t *testing.T) {
state := new(State)
defer func() {
if err := state.Close(); err != nil {
t.Error("Close:", err)
}
}()

const source = `local function factorial(n, acc)` + "\n" +
`acc = acc or 1` + "\n" +
`if n == 0 then return acc end` + "\n" +
`return factorial(n - 1, acc * n)` + "\n" +
`end` + "\n" +
`return factorial(3)` + "\n"
if err := state.Load(strings.NewReader(source), luacode.Source(source), "t"); err != nil {
t.Fatal(err)
}
if err := state.Call(0, 1, 0); err != nil {
t.Fatal(err)
}
if !state.IsNumber(-1) {
t.Fatalf("top of stack is %v; want number", state.Type(-1))
}
const want = 6.0
if got, ok := state.ToNumber(-1); got != want || !ok {
t.Errorf("(return value, is number) = (%g, %t); want (%g, true)", got, ok, want)
}
})
}
22 changes: 22 additions & 0 deletions sets/bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,28 @@ func (s *Bit) Len() int {
return total
}

// Min returns the smallest value in the set.
func (s *Bit) Min() (_ uint, nonEmpty bool) {
if s == nil {
return 0, false
}
for x := range s.All() {
return x, true
}
return 0, false
}

// Max returns the largest value in the set.
func (s *Bit) Max() (_ uint, nonEmpty bool) {
if s == nil {
return 0, false
}
for x := range s.Reversed() {
return x, true
}
return 0, false
}

// All returns an iterator of the elements of s.
// Elements are in ascending order.
func (s *Bit) All() iter.Seq[uint] {
Expand Down

0 comments on commit eb7c4f5

Please sign in to comment.