Skip to content

Commit 64a69b4

Browse files
committed
Wasm: specialization of number comparisons
1 parent 4bafb2e commit 64a69b4

File tree

6 files changed

+541
-11
lines changed

6 files changed

+541
-11
lines changed

compiler/lib-wasm/generate.ml

+100-3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ module Generate (Target : Target_sig.S) = struct
3636
{ live : int array
3737
; in_cps : Effects.in_cps
3838
; deadcode_sentinal : Var.t
39+
; types : Typing.typ Var.Tbl.t
3940
; blocks : block Addr.Map.t
4041
; closures : Closure_conversion.closure Var.Map.t
4142
; global_context : Code_generation.context
@@ -233,6 +234,39 @@ module Generate (Target : Target_sig.S) = struct
233234
f context (transl_prim_arg x) (transl_prim_arg y) (transl_prim_arg z)
234235
| _ -> invalid_arity name l ~expected:3)
235236

237+
let get_type ctx p =
238+
match p with
239+
| Pv x -> Var.Tbl.get ctx.types x
240+
| Pc c -> Typing.constant_type c
241+
242+
let register_comparison name cmp_int cmp_boxed_int cmp_float =
243+
register_prim name `Mutable (fun ctx _ transl_prim_arg l ->
244+
match l with
245+
| [ x; y ] -> (
246+
let x' = transl_prim_arg x in
247+
let y' = transl_prim_arg y in
248+
match get_type ctx x, get_type ctx y with
249+
| Number Int, Number Int -> cmp_int x' y'
250+
| Number Int32, Number Int32 ->
251+
let* x' = Memory.unbox_int32 x' in
252+
let* y' = Memory.unbox_int32 y' in
253+
Value.val_int (return (W.BinOp (I32 cmp_boxed_int, x', y')))
254+
| Number Nativeint, Number Nativeint ->
255+
let* x' = Memory.unbox_nativeint x' in
256+
let* y' = Memory.unbox_nativeint y' in
257+
Value.val_int (return (W.BinOp (I32 cmp_boxed_int, x', y')))
258+
| Number Int64, Number Int64 ->
259+
let* x' = Memory.unbox_int64 x' in
260+
let* y' = Memory.unbox_int64 y' in
261+
Value.val_int (return (W.BinOp (I64 cmp_boxed_int, x', y')))
262+
| Number Float, Number Float -> float_comparison cmp_float x' y'
263+
| _ ->
264+
let* f = register_import ~name (Fun (func_type 2)) in
265+
let* x' = x' in
266+
let* y' = y' in
267+
return (W.Call (f, [ x'; y' ])))
268+
| _ -> invalid_arity name l ~expected:2)
269+
236270
let () =
237271
register_bin_prim "caml_array_unsafe_get" `Mutable Memory.gen_array_get;
238272
register_bin_prim "caml_floatarray_unsafe_get" `Mutable Memory.float_array_get;
@@ -605,7 +639,66 @@ module Generate (Target : Target_sig.S) = struct
605639
l
606640
~init:(return [])
607641
in
608-
Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal l)
642+
Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal l);
643+
register_comparison "caml_greaterthan" (fun y x -> Value.lt x y) (Gt S) Gt;
644+
register_comparison "caml_greaterequal" (fun y x -> Value.le x y) (Ge S) Ge;
645+
register_comparison "caml_lessthan" Value.lt (Lt S) Lt;
646+
register_comparison "caml_lessequal" Value.le (Le S) Le;
647+
register_comparison
648+
"caml_equal"
649+
(fun x y ->
650+
let* x = x in
651+
let* y = y in
652+
return (W.RefEq (x, y)))
653+
Eq
654+
Eq;
655+
register_comparison
656+
"caml_notequal"
657+
(fun x y ->
658+
let* x = x in
659+
let* y = y in
660+
return (W.UnOp (I32 Eqz, RefEq (x, y))))
661+
Ne
662+
Ne;
663+
register_prim "caml_compare" `Mutable (fun ctx _ transl_prim_arg l ->
664+
match l with
665+
| [ x; y ] -> (
666+
let x' = transl_prim_arg x in
667+
let y' = transl_prim_arg y in
668+
match get_type ctx x, get_type ctx y with
669+
| Number Int, Number Int ->
670+
Value.val_int
671+
Arith.(
672+
(Value.int_val y' < Value.int_val x')
673+
- (Value.int_val x' < Value.int_val y'))
674+
| Number Int32, Number Int32 ->
675+
let* f = register_import ~name:"caml_int32_compare" (Fun (func_type 2)) in
676+
let* x' = Memory.unbox_int32 x' in
677+
let* y' = Memory.unbox_int32 y' in
678+
return (W.Call (f, [ x'; y' ]))
679+
| Number Nativeint, Number Nativeint ->
680+
let* f =
681+
register_import ~name:"caml_nativeint_compare" (Fun (func_type 2))
682+
in
683+
let* x' = Memory.unbox_nativeint x' in
684+
let* y' = Memory.unbox_nativeint y' in
685+
return (W.Call (f, [ x'; y' ]))
686+
| Number Int64, Number Int64 ->
687+
let* f = register_import ~name:"caml_int64_compare" (Fun (func_type 2)) in
688+
let* x' = Memory.unbox_int64 x' in
689+
let* y' = Memory.unbox_int64 y' in
690+
return (W.Call (f, [ x'; y' ]))
691+
| Number Float, Number Float ->
692+
let* f = register_import ~name:"caml_float_compare" (Fun (func_type 2)) in
693+
let* x' = Memory.unbox_int64 x' in
694+
let* y' = Memory.unbox_int64 y' in
695+
return (W.Call (f, [ x'; y' ]))
696+
| _ ->
697+
let* f = register_import ~name:"caml_compare" (Fun (func_type 2)) in
698+
let* x' = x' in
699+
let* y' = y' in
700+
return (W.Call (f, [ x'; y' ])))
701+
| _ -> invalid_arity "caml_compare" l ~expected:2)
609702

610703
let rec translate_expr ctx context x e =
611704
match e with
@@ -1175,7 +1268,8 @@ module Generate (Target : Target_sig.S) = struct
11751268
~should_export
11761269
~warn_on_unhandled_effect
11771270
*)
1178-
~deadcode_sentinal =
1271+
~deadcode_sentinal
1272+
~types =
11791273
global_context.unit_name <- unit_name;
11801274
let p, closures = Closure_conversion.f p in
11811275
(*
@@ -1185,6 +1279,7 @@ module Generate (Target : Target_sig.S) = struct
11851279
{ live = live_vars
11861280
; in_cps
11871281
; deadcode_sentinal
1282+
; types
11881283
; blocks = p.blocks
11891284
; closures
11901285
; global_context
@@ -1292,8 +1387,10 @@ let start () = make_context ~value_type:Gc_target.Value.value
12921387

12931388
let f ~context ~unit_name p ~live_vars ~in_cps ~deadcode_sentinal =
12941389
let t = Timer.make () in
1390+
let state, info = Global_flow.f' ~fast:false p in
1391+
let types = Typing.f ~state ~info p in
12951392
let p = fix_switch_branches p in
1296-
let res = G.f ~context ~unit_name ~live_vars ~in_cps ~deadcode_sentinal p in
1393+
let res = G.f ~context ~unit_name ~live_vars ~in_cps ~deadcode_sentinal ~types p in
12971394
if times () then Format.eprintf " code gen.: %a@." Timer.print t;
12981395
res
12991396

0 commit comments

Comments
 (0)