Skip to content

Use more precise Wasm types #1907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion compiler/lib-wasm/closure_conversion.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ open Code
type closure =
{ functions : (Var.t * int) list
; free_variables : Var.t list
; mutable id : int option
}

module SCC = Strongly_connected_components.Make (Var)
Expand Down Expand Up @@ -144,7 +145,8 @@ let rec traverse var_depth closures program pc depth =
in
List.iter
~f:(fun (f, _) ->
closures := Var.Map.add f { functions; free_variables } !closures)
closures :=
Var.Map.add f { functions; free_variables; id = None } !closures)
functions;
fun_lst)
components
Expand Down
1 change: 1 addition & 0 deletions compiler/lib-wasm/closure_conversion.mli
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
type closure =
{ functions : (Code.Var.t * int) list
; free_variables : Code.Var.t list
; mutable id : int option
}

val f : Code.program -> Code.program * closure Code.Var.Map.t
242 changes: 224 additions & 18 deletions compiler/lib-wasm/code_generation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ https://github.com/llvm/llvm-project/issues/58438
type constant_global =
{ init : W.expression option
; constant : bool
; typ : W.value_type
}

type context =
Expand All @@ -46,6 +47,7 @@ type context =
; types : (Var.t, Wasm_ast.type_field) Hashtbl.t
; mutable closure_envs : Var.t Var.Map.t
(** GC: mapping of recursive functions to their shared environment *)
; closure_types : (W.value_type option list, int) Hashtbl.t
; mutable apply_funs : Var.t IntMap.t
; mutable cps_apply_funs : Var.t IntMap.t
; mutable curry_funs : Var.t IntMap.t
Expand All @@ -71,6 +73,7 @@ let make_context ~value_type =
; type_names = Hashtbl.create 128
; types = Hashtbl.create 128
; closure_envs = Var.Map.empty
; closure_types = Hashtbl.create 128
; apply_funs = IntMap.empty
; cps_apply_funs = IntMap.empty
; curry_funs = IntMap.empty
Expand Down Expand Up @@ -196,6 +199,68 @@ let heap_type_sub (ty : W.heap_type) (ty' : W.heap_type) st =
(* I31, struct, array and none have no other subtype *)
| _, (I31 | Type _ | Struct | Array | None_) -> false, st

(*ZZZ*)
let rec type_index_lub ty ty' st =
if Var.equal ty ty'
then Some ty
else
let type_field = Hashtbl.find st.context.types ty in
match type_field.supertype with
| None -> None
| Some ty -> (
match type_index_lub ty ty' st with
| Some ty -> Some ty
| None -> (
let type_field = Hashtbl.find st.context.types ty' in
match type_field.supertype with
| None -> None
| Some ty' -> type_index_lub ty ty' st))

let heap_type_lub (ty : W.heap_type) (ty' : W.heap_type) =
match ty, ty' with
| (Func | Extern), _ | _, (Func | Extern) -> assert false
| None_, _ -> return ty'
| _, None_ | Struct, Struct | Array, Array -> return ty
| Any, _ | _, Any -> return W.Any
| Eq, _
| _, Eq
| (Struct | Array | Type _), I31
| I31, (Struct | Array | Type _)
| Struct, Array
| Array, Struct -> return (Eq : W.heap_type)
| Struct, Type t | Type t, Struct -> (
fun st ->
let type_field = Hashtbl.find st.context.types t in
match type_field.typ with
| Struct _ -> W.Struct, st
| Array _ | Func _ -> W.Eq, st)
| Array, Type t | Type t, Array -> (
fun st ->
let type_field = Hashtbl.find st.context.types t in
match type_field.typ with
| Array _ -> W.Struct, st
| Struct _ | Func _ -> W.Eq, st)
| Type t, Type t' -> (
let* r = fun st -> type_index_lub t t' st, st in
match r with
| Some t'' -> return (Type t'' : W.heap_type)
| None -> (
fun st ->
let type_field = Hashtbl.find st.context.types t in
let type_field' = Hashtbl.find st.context.types t' in
match type_field.typ, type_field'.typ with
| Struct _, Struct _ -> (Struct : W.heap_type), st
| Array _, Array _ -> W.Array, st
| (Array _ | Struct _ | Func _), (Array _ | Struct _ | Func _) -> W.Eq, st))
| I31, I31 -> return W.I31

let value_type_lub (ty : W.value_type) (ty' : W.value_type) =
match ty, ty' with
| Ref { nullable; typ }, Ref { nullable = nullable'; typ = typ' } ->
let* typ = heap_type_lub typ typ' in
return (W.Ref { nullable = nullable || nullable'; typ })
| _ -> assert false

let register_global name ?exported_name ?(constant = false) typ init st =
st.context.other_fields <-
W.Global { name; exported_name; typ; init } :: st.context.other_fields;
Expand All @@ -204,6 +269,7 @@ let register_global name ?exported_name ?(constant = false) typ init st =
name
{ init = (if not typ.mut then Some init else None)
; constant = (not typ.mut) || constant
; typ = typ.typ
}
st.context.constant_globals;
(), st
Expand Down Expand Up @@ -498,6 +564,69 @@ let load x =
| Local (_, x, _) -> return (W.LocalGet x)
| Expr e -> e

let value_type st = st.context.value_type, st

let rec variable_type x st =
match Var.Map.find_opt x st.vars with
| Some (Local (_, _, typ)) -> typ, st
| Some (Expr e) ->
(let* e = e in
expression_type e)
st
| None -> None, st

and expression_type (e : W.expression) st =
match e with
| Const _
| UnOp _
| BinOp _
| I32WrapI64 _
| I64ExtendI32 _
| F32DemoteF64 _
| F64PromoteF32 _
| BlockExpr _
| Call _
| RefFunc _
| Call_ref _
| I31Get _
| ArrayGet _
| ArrayLen _
| RefTest _
| RefEq _
| RefNull _
| Try _ -> None, st
| LocalGet x | LocalTee (x, _) -> variable_type x st
| GlobalGet x ->
( (try
let typ = (Var.Map.find x st.context.constant_globals).typ in
if Poly.equal typ st.context.value_type
then None
else
Some
(match typ with
| Ref { typ; nullable = true } -> Ref { typ; nullable = false }
| _ -> typ)
with Not_found -> None)
, st )
| Seq (_, e') -> expression_type e' st
| Pop typ -> Some typ, st
| RefI31 _ -> Some (Ref { nullable = false; typ = I31 }), st
| ArrayNew (ty, _, _)
| ArrayNewFixed (ty, _)
| ArrayNewData (ty, _, _, _)
| StructNew (ty, _) -> Some (Ref { nullable = false; typ = Type ty }), st
| StructGet (_, ty, i, _) -> (
match (Hashtbl.find st.context.types ty).typ with
| Struct l -> (
match (List.nth l i).typ with
| Value typ ->
(if Poly.equal typ st.context.value_type then None else Some typ), st
| Packed _ -> assert false)
| Array _ | Func _ -> assert false)
| RefCast (typ, _) | Br_on_cast (_, _, typ, _) | Br_on_cast_fail (_, typ, _, _) ->
Some (Ref typ), st
| IfExpr (_, _, _, _) | ExternConvertAny _ -> None, st

let tee ?typ x e =
let* e = e in
let* b = is_small_constant e in
Expand All @@ -506,12 +635,53 @@ let tee ?typ x e =
let* () = register_constant x e in
return e
else
let* typ =
match typ with
| Some _ -> return typ
| None -> expression_type e
in
let* i = add_var ?typ x in
return (W.LocalTee (i, e))

let should_make_global x st = Var.Set.mem x st.context.globalized_variables, st

let value_type st = st.context.value_type, st
let get_constant x st = Hashtbl.find_opt st.context.constants x, st

let placeholder_value typ f =
let* c = get_constant typ in
match c with
| None ->
let x = Var.fresh () in
let* () = register_constant typ (W.GlobalGet x) in
let* () =
register_global
~constant:true
x
{ mut = false; typ = Ref { nullable = false; typ = Type typ } }
(f typ)
in
return (W.GlobalGet x)
| Some c -> return c

let array_placeholder typ = placeholder_value typ (fun typ -> ArrayNewFixed (typ, []))

let default_value val_typ st =
match val_typ with
| W.Ref { typ = I31 | Eq | Any; _ } -> (W.RefI31 (Const (I32 0l)), val_typ, None), st
| W.Ref { typ = Type typ; nullable = false } -> (
match (Hashtbl.find st.context.types typ).typ with
| Array _ ->
(let* placeholder = array_placeholder typ in
return (placeholder, val_typ, None))
st
| Struct _ | Func _ ->
( ( W.RefNull (Type typ)
, W.Ref { typ = Type typ; nullable = true }
, Some { W.typ = Type typ; nullable = false } )
, st ))
| W.Ref { nullable = true; _ }
| W.Ref { typ = Func | Extern | Struct | Array | None_; _ }
| I32 | I64 | F32 | F64 -> assert false

let rec store ?(always = false) ?typ x e =
let* e = e in
Expand All @@ -527,25 +697,40 @@ let rec store ?(always = false) ?typ x e =
let* b = should_make_global x in
if b
then
let* typ =
match typ with
| Some typ -> return typ
| None -> value_type
in
let* () =
let* b = global_is_registered x in
if b
then return ()
else
register_global
~constant:true
x
{ mut = true; typ }
(W.RefI31 (Const (I32 0l)))
let* typ =
match typ with
| Some typ -> return typ
| None -> (
if always
then value_type
else
let* typ = expression_type e in
match typ with
| None -> value_type
| Some typ -> return typ)
in
let* default, typ', cast = default_value typ in
let* () =
register_constant
x
(match cast with
| Some typ -> W.RefCast (typ, W.GlobalGet x)
| None -> W.GlobalGet x)
in
register_global ~constant:true x { mut = true; typ = typ' } default
in
let* () = register_constant x (W.GlobalGet x) in
instr (GlobalSet (x, e))
else
let* typ =
match typ with
| Some _ -> return typ
| None -> if always then return None else expression_type e
in
let* i = add_var ?typ x in
instr (LocalSet (i, e))

Expand Down Expand Up @@ -578,13 +763,28 @@ let push e =
instr (Push e')
| _ -> instr (Push e)

let blk' ty l st =
let instrs = st.instrs in
let (), st = l { st with instrs = [] } in
let ty, st =
match st.instrs with
| Push e :: _ ->
(let* ty' = expression_type e in
match ty' with
| None -> return ty
| Some ty' -> return { ty with W.result = [ ty' ] })
st
| _ -> ty, st
in
(List.rev st.instrs, ty), { st with instrs }

let loop ty l =
let* instrs = blk l in
instr (Loop (ty, instrs))
let* instrs, ty' = blk' ty l in
instr (Loop (ty', instrs))

let block ty l =
let* instrs = blk l in
instr (Block (ty, instrs))
let* instrs, ty' = blk' ty l in
instr (Block (ty', instrs))

let block_expr ty l =
let* instrs = blk l in
Expand Down Expand Up @@ -657,7 +857,7 @@ let init_code context = instrs context.init_code

let function_body ~context ~param_names ~body =
let st = { var_count = 0; vars = Var.Map.empty; instrs = []; context } in
let (), st = body st in
let res, st = body st in
let local_count, body = st.var_count, List.rev st.instrs in
let local_types = Array.make local_count (Var.fresh (), None) in
List.iteri ~f:(fun i x -> local_types.(i) <- x, None) param_names;
Expand All @@ -675,4 +875,10 @@ let function_body ~context ~param_names ~body =
|> (fun a -> Array.sub a ~pos:param_count ~len:(Array.length a - param_count))
|> Array.to_list
in
locals, body
locals, res, body

let eval ~context e =
let st = { var_count = 0; vars = Var.Map.empty; instrs = []; context } in
let r, st = e st in
assert (st.var_count = 0 && List.is_empty st.instrs);
r
Loading
Loading