diff --git a/src/symbolic_choice.ml b/src/symbolic_choice.ml index e79935812..48c3c1d9d 100644 --- a/src/symbolic_choice.ml +++ b/src/symbolic_choice.ml @@ -9,93 +9,6 @@ open Hc exception Assertion of Expr.t * Thread.t -let check sym_bool thread (S (solver_module, solver)) = - let pc = Thread.pc thread in - let no = Bool.not sym_bool in - let no = Expr.simplify no in - match no.node.e with - | Val True -> false - | Val False -> true - | _ -> - let check = no :: pc in - let module Solver = (val solver_module) in - let r = Solver.check solver check in - not r - -let list_select b ({ Thread.pc; _ } as thread) (S (solver_module, s)) = - let v = Expr.simplify b in - match v.node.e with - | Val True -> [ (true, thread) ] - | Val False -> [ (false, thread) ] - | Val (Num (I32 _)) -> assert false - | _ -> ( - let module Solver = (val solver_module) in - let with_v = v :: pc in - let with_not_v = Bool.not v :: pc in - let sat_true = Solver.check s with_v in - let sat_false = Solver.check s with_not_v in - match (sat_true, sat_false) with - | false, false -> [] - | true, false | false, true -> [ (sat_true, thread) ] - | true, true -> - let thread1 = Thread.clone { thread with pc = with_v } in - let thread2 = Thread.clone { thread with pc = with_not_v } in - [ (true, thread1); (false, thread2) ] ) - -let fix_symbol (e : Expr.t) ~pc ~choices = - match e.node.e with - | Symbol sym -> (pc, sym) - | _ -> - let symbol_name = Format.sprintf "choice_i32_%i" choices in - let sym = Symbol.(symbol_name @: Ty_bitv S32) in - let assign = Expr.(Relop (Eq, mk_symbol sym, e) @: Ty_bitv S32) in - (assign :: pc, sym) - -let clone_if_needed ~orig_pc cases = - match cases with - | [ (i, thread) ] -> [ (i, { thread with Thread.pc = orig_pc }) ] - | cases -> - List.map - (fun (i, thread) -> - let thread = Thread.clone thread in - (i, thread) ) - cases - -let not_value sym value = - Expr.(Relop (Ne, mk_symbol sym, const_i32 value) @: Ty_bitv S32) - -let list_select_i32 sym_int thread (S (solver_module, solver)) = - let pc = Thread.pc thread in - let sym_int = Expr.simplify sym_int in - let orig_pc = pc in - let pc, symbol = fix_symbol sym_int ~pc ~choices:thread.choices in - match sym_int.node.e with - | Val (Num (I32 i)) -> [ (i, thread) ] - | _ -> - let module Solver = (val solver_module) in - let rec find_values values = - let additionnal = List.map (not_value symbol) values in - if not (Solver.check solver (additionnal @ pc)) then [] - else begin - let model = Solver.model ~symbols:[ symbol ] solver in - match model with - | None -> assert false (* ? *) - | Some model -> ( - let v = Model.evaluate model symbol in - match v with - | None -> assert false (* ? *) - | Some (Num (I32 i)) -> begin - let cond = Expr.Bitv.I32.(Expr.mk_symbol symbol = v i) in - let pc = cond :: pc in - let case = (i, { thread with pc; choices = thread.choices + 1 }) in - case :: find_values (i :: values) - end - | Some _ -> assert false ) - end - in - let cases = find_values [] in - clone_if_needed ~orig_pc cases - module Minimalist = struct type err = | Assert_fail @@ -162,21 +75,20 @@ module WQ = struct { mutex : Mutex.t ; cond : Condition.t ; queue : 'a Queue.t - ; mutable producers : int + ; mutable pledges : int ; mutable failed : bool } - let take_as_producer q = + let take q pledge = Mutex.lock q.mutex; - q.producers <- q.producers - 1; let r = try while Queue.is_empty q.queue do - if q.producers = 0 || q.failed then raise Exit; + if q.pledges = 0 || q.failed then raise Exit; Condition.wait q.cond q.mutex done; let v = Queue.pop q.queue in - q.producers <- q.producers + 1; + if pledge then q.pledges <- q.pledges + 1; Some v with Exit -> Condition.broadcast q.cond; @@ -185,39 +97,24 @@ module WQ = struct Mutex.unlock q.mutex; r - let take_as_consumer q = + let make_pledge q = Mutex.lock q.mutex; - let r = - try - while Queue.is_empty q.queue do - if q.producers = 0 || q.failed then raise Exit; - Condition.wait q.cond q.mutex - done; - let v = Queue.pop q.queue in - Some v - with Exit -> None - in - Mutex.unlock q.mutex; - r + q.pledges <- q.pledges + 1; + Mutex.unlock q.mutex - let rec read_as_seq (q : 'a t) : 'a Seq.t = - fun () -> - match take_as_consumer q with - | None -> Nil - | Some v -> Cons (v, read_as_seq q) - - let produce (q : 'a t) (f : 'a -> unit) = - let rec loop () = - match take_as_producer q with - | None -> () - | Some v -> - f v; - loop () - in + let end_pledge q = Mutex.lock q.mutex; - q.producers <- q.producers + 1; - Mutex.unlock q.mutex; - loop () + q.pledges <- q.pledges - 1; + Condition.broadcast q.cond; + Mutex.unlock q.mutex + + let rec read_as_seq (q : 'a t) ?(finalizer = Fun.const ()) : 'a Seq.t = + fun () -> + match take q false with + | None -> + finalizer (); + Nil + | Some v -> Cons (v, read_as_seq q ~finalizer) let push v q = Mutex.lock q.mutex; @@ -232,215 +129,445 @@ module WQ = struct Condition.broadcast q.cond; Mutex.unlock q.mutex - let with_produce q f ~started = - Mutex.lock q.mutex; - q.producers <- q.producers + 1; - Mutex.unlock q.mutex; - started (); - match f () with - | () -> - Mutex.lock q.mutex; - q.producers <- q.producers - 1; - if q.producers = 0 then Condition.broadcast q.cond; - Mutex.unlock q.mutex - | exception e -> - let bt = Printexc.get_raw_backtrace () in - fail q; - Printexc.raise_with_backtrace e bt - let init () = { mutex = Mutex.create () ; cond = Condition.create () ; queue = Queue.create () - ; producers = 0 + ; pledges = 0 ; failed = false } end -module Counter = struct - type t = - { mutable c : int - ; mutex : Mutex.t - ; cond : Condition.t - } +module Multicore = struct + (* + Multicore is based on several layers of monad transformers defined here + in submodules. The module as a whole is made to provide a monad to explore in parallel + different possibilites, with a notion of priority. + *) + module Prio = struct + (* + Currently there is no real notion of priority. Future extensions adding it will ho here. + *) + type t = Default - let incr t = - Mutex.lock t.mutex; - t.c <- t.c + 1; - Condition.broadcast t.cond; - Mutex.unlock t.mutex + let default = Default + end - let wait_n t n = - Mutex.lock t.mutex; - while t.c < n do - Condition.wait t.cond t.mutex - done; - Mutex.unlock t.mutex + module CoreImpl : sig + (* + The core implementation of the monad. It is isolated in a module to restict its exposed interface + and maintain its invariant. In particular, choose must guarantee that the Thread.t is cloned in each branch. + Using functions defined here should be foolproof. + *) + type 'a t - let create () = { c = 0; mutex = Mutex.create (); cond = Condition.create () } -end + val return : 'a -> 'a t -module Multicore = struct - type 'a st = St of (Thread.t -> solver -> 'a * Thread.t) [@@unboxed] + val bind : 'a t -> ('a -> 'b t) -> 'b t - type 'a t = - | Empty : 'a t - | Ret : 'a st -> 'a t - | Retv : 'a -> 'a t - | Bind : 'a t * ('a -> 'b t) -> 'b t - | Assert : vbool -> unit t - | Choice : vbool -> bool t - | Choice_i32 : int32 -> Stdlib.Int32.t t - | Trap : Trap.t -> 'a t - - type 'a eval = - | EVal of 'a - | ETrap of Trap.t - | EAssert of Encoding.Expr.t - - type 'a run_result = ('a eval * Thread.t) Seq.t - - let return v = Retv v [@@inline] - - let bind : type a b. a t -> (a -> b t) -> b t = - fun v f -> - match v with - | Empty -> Empty - | Trap t -> Trap t - | Retv v -> f v - | Assert _ | Ret _ | Choice _ | Choice_i32 _ | Bind _ -> Bind (v, f) - [@@inline] + val ( let* ) : 'a t -> ('a -> 'b t) -> 'b t - let ( let* ) = bind + val map : 'a t -> ('a -> 'b) -> 'b t - let map v f = - let* v in - return (f v) + val ( let+ ) : 'a t -> ('a -> 'b) -> 'b t - let ( let+ ) = map + val stop : 'a t - let select (cond : vbool) = - match cond.node.e with - | Val True -> Retv true - | Val False -> Retv false - | _ -> Choice cond - [@@inline] + val assertion_fail : Expr.t -> 'a t - let select_i32 (i : int32) = - match i.node.e with Val (Num (I32 v)) -> Retv v | _ -> Choice_i32 i + val trap : Trap.t -> 'a t + + val thread : Thread.t t + + val yield : unit t + + val solver : solver t + + val with_thread : (Thread.t -> 'a) -> 'a t + + val set_thread : Thread.t -> unit t + + val modify_thread : (Thread.t -> Thread.t) -> unit t + + (* + Indicates a possible choice between two values. Thread duplication + is already handled by choose and should not be done before by the caller. + *) + val choose : 'a t -> 'a t -> 'a t - let trap t = Trap t + type 'a eval = + | EVal of 'a + | ETrap of Trap.t + | EAssert of Encoding.Expr.t + + type 'a run_result = ('a eval * Thread.t) Seq.t + + val run : workers:int -> 'a t -> Thread.t -> 'a run_result + end = struct + module Schedulable = struct + (* + A monad representing computation that can be cooperatively scheduled and may need + Worker Local Storage (WLS). Computations can yield, and fork (Choice). + *) + type ('a, 'wls) t = Sched of ('wls -> ('a, 'wls) status) [@@unboxed] + + and ('a, 'wls) status = + | Now of 'a + | Yield of Prio.t * ('a, 'wls) t + | Choice of (('a, 'wls) status * ('a, 'wls) status) + | Stop + + let run (Sched mxf) wls = mxf wls + + let return x : _ t = Sched (Fun.const (Now x)) + + let return_status status = Sched (Fun.const status) + + let rec bind (mx : ('a, 'wls) t) (f : 'a -> ('b, 'wls) t) : _ t = + let rec bind_status (x : _ status) (f : _ -> _ status) : _ status = + match x with + | Now x -> f x + | Yield (prio, lx) -> + Yield (prio, Sched (fun wls -> bind_status (run lx wls) f)) + | Choice (mx1, mx2) -> Choice (bind_status mx1 f, bind_status mx2 f) + | Stop -> Stop + in + Sched + (fun wls -> + let argumented_f x = run (f x) wls in + match run mx wls with + | Yield (prio, mx) -> Yield (prio, bind mx f) + | x -> bind_status x argumented_f ) + + let ( let* ) = bind + + let map x f = + let* x in + return (f x) - let with_thread f = Ret (St (fun t _sol -> (f t, t))) [@@inline] + let ( let+ ) = map - let thread = Ret (St (fun t _sol -> (t, t))) + let yield prio = return_status (Yield (prio, Sched (Fun.const (Now ())))) - let solver = Ret (St (fun t sol -> (sol, t))) + let choose a b = Sched (fun wls -> Choice (run a wls, run b wls)) + + let stop : ('a, 'b) t = return_status Stop + + let worker_local : ('a, 'a) t = Sched (fun wls -> Now wls) + end + + module Scheduler = struct + (* + A scheduler for Schedulable values. + *) + type ('a, 'wls) work_queue = ('a, 'wls) Schedulable.t WQ.t + + type 'a res_queue = 'a WQ.t + + type ('a, 'wls) t = + { work_queue : ('a, 'wls) work_queue + ; res_writer : 'a res_queue + } + + let init_scheduler () = + let work_queue = WQ.init () in + let res_writer = WQ.init () in + { work_queue; res_writer } + + let add_init_task sched task = WQ.push task sched.work_queue + + let rec work wls sched = + let rec handle_status (t : _ Schedulable.status) sched = + match t with + | Stop -> () + | Now x -> WQ.push x sched.res_writer + | Yield (_prio, f) -> WQ.push f sched.work_queue + | Choice (m1, m2) -> + handle_status m1 sched; + handle_status m2 sched + in + match WQ.take sched.work_queue true with + | None -> () + | Some f -> begin + handle_status (Schedulable.run f wls) sched; + WQ.end_pledge sched.work_queue; + work wls sched + end + + let spawn_worker sched wls_init = + WQ.make_pledge sched.res_writer; + Domain.spawn (fun () -> + let wls = wls_init () in + try + work wls sched; + WQ.end_pledge sched.res_writer + with e -> + let bt = Printexc.get_raw_backtrace () in + WQ.fail sched.work_queue; + WQ.end_pledge sched.res_writer; + Printexc.raise_with_backtrace e bt ) + end + + module State = struct + (* + Add a notion of State to the Schedulable monad + ("Transformer without module functor" style) + *) + module M = Schedulable + + type 'a t = St of (Thread.t -> ('a * Thread.t, solver) M.t) [@@unboxed] + + let run (St mxf) st = mxf st + + let return x = St (fun st -> M.return (x, st)) + + let lift x = + let ( let+ ) = M.( let+ ) in + St + (fun st -> + let+ x in + (x, st) ) + + let bind mx f = + St + (fun st -> + let ( let* ) = M.( let* ) in + let* x, new_st = run mx st in + run (f x) new_st ) + + let ( let* ) = bind + + let map x f = + let* x in + return (f x) + + let liftF2 f x y = St (fun st -> f (run x st) (run y st)) + + let ( let+ ) = map + + let with_state f = St (fun st -> M.return (f st)) + + let modify_state f = St (fun st -> M.return ((), f st)) + end + + module Eval = struct + (* + Add a notion of faillibility to the evaluation + ("Transformer without module functor" style) + *) + module M = State + + type 'a eval = + | EVal of 'a + | ETrap of Trap.t + | EAssert of Encoding.Expr.t + + type 'a t = 'a eval M.t + + let return x : _ t = M.return (EVal x) + + let lift x = + let ( let+ ) = M.( let+ ) in + let+ x in + EVal x + + let bind (mx : _ t) f : _ t = + let ( let* ) = M.( let* ) in + let* mx in + match mx with + | EVal x -> f x + | ETrap _ as mx -> M.return mx + | EAssert _ as mx -> M.return mx + + let ( let* ) = bind + + let map mx f = + let ( let+ ) = M.( let+ ) in + let+ mx in + match mx with + | EVal x -> EVal (f x) + | ETrap _ as mx -> mx + | EAssert _ as mx -> mx + + let ( let+ ) = map + end + + include Eval + + (* + Here we define functions to seamlessly + operate on the three monads layers + *) + + let lift_schedulable (v : ('a, _) Schedulable.t) : 'a t = + lift (State.lift v) + + let with_thread f = lift (State.with_state (fun st -> (f st, st))) + + let thread = with_thread Fun.id + + let modify_thread f = lift (State.modify_state f) + + let set_thread st = modify_thread (Fun.const st) + + let clone_thread = modify_thread Thread.clone + + let solver = lift_schedulable Schedulable.worker_local + + let choose a b = + let a = + let* () = clone_thread in + a + in + let b = + let* () = clone_thread in + b + in + State.liftF2 Schedulable.choose a b + + let yield = lift_schedulable @@ Schedulable.yield Prio.default + + let stop = lift_schedulable Schedulable.stop + + type 'a run_result = ('a eval * Thread.t) Seq.t + + let run ~workers t thread = + let open Scheduler in + let sched = init_scheduler () in + add_init_task sched (State.run t thread); + let join_handles = + Array.map + (fun () -> spawn_worker sched fresh_solver) + (Array.init workers (Fun.const ())) + in + WQ.read_as_seq sched.res_writer ~finalizer:(fun () -> + Array.iter Domain.join join_handles ) + + let trap t = State.return (ETrap t) + + let assertion_fail c = State.return (EAssert c) + end + + (* + We can now use CoreImpl only through its exposed signature which + maintains all invariants. + *) + + include CoreImpl let add_pc (c : vbool) = match c.node.e with - | Val True -> Retv () - | Val False -> Empty - | _ -> Ret (St (fun t _sol -> ((), { t with pc = c :: t.pc }))) + | Val True -> return () + | Val False -> stop + | _ -> + let* thread in + let new_thread = { thread with pc = c :: thread.pc } in + set_thread new_thread [@@inline] - let assertion c = Assert c + (* + Yielding is currently done each time the solver is about to be called, + in check_reachability and get_model. + *) + let check_reachability = + let* () = yield in + let* (S (solver_module, s)) = solver in + let module Solver = (val solver_module) in + let* thread in + let sat = Solver.check s thread.pc in + if sat then return () else stop - type 'a global_state = - { w : hold WQ.t (* work *) - ; r : ('a eval * Thread.t) WQ.t (* results *) - ; start_counter : Counter.t - } + let get_model symbol = + let* () = yield in + let* (S (solver_module, s)) = solver in + let module Solver = (val solver_module) in + let+ thread in + let sat = Solver.check s thread.pc in + if not sat then None + else begin + let model = Solver.model ~symbols:[ symbol ] s in + match model with + | None -> + failwith "Unreachable: The problem is sat so a model should exist" + | Some model -> begin + match Model.evaluate model symbol with + | None -> + failwith + "Unreachable: The model exists so this symbol should evaluate" + | Some _ as v -> v + end + end - and 'a local_state = - { solver : solver - ; mutable next : hold option - ; global : 'a global_state - } + let get_model_or_stop symbol = + let* model = get_model symbol in + match model with Some v -> return v | None -> stop - and e_local_state = E_st : 'a local_state -> e_local_state [@@unboxed] - - and 'a cont = { k : Thread.t -> e_local_state -> 'a -> unit } [@@unboxed] - - and hold = H : Thread.t * 'a t * 'a cont -> hold - - let local_push st v = - match st.next with - | None -> st.next <- Some v - | Some _ -> WQ.push v st.global.w - - let rec step : type v. Thread.t -> v t -> v cont -> _ -> unit = - fun thread t cont st -> - match t with - | Empty -> () - | Retv v -> cont.k thread (E_st st) v - | Ret (St f) -> - let v, thread = f thread st.solver in - cont.k thread (E_st st) v - | Trap t -> WQ.push (ETrap t, thread) st.global.r - | Assert c -> - if check c thread st.solver then cont.k thread (E_st st) () - else - let no = Bool.not c in - let thread = { thread with pc = no :: thread.pc } in - WQ.push (EAssert c, thread) st.global.r - | Bind (v, f) -> - let k thread (E_st st) v = - let r = f v in - local_push st (H (thread, r, cont)) + let select (cond : Symbolic_value.vbool) = + let v = Expr.simplify cond in + match v.node.e with + | Val True -> return true + | Val False -> return false + | Val (Num (I32 _)) -> failwith "unreachable (type error)" + | _ -> + let true_branch = + let* () = add_pc v in + let+ () = check_reachability in + true in - step thread v { k } st - | Choice cond -> - let cases = list_select cond thread st.solver in - List.iter (fun (case, thread) -> cont.k thread (E_st st) case) cases - | Choice_i32 i -> - let cases = list_select_i32 i thread st.solver in - List.iter (fun (case, thread) -> cont.k thread (E_st st) case) cases - - let init_global () = - let w = WQ.init () in - let r = WQ.init () in - let start_counter = Counter.create () in - { w; r; start_counter } - - let push_first_work g thread t = - let k thread _st v = WQ.push (EVal v, thread) g.r in - WQ.push (H (thread, t, { k })) g.w - - let spawn_producer global _i = - let solver = fresh_solver () in - let st = { solver; next = None; global } in - let rec producer (H (thread, t, cont)) = - step thread t cont st; - match st.next with - | Some h -> - st.next <- None; - producer h - | None -> () - in - Domain.spawn (fun () -> - try - WQ.with_produce global.r - ~started:(fun () -> Counter.incr global.start_counter) - (fun () -> WQ.produce global.w producer) - with e -> - let bt = Printexc.get_raw_backtrace () in - WQ.fail global.w; - Printexc.raise_with_backtrace e bt ) - - let rec loop_and_do (s : 'a Seq.t) f : 'a Seq.t = - fun () -> - match s () with - | Cons (s, t) -> Cons (s, loop_and_do t f) - | Nil -> - f (); - Nil + let false_branch = + let* () = add_pc (Symbolic_value.Bool.not v) in + let+ () = check_reachability in + false + in + choose true_branch false_branch + [@@inline] + + let summary_symbol (e : Expr.t) = + let* thread in + match e.node.e with + | Symbol sym -> return (None, sym) + | _ -> + let choices = thread.choices in + let symbol_name = Format.sprintf "choice_i32_%i" choices in + let+ () = modify_thread (fun t -> { t with choices = choices + 1 }) in + let sym = Symbol.(symbol_name @: Ty_bitv S32) in + let assign = Expr.(Relop (Eq, mk_symbol sym, e) @: Ty_bitv S32) in + (Some assign, sym) + + let select_i32 (i : Symbolic_value.int32) = + let sym_int = Expr.simplify i in + match sym_int.node.e with + | Val (Num (I32 i)) -> return i + | _ -> + let* assign, symbol = summary_symbol sym_int in + let* () = + match assign with Some assign -> add_pc assign | None -> return () + in + let rec generator () = + let* possible_value = get_model_or_stop symbol in + let i = + match possible_value with + | Num (I32 i) -> i + | _ -> failwith "Unreachable: found symbol must be a value" + in + let this_value_cond = Expr.Bitv.I32.(Expr.mk_symbol symbol = v i) in + let not_this_value_cond = + (* != is **not** the physical equality here *) + Expr.Bitv.I32.(Expr.mk_symbol symbol != v i) + in + let this_val_branch = + let+ () = add_pc this_value_cond in + i + in + let not_this_val_branch = + let* () = add_pc not_this_value_cond in + generator () + in + choose this_val_branch not_this_val_branch + in + generator () - let run ~workers t thread = - let global = init_global () in - push_first_work global thread t; - let producers = Array.init workers (spawn_producer global) in - Counter.wait_n global.start_counter workers; - loop_and_do (WQ.read_as_seq global.r) (fun () -> - Array.iter Domain.join producers ) + let assertion c = + let* assertion_true = select c in + if assertion_true then return () else assertion_fail c end