Skip to content

Commit

Permalink
#19 [asmik] extracted virtual registers allocator
Browse files Browse the repository at this point in the history
  • Loading branch information
vityaman committed Feb 12, 2024
1 parent a9e5790 commit 47f9270
Showing 1 changed file with 37 additions and 36 deletions.
73 changes: 37 additions & 36 deletions sleepy/asmik/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,35 @@
from .memory import Memory


class VirtualRegisters:
def __init__(self) -> None:
self.sequence = (VirtReg(n) for n in range(10000))
self.binded: dict[str, Reg] = {}

def binded_to(self, var: tafka.Var) -> Reg:
var_repr = repr(var)
if var_repr not in self.binded:
self.binded[var_repr] = self.temporary()
return self.binded[var_repr]

def temporary(self) -> VirtReg:
return next(self.sequence)


class AsmikEmitListener(TafkaWalker.Listener):
def __init__(self) -> None:
self.virt_regs = (VirtReg(n) for n in range(10000))
self.memory = Memory()
self.regs: dict[str, Reg] = {}
self.registers = VirtualRegisters()

self.resolved: dict[str, int] = {}

self.block: tafka.Block
self.block_until: list[tafka.Block] = []

@override
def enter_procedure(self, procedure: tafka.Procedure) -> None:
addr = self.memory.data_put(IntegerData(self.next_instr_addr))
self.resolved[repr(procedure.const)] = addr
for i, param in enumerate(procedure.parameters):
param_reg = self.reg_var(param)
self.emit_i(mov(param_reg, PhysReg.arg(i + 1)))
register = self.registers.binded_to(param)
self.emit_i(mov(register, PhysReg.arg(i + 1)))

@override
def exit_procedure(self, procedure: tafka.Procedure) -> None:
Expand All @@ -53,7 +64,6 @@ def exit_procedure(self, procedure: tafka.Procedure) -> None:
@override
def enter_block(self, block: tafka.Block) -> None:
self.resolved[repr(block.label)] = self.next_instr_addr
self.block = block

@override
def exit_block(self, block: tafka.Block) -> None:
Expand All @@ -69,22 +79,22 @@ def exit_statement(self, statement: tafka.Statement) -> None:

@override
def on_return(self, ret: tafka.Return) -> None:
retr = self.reg_var(ret.value)
retr = self.registers.binded_to(ret.value)
self.emit_i(mov(Reg.a1(), retr))
self.emit_i(Brn(Reg.ze(), Reg.ra()))

@override
def on_goto(self, goto: tafka.Goto) -> None:
block_label = repr(goto.block.label)
label = self.reg_tmp()
label = self.registers.temporary()
self.emit_i(movi(label, Unassigned(block_label)))
self.emit_i(Brn(Reg.ze(), label))

@override
def on_conditional(self, conditional: tafka.Conditional) -> None:
else_label = repr(conditional.else_branch.label)
condition = self.reg_var(conditional.condition)
else_address = self.reg_tmp()
condition = self.registers.binded_to(conditional.condition)
else_address = self.registers.temporary()
self.emit_i(movi(else_address, Unassigned(else_label)))
self.emit_i(Brn(condition, else_address))

Expand All @@ -95,32 +105,32 @@ def on_invokation(
source: tafka.Invokation,
) -> None:
for i, arg in enumerate(source.args):
arg_reg = self.reg_var(arg)
arg_reg = self.registers.binded_to(arg)
self.emit_i(mov(PhysReg.arg(i + 1), arg_reg))

prev_ra = self.reg_tmp()
prev_ra = self.registers.temporary()
self.emit_i(mov(prev_ra, Reg.ra()))

proc_reg = self.reg_var(source.closure)
proc_reg = self.registers.binded_to(source.closure)
self.emit_i(Addim(Reg.ra(), Reg.ip(), Integer(4)))
self.emit_i(Brn(Reg.ze(), proc_reg))

res_reg = self.reg_var(target)
res_reg = self.registers.binded_to(target)
self.emit_i(mov(res_reg, Reg.a1()))

self.emit_i(mov(Reg.ra(), prev_ra))

@override
def on_load(self, target: tafka.Var, source: tafka.Load) -> None:
dst = self.reg_var(target)
dst = self.registers.binded_to(target)
addr = self.addr_of(source.constant)
self.emit_i(Addim(dst, Reg.ze(), addr))
self.emit_i(Load(dst, dst))

@override
def on_copy(self, target: tafka.Var, source: tafka.Copy) -> None:
dst = self.reg_var(target)
src = self.reg_var(source.argument)
dst = self.registers.binded_to(target)
src = self.registers.binded_to(source.argument)
self.emit_i(mov(dst, src))

@override
Expand All @@ -141,13 +151,13 @@ def on_rem(self, target: tafka.Var, source: tafka.Rem) -> None:

@override
def on_eq(self, target: tafka.Var, source: tafka.Eq) -> None:
dstr = self.reg_var(target)
lhsr = self.reg_var(source.left)
rhsr = self.reg_var(source.right)
dstr = self.registers.binded_to(target)
lhsr = self.registers.binded_to(source.left)
rhsr = self.registers.binded_to(source.right)

l2r = orb = dstr
r2l = self.reg_tmp()
neg = self.reg_tmp()
r2l = self.registers.temporary()
neg = self.registers.temporary()

self.emit_i(Slti(l2r, lhsr, rhsr))
self.emit_i(Slti(r2l, rhsr, lhsr))
Expand All @@ -172,9 +182,9 @@ def on_trivial_binary_operation(
target: tafka.Var,
source: tafka.BinaryOperator,
) -> None:
dstr = self.reg_var(target)
lhsr = self.reg_var(source.left)
rhsr = self.reg_var(source.right)
dstr = self.registers.binded_to(target)
lhsr = self.registers.binded_to(source.left)
rhsr = self.registers.binded_to(source.right)

instruction: Instruction
match source:
Expand All @@ -198,15 +208,6 @@ def on_trivial_binary_operation(
def emit_i(self, instr: Instruction) -> None:
self.memory.instr.append(instr)

def reg_var(self, var: tafka.Var) -> Reg:
var_repr = repr(var)
if var_repr not in self.regs:
self.regs[var_repr] = self.reg_tmp()
return self.regs[var_repr]

def reg_tmp(self) -> VirtReg:
return next(self.virt_regs)

def addr_of(self, cnst: tafka.Const) -> Immediate:
match cnst.kind:
case tafka.Int():
Expand Down

0 comments on commit 47f9270

Please sign in to comment.