Skip to content

Commit

Permalink
#16 [asm] saved registers to stack before invoke
Browse files Browse the repository at this point in the history
  • Loading branch information
vityaman committed Feb 12, 2024
1 parent d45363f commit 69e5ab1
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 37 deletions.
54 changes: 45 additions & 9 deletions sleepy/asmik/emit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import override

import sleepy.tafka.representation as taf
from sleepy.tafka.walker import TafkaWalker
from sleepy.tafka import Context, TafkaWalker, Usages

from .argument import Immediate, Integer, Unassigned
from .argument import PhysicalRegister as PhysReg
Expand All @@ -20,6 +20,7 @@
Orb,
Remi,
Slti,
Stor,
Xorb,
mov,
movi,
Expand All @@ -42,15 +43,23 @@ def temporary(self) -> VirtReg:
return next(self.sequence)


class AsmikEmitListener(TafkaWalker.Listener):
class AsmikEmitListener(TafkaWalker.ContextedListener):
def __init__(self) -> None:
super().__init__()

self.memory = Memory()
self.registers = VirtualRegisters()

self.resolved: dict[str, int] = {}

self.procedure: taf.Procedure
self.usages: Usages

@override
def enter_procedure(self, procedure: taf.Procedure) -> None:
self.usages = Usages.analyzed(procedure)
self.procedure = procedure

addr = self.memory.data_put(IntegerData(self.next_instr_addr))
self.resolved[repr(procedure.const)] = addr
for i, param in enumerate(procedure.parameters):
Expand All @@ -63,6 +72,7 @@ def exit_procedure(self, procedure: taf.Procedure) -> None:

@override
def enter_block(self, block: taf.Block) -> None:
super().enter_block(block)
self.resolved[repr(block.label)] = self.next_instr_addr

@override
Expand All @@ -71,7 +81,7 @@ def exit_block(self, block: taf.Block) -> None:

@override
def enter_statement(self, statement: taf.Statement) -> None:
pass
super().enter_statement(statement)

@override
def exit_statement(self, statement: taf.Statement) -> None:
Expand Down Expand Up @@ -104,21 +114,38 @@ def on_invokation(
target: taf.Var,
source: taf.Invokation,
) -> None:
def push(register: Reg) -> None:
print(f"push {register}")
self.emit(Stor(Reg.sp(), register))
self.emit(Addim(Reg.sp(), Reg.sp(), Integer(8)))

def pop(register: Reg) -> None:
print(f"pop {register}")
self.emit(Addim(Reg.sp(), Reg.sp(), Integer(-8)))
self.emit(Load(register, Reg.sp()))

local_vars = list(self.procedure.locals)

for local in local_vars:
if self.is_alive(local):
push(self.registers.binded_to(local))
push(Reg.ra())

for i, arg in enumerate(source.args):
arg_reg = self.registers.binded_to(arg)
self.emit(mov(PhysReg.arg(i + 1), arg_reg))

prev_ra = self.registers.temporary()
self.emit(mov(prev_ra, Reg.ra()))

proc_reg = self.registers.binded_to(source.closure)
self.emit(Addim(Reg.ra(), Reg.ip(), Integer(4)))
self.emit(Brn(Reg.ze(), proc_reg))

res_reg = self.registers.binded_to(target)
self.emit(mov(res_reg, Reg.a1()))
pop(Reg.ra())
for local in local_vars[::-1]:
if self.is_alive(local):
pop(self.registers.binded_to(local))

self.emit(mov(Reg.ra(), prev_ra))
result = self.registers.binded_to(target)
self.emit(mov(result, Reg.a1()))

@override
def on_load(self, target: taf.Var, source: taf.Load) -> None:
Expand Down Expand Up @@ -219,6 +246,15 @@ def addr_of(self, cnst: taf.Const) -> Immediate:
case _:
raise NotImplementedError

def is_alive(self, var: taf.Var) -> bool:
nxt = self.usages.next_read(var, self.context)
prv = self.usages.next_write(var, Context(-1, self.procedure.entry))
return (
nxt is not None #
and prv is not None #
and prv < self.position
)

@property
def next_instr_addr(self) -> int:
return len(self.memory.instr) * 4
12 changes: 6 additions & 6 deletions sleepy/interpreter/asmik.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self) -> None:
self.registers["ze"] = 0
self.registers["ip"] = 0
self.registers["ra"] = self.STOP
self.registers["sp"] = 10000

self.running = False

Expand Down Expand Up @@ -62,6 +63,7 @@ def run(self) -> None:
self.running = False

def execute(self, instr: Instruction) -> None:
print(f"exec {instr}")
match instr:
case Addi(dst, lhs, rhs):
self.write(dst, self.read(lhs) + self.read(rhs))
Expand Down Expand Up @@ -101,9 +103,7 @@ def read(self, reg: Register) -> int:
return self.registers[repr(reg)]

def write(self, reg: Register, value: int) -> None:
match reg:
case "ze":
message = "ze is readonly"
raise SleepyError(message)
case _:
self.registers[repr(reg)] = value
if reg == Register.ze():
message = "ze is readonly"
raise SleepyError(message)
self.registers[repr(reg)] = value
2 changes: 1 addition & 1 deletion sleepy/tafka/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@
)
from .unit import TafkaUnit
from .usage import Usages
from .walker import TafkaWalker
from .walker import Context, TafkaWalker
40 changes: 26 additions & 14 deletions sleepy/tafka/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def __init__(self, unit: ProgramUnit) -> None:

self.vars = MetaTable[taf.Var]()

self.current_block = self.main.entry
self.current_procedure = self.main
self.current_block = self.current_procedure.entry
self.last_result = taf.Var("0", taf.Int())

@override
Expand Down Expand Up @@ -122,31 +123,38 @@ def visit_application_variable(

@override
def visit_lambda(self, tree: program.Closure) -> None:
current_procedure = self.current_procedure
current_block = self.current_block

label = self.next_lbl()

params = [
self.next_var(taf.Kind.from_sleepy(param.kind))
for param in tree.parameters
]
procedure = taf.Procedure(
name=label.name,
entry=taf.Block(label, statements=[]),
parameters=[
self.next_var(taf.Kind.from_sleepy(param.kind))
for param in tree.parameters
],
value=taf.Unknown(),
)

for param, var in zip(tree.parameters, params, strict=True):
self.vars[param.name] = var
self.current_procedure = procedure
self.current_block = procedure.entry

body = taf.Block(label, statements=[])
for param, var in zip(
tree.parameters,
procedure.parameters,
strict=True,
):
self.vars[param.name] = var

self.current_block = body
for statement in tree.statements:
self.visit_expression(statement)

self.emit_statement(taf.Return(self.last_result))

value = self.last_result.kind

self.current_block = current_block
self.current_procedure = current_procedure

procedure = taf.Procedure(label.name, body, params, value)
self.procedures.append(procedure)

self.emit_intermidiate(
Expand All @@ -173,13 +181,17 @@ def visit_definition(self, tree: program.Definition) -> None:
def emit_statement(self, statement: taf.Statement) -> None:
if isinstance(statement, taf.Set):
self.last_result = statement.target
if isinstance(statement, taf.Return):
self.current_procedure.value = self.last_result.kind
self.current_block.statements.append(statement)

def emit_intermidiate(self, rvalue: taf.RValue) -> None:
self.emit_statement(taf.Set(self.next_var(rvalue.value), rvalue))

def next_var(self, kind: taf.Kind) -> taf.Var:
return taf.Var(next(self.var_names), kind)
var = taf.Var(next(self.var_names), kind)
self.current_procedure.locals.add(var)
return var

def next_lbl(self) -> taf.Label:
return taf.Label(next(self.lbl_names))
3 changes: 2 additions & 1 deletion sleepy/tafka/representation/block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import cast

from .kind import Kind, Signature
Expand Down Expand Up @@ -60,6 +60,7 @@ class Procedure(Node):
entry: Block
parameters: list[Var]
value: Kind
locals: set[Var] = field(default_factory=lambda: set(), init=False)

@property
def signature(self) -> Signature:
Expand Down
10 changes: 7 additions & 3 deletions sleepy/tafka/representation/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@ def from_sleepy(cls, kind: SleepyKind) -> "Kind":
raise NotImplementedError


@dataclass(repr=False)
@dataclass(repr=False, unsafe_hash=True)
class Unknown(Kind):
@override
def __repr__(self) -> str:
return "?"


@dataclass(repr=False)
@dataclass(repr=False, unsafe_hash=True)
class Int(Kind):
@override
def __repr__(self) -> str:
return "int"


@dataclass(repr=False)
@dataclass(repr=False, unsafe_hash=True)
class Bool(Kind):
@override
def __repr__(self) -> str:
Expand All @@ -44,6 +44,10 @@ class Signature(Kind):
params: list[Kind]
value: Kind

@override
def __hash__(self) -> int:
return hash(str(self.params)) + hash(self.value)

@override
def __repr__(self) -> str:
return f"({', '.join(repr(_) for _ in self.params)}) -> {self.value!r}"
6 changes: 3 additions & 3 deletions sleepy/tafka/representation/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
from .node import Node


@dataclass(repr=False)
@dataclass(repr=False, unsafe_hash=True)
class Symbol(Node):
name: str
kind: Kind


@dataclass(repr=False)
@dataclass(repr=False, unsafe_hash=True)
class Const(Symbol):
def __repr__(self) -> str:
return f"${self.name}: {self.kind}"


@dataclass(repr=False)
@dataclass(repr=False, unsafe_hash=True)
class Var(Symbol):
def __repr__(self) -> str:
return f"%{self.name}: {self.kind}"

0 comments on commit 69e5ab1

Please sign in to comment.