Skip to content

Compiler: allow to inline primitives defined in the js runtime into the generated code #1928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/bin-js_of_ocaml/check_runtime.ml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ let f (runtime_files, bytecode, target_env) =
Config.set_target `JavaScript;
Config.set_effects_backend `Disabled;
Linker.reset ();
Generate.reset ();
let runtime_files, builtin =
List.partition_map runtime_files ~f:(fun name ->
match Builtins.find name with
Expand Down
1 change: 1 addition & 0 deletions compiler/bin-js_of_ocaml/compile.ml
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ let run
Jsoo_cmdline.Arg.eval common;
Config.set_effects_backend effects;
Linker.reset ();
Generate.reset ();
(match output_file with
| `Stdout, _ -> ()
| `Name name, _ when debug_mem () -> Debug.start_profiling name
Expand Down
1 change: 1 addition & 0 deletions compiler/bin-js_of_ocaml/link.ml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ let f
Config.set_target `JavaScript;
Jsoo_cmdline.Arg.eval common;
Linker.reset ();
Generate.reset ();
let with_output f =
match output_file with
| None -> f stdout
Expand Down
1 change: 1 addition & 0 deletions compiler/lib-dynlink/js_of_ocaml_compiler_dynlink.ml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ let () =
Config.set_effects_backend (Jsoo_runtime.Sys.Config.effects ());
Linker.reset ();
List.iter aliases ~f:(fun (a, b) -> Primitive.alias a b);
Generate.reset ();
(* this needs to stay synchronized with toplevel.js *)
let toplevel_compile (s : string) (debug : Instruct.debug_event list array) :
unit -> J.t =
Expand Down
1 change: 1 addition & 0 deletions compiler/lib-runtime-files/gen/gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ let () =
| `Effects b -> Js_of_ocaml_compiler.Config.set_effects_backend b);
List.iter Js_of_ocaml_compiler.Target_env.all ~f:(fun target_env ->
Js_of_ocaml_compiler.Linker.reset ();
Js_of_ocaml_compiler.Generate.reset ();
List.iter fragments ~f:(fun (filename, frags) ->
Js_of_ocaml_compiler.Linker.load_fragments ~target_env ~filename frags);
let linkinfos = Js_of_ocaml_compiler.Linker.init () in
Expand Down
1 change: 1 addition & 0 deletions compiler/lib/annot_lexer.mll
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ rule main = parse
| "Requires" {TRequires}
| "Version" {TVersion}
| "Weakdef" {TWeakdef}
| "Inline" {TInline}
| "Always" {TAlways}
| "If" {TIf}
| "Alias" {TAlias}
Expand Down
3 changes: 2 additions & 1 deletion compiler/lib/annot_parser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*)

%token TProvides TRequires TVersion TWeakdef TIf TAlways TAlias
%token TProvides TRequires TVersion TWeakdef TInline TIf TAlways TAlias
%token TA_Pure TA_Const TA_Mutable TA_Mutator TA_Shallow TA_Object_literal
%token<string> TIdent TIdent_percent TVNum
%token TComma TColon EOF EOL LE LT GE GT EQ LPARENT RPARENT
Expand All @@ -40,6 +40,7 @@ annot:
| TVersion TColon l=separated_nonempty_list(TComma,version) endline
{ `Version (l) }
| TWeakdef endline { `Weakdef }
| TInline endline { `Inline }
| TAlways endline { `Always }
| TDeprecated endline { `Deprecated $1 }
| TAlias TColon name=TIdent endline { `Alias (name) }
Expand Down
191 changes: 56 additions & 135 deletions compiler/lib/generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -365,13 +365,6 @@ let one = J.ENum (J.Num.of_targetint Targetint.one)

let zero = J.ENum (J.Num.of_targetint Targetint.zero)

let plus_int x y =
match x, y with
| J.ENum y, x when J.Num.is_zero y -> x
| x, J.ENum y when J.Num.is_zero y -> x
| J.ENum x, J.ENum y -> J.ENum (J.Num.add x y)
| x, y -> J.EBin (J.Plus, x, y)

let bool e = J.ECond (e, one, zero)

(****)
Expand Down Expand Up @@ -1082,16 +1075,6 @@ let register_un_prims names ?(need_loc = false) k f =

let register_un_prim name k f = register_un_prims [ name ] k f

let register_un_prim_ctx name k f =
register_prims [ name ] k (fun name l ctx loc ->
match l with
| [ x ] ->
let open Expr_builder in
let* cx = access' ~ctx x in
let* () = info (kind k) in
return (f ctx cx loc)
| _ -> invalid_arity name l ~loc ~expected:1)

let register_bin_prims names k f =
register_prims names k (fun name l ctx loc ->
match l with
Expand Down Expand Up @@ -1119,28 +1102,7 @@ let register_tern_prims names k f =

let register_tern_prim name k f = register_tern_prims [ name ] k f

let register_un_math_prim name prim =
let prim = Utf8_string.of_string_exn prim in
register_un_prim name `Pure (fun cx loc ->
J.call (J.dot (s_var "Math") prim) [ cx ] loc)

let register_bin_math_prim name prim =
let prim = Utf8_string.of_string_exn prim in
register_bin_prims [ name ] `Pure (fun cx cy loc ->
J.call (J.dot (s_var "Math") prim) [ cx; cy ] loc)

let _ =
register_un_prim_ctx "%caml_format_int_special" `Pure (fun ctx cx loc ->
let s = J.EBin (J.Plus, str_js_utf8 "", cx) in
ocaml_string ~ctx ~loc s);
register_un_prim "%direct_obj_tag" `Pure (fun cx _loc -> Mlvalue.Block.tag cx);
register_bin_prims
[ "caml_array_unsafe_get"
; "caml_array_unsafe_get_float"
; "caml_floatarray_unsafe_get"
]
`Mutable
(fun cx cy _ -> Mlvalue.Array.field cx cy);
register_un_prims
[ "caml_int32_of_int"
; "caml_int32_to_int"
Expand All @@ -1154,83 +1116,6 @@ let _ =
]
`Pure
(fun cx _ -> cx);
register_bin_prims
[ "%int_add"; "caml_int32_add"; "caml_nativeint_add" ]
`Pure
(fun cx cy _ ->
match cx, cy with
| J.EBin (J.Minus, cz, J.ENum n), J.ENum m ->
to_int (J.EBin (J.Plus, cz, J.ENum (J.Num.add m (J.Num.neg n))))
| _ -> to_int (plus_int cx cy));
register_bin_prims
[ "%int_sub"; "caml_int32_sub"; "caml_nativeint_sub" ]
`Pure
(fun cx cy _ ->
match cx, cy with
| J.EBin (J.Minus, cz, J.ENum n), J.ENum m ->
to_int (J.EBin (J.Minus, cz, J.ENum (J.Num.add n m)))
| _ -> to_int (J.EBin (J.Minus, cx, cy)));
register_bin_prim "%direct_int_mul" `Pure (fun cx cy _ ->
to_int (J.EBin (J.Mul, cx, cy)));
register_bin_prim "%direct_int_div" `Pure (fun cx cy _ ->
to_int (J.EBin (J.Div, cx, cy)));
register_bin_prim "%direct_int_mod" `Pure (fun cx cy _ ->
to_int (J.EBin (J.Mod, cx, cy)));
register_bin_prims
[ "%int_and"; "caml_int32_and"; "caml_nativeint_and" ]
`Pure
(fun cx cy _ -> J.EBin (J.Band, cx, cy));
register_bin_prims
[ "%int_or"; "caml_int32_or"; "caml_nativeint_or" ]
`Pure
(fun cx cy _ -> J.EBin (J.Bor, cx, cy));
register_bin_prims
[ "%int_xor"; "caml_int32_xor"; "caml_nativeint_xor" ]
`Pure
(fun cx cy _ -> J.EBin (J.Bxor, cx, cy));
register_bin_prims
[ "%int_lsl"; "caml_int32_shift_left"; "caml_nativeint_shift_left" ]
`Pure
(fun cx cy _ -> J.EBin (J.Lsl, cx, cy));
register_bin_prims
[ "%int_lsr"
; "caml_int32_shift_right_unsigned"
; "caml_nativeint_shift_right_unsigned"
]
`Pure
(fun cx cy _ -> to_int (J.EBin (J.Lsr, cx, cy)));
register_bin_prims
[ "%int_asr"; "caml_int32_shift_right"; "caml_nativeint_shift_right" ]
`Pure
(fun cx cy _ -> J.EBin (J.Asr, cx, cy));
register_un_prims
[ "%int_neg"; "caml_int32_neg"; "caml_nativeint_neg" ]
`Pure
(fun cx _ -> to_int (J.EUn (J.Neg, cx)));
register_bin_prim "caml_eq_float" `Pure (fun cx cy _ ->
bool (J.EBin (J.EqEqEq, cx, cy)));
register_bin_prim "caml_neq_float" `Pure (fun cx cy _ ->
bool (J.EBin (J.NotEqEq, cx, cy)));
register_bin_prim "caml_ge_float" `Pure (fun cx cy _ -> bool (J.EBin (J.Le, cy, cx)));
register_bin_prim "caml_le_float" `Pure (fun cx cy _ -> bool (J.EBin (J.Le, cx, cy)));
register_bin_prim "caml_gt_float" `Pure (fun cx cy _ -> bool (J.EBin (J.Lt, cy, cx)));
register_bin_prim "caml_lt_float" `Pure (fun cx cy _ -> bool (J.EBin (J.Lt, cx, cy)));
register_bin_prim "caml_add_float" `Pure (fun cx cy _ -> J.EBin (J.Plus, cx, cy));
register_bin_prim "caml_sub_float" `Pure (fun cx cy _ -> J.EBin (J.Minus, cx, cy));
register_bin_prim "caml_mul_float" `Pure (fun cx cy _ -> J.EBin (J.Mul, cx, cy));
register_bin_prim "caml_div_float" `Pure (fun cx cy _ -> J.EBin (J.Div, cx, cy));
register_un_prim "caml_neg_float" `Pure (fun cx _ -> J.EUn (J.Neg, cx));
register_bin_prim "caml_fmod_float" `Pure (fun cx cy _ -> J.EBin (J.Mod, cx, cy));
register_tern_prims
[ "caml_array_unsafe_set"
; "caml_array_unsafe_set_float"
; "caml_floatarray_unsafe_set"
; "caml_array_unsafe_set_addr"
]
`Mutator
(fun cx cy cz _ -> J.EBin (J.Eq, Mlvalue.Array.field cx cy, cz));
register_un_prims [ "caml_alloc_dummy"; "caml_alloc_dummy_float" ] `Pure (fun _ _ ->
J.array []);
register_un_prims
[ "caml_int_of_float"
; "caml_int32_of_float"
Expand All @@ -1240,20 +1125,6 @@ let _ =
]
`Pure
(fun cx _loc -> to_int cx);
register_un_math_prim "caml_abs_float" "abs";
register_un_math_prim "caml_acos_float" "acos";
register_un_math_prim "caml_asin_float" "asin";
register_un_math_prim "caml_atan_float" "atan";
register_bin_math_prim "caml_atan2_float" "atan2";
register_un_math_prim "caml_ceil_float" "ceil";
register_un_math_prim "caml_cos_float" "cos";
register_un_math_prim "caml_exp_float" "exp";
register_un_math_prim "caml_floor_float" "floor";
register_un_math_prim "caml_log_float" "log";
register_bin_math_prim "caml_power_float" "pow";
register_un_math_prim "caml_sin_float" "sin";
register_un_math_prim "caml_sqrt_float" "sqrt";
register_un_math_prim "caml_tan_float" "tan";
register_un_prim "caml_js_from_bool" `Pure (fun cx _ ->
J.EUn (J.Not, J.EUn (J.Not, cx)));
register_un_prim "caml_js_to_bool" `Pure (fun cx _ -> to_int cx);
Expand Down Expand Up @@ -1318,6 +1189,17 @@ let remove_unused_tail_args ctx exact trampolined args =
else args
else args

(* var substitution *)
class subst sub =
object
inherit Js_traverse.map as super

method expression x =
match x with
| EVar v -> ( try sub v with Not_found -> super#expression x)
| _ -> super#expression x
end

let rec translate_expr ctx loc x e level : (_ * J.statement_list) Expr_builder.t =
let open Expr_builder in
match e with
Expand Down Expand Up @@ -1539,13 +1421,52 @@ let rec translate_expr ctx loc x e level : (_ * J.statement_list) Expr_builder.t
let name = Primitive.resolve name_orig in
match internal_prim name with
| Some f -> f name l ctx loc
| None ->
| None -> (
if String.starts_with name ~prefix:"%"
then failwith (Printf.sprintf "Unresolved internal primitive: %s" name);
let prim = Share.get_prim (runtime_fun ctx) name ctx.Ctx.share in
let* () = info ~need_loc:true (kind (Primitive.kind name)) in
let* args = list_map (fun x -> access' ~ctx x) l in
return (J.call prim args loc))
match Linker.inline ~name with
| Some (req, f)
when Option.is_none ctx.Ctx.exported_runtime || List.is_empty req -> (
let c = new Js_traverse.rename_variable ~esm:false in
let f = c#expression f in
match f with
| EFun
( None
, ( { async = false; generator = false }
, { list = params; rest = None }
, [ (Return_statement (Some body, _), _) ]
, _loc ) )
when List.length params = List.length l ->
let* l = list_map (fun x -> access' ~ctx x) l in
let params =
List.map params ~f:(fun (x, _) ->
match x with
| BindingIdent x -> x
| BindingPattern _ -> assert false)
in
let sub =
let t = Hashtbl.create (List.length l) in
List.iter2 params l ~f:(fun p x ->
let k =
match p with
| J.V v -> v
| _ -> assert false
in
Hashtbl.add t k x);

fun x ->
match x with
| J.S _ -> J.EVar x
| J.V x -> Hashtbl.find t x
in
let r = new subst sub in
return (r#expression body)
| _ -> assert false)
| None | Some _ ->
let prim = Share.get_prim (runtime_fun ctx) name ctx.Ctx.share in
let* () = info ~need_loc:true (kind (Primitive.kind name)) in
let* args = list_map (fun x -> access' ~ctx x) l in
return (J.call prim args loc)))
| Not, [ x ] ->
let* cx = access' ~ctx x in
return (J.EBin (J.Minus, one, cx))
Expand Down Expand Up @@ -2289,7 +2210,7 @@ let f
if times () then Format.eprintf " code gen.: %a@." Timer.print t';
p

let init () =
let reset () =
Hashtbl.iter
(fun name (k, _) -> Primitive.register name k None None)
internal_primitives
2 changes: 1 addition & 1 deletion compiler/lib/generate.mli
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ val f :
-> deadcode_sentinal:Code.Var.t
-> Javascript.program

val init : unit -> unit
val reset : unit -> unit
7 changes: 7 additions & 0 deletions compiler/lib/javascript.ml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ module Num : sig

val is_neg : t -> bool

val is_int : t -> bool

(** Arithmetic *)

val add : t -> t -> t
Expand Down Expand Up @@ -134,6 +136,11 @@ end = struct

let is_neg s = Char.equal s.[0] '-'

let is_int s =
String.for_all s ~f:(function
| '0' .. '9' | '-' -> true
| _ -> false)

let neg s =
match String.drop_prefix s ~prefix:"-" with
| None -> "-" ^ s
Expand Down
2 changes: 2 additions & 0 deletions compiler/lib/javascript.mli
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ module Num : sig

val is_neg : t -> bool

val is_int : t -> bool

(** Arithmetic *)

val add : t -> t -> t
Expand Down
20 changes: 10 additions & 10 deletions compiler/lib/js_traverse.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1702,23 +1702,23 @@ class simpl =

method expression e =
let e = super#expression e in
let is_zero x =
match Num.to_string x with
| "0" | "0." -> true
| _ -> false
in
match e with
| EBin (Plus, e1, e2) -> (
match e1, e2 with
| _, ENum n when Num.is_neg n -> EBin (Minus, e1, ENum (Num.neg n))
| ENum n, _ when Num.is_neg n -> EBin (Minus, e2, ENum (Num.neg n))
| ENum zero, (ENum _ as x) when is_zero zero -> x
| (ENum _ as x), ENum zero when is_zero zero -> x
| ENum n1, ENum n2 when Num.is_int n1 && Num.is_int n2 -> ENum (Num.add n1 n2)
| _, ENum n when Num.is_neg n ->
m#expression (EBin (Minus, e1, ENum (Num.neg n)))
| ENum n, _ when Num.is_neg n ->
m#expression (EBin (Minus, e2, ENum (Num.neg n)))
| ENum zero, x when Num.is_zero zero -> x
| x, ENum zero when Num.is_zero zero -> x
| _ -> e)
| EBin (Minus, e1, e2) -> (
match e1, e2 with
| EBin (Minus, e0, ENum n1), ENum n2 when Num.is_int n1 && Num.is_int n2 ->
EBin (Minus, e0, ENum (Num.add n1 n2))
| _, ENum n when Num.is_neg n -> EBin (Plus, e1, ENum (Num.neg n))
| (ENum _ as x), ENum zero when is_zero zero -> x
| (ENum _ as x), ENum zero when Num.is_zero zero -> x
| _ -> e)
| EFun
(None, (({ generator = false; async = true | false }, _, body, _) as fun_decl))
Expand Down
Loading
Loading