Skip to content

Commit

Permalink
Make passthrough work with state edges for terminating loops
Browse files Browse the repository at this point in the history
  • Loading branch information
kirstenmg committed Nov 30, 2024
1 parent 6b05164 commit dd3e209
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 68 deletions.
1 change: 1 addition & 0 deletions dag_in_context/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub fn prologue() -> String {
include_str!("utility/expr_size.egg"),
include_str!("utility/drop_at.egg"),
include_str!("interval_analysis.egg"),
include_str!("loop_iteration_analysis.egg"),
include_str!("optimizations/switch_rewrites.egg"),
include_str!("optimizations/select.egg"),
include_str!("optimizations/peepholes.egg"),
Expand Down
76 changes: 76 additions & 0 deletions dag_in_context/src/loop_iteration_analysis.egg
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
;; Analysis to get the number of iterations of a loop
(ruleset loop-iter-analysis)

;; inputs, outputs -> number of iterations
;; The minimum possible guess is 1 because of do-while loops
(function LoopNumItersGuess (Expr Expr) i64 :merge (max 1 (min old new)))

;; Marks loops that we know will terminate
(relation TerminatingLoop (Expr Expr))

;; by default, guess that all loops run 1000 times
(rule ((DoWhile inputs outputs))
((set (LoopNumItersGuess inputs outputs) 1000))
:ruleset loop-iter-analysis)

;; For a loop that is false, its num iters is 1
(rule
((= loop (DoWhile inputs outputs))
(= (Const (Bool false) ty ctx) (Get outputs 0)))
((set (LoopNumItersGuess inputs outputs) 1)
(TerminatingLoop inputs outputs))
:ruleset loop-iter-analysis)

;; Figure out number of iterations for a loop with constant bounds and initial value
;; and i is updated before checking pred
;; TODO: we could make it work for decrementing loops
(rule
((= lhs (DoWhile inputs outputs))
(= num-inputs (tuple-length inputs))
(= pred (Get outputs 0))
;; iteration counter starts at start_const
(= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i))
;; updated counter at counter_i
(= next_counter (Get outputs (+ counter_i 1)))
;; increments by some constant each loop
(= next_counter (Bop (Add) (Get (Arg _ty _ctx) counter_i)
(Const (Int increment) _ty2 _ctx2)))
(> increment 0)
;; while next_counter less than end_constant
(= pred (Bop (LessThan) next_counter
(Const (Int end_constant) _ty3 _ctx3)))
;; end constant is at least start constant
(>= end_constant start_const)
)
(
(set (LoopNumItersGuess inputs outputs) (/ (- end_constant start_const) increment))
(TerminatingLoop inputs outputs)
)
:ruleset loop-iter-analysis)

;; Figure out number of iterations for a loop with constant bounds and initial value
;; and i is updated after checking pred
(rule
((= lhs (DoWhile inputs outputs))
(= num-inputs (tuple-length inputs))
(= pred (Get outputs 0))
;; iteration counter starts at start_const
(= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i))
;; updated counter at counter_i
(= next_counter (Get outputs (+ counter_i 1)))
;; increments by a constant each loop
(= next_counter (Bop (Add) (Get (Arg _ty _ctx) counter_i)
(Const (Int increment) _ty2 _ctx2)))
(> increment 0)
;; while this counter less than end_constant
(= pred (Bop (LessThan) (Get (Arg _ty _ctx) counter_i)
(Const (Int end_constant) _ty3 _ctx3)))
;; end constant is at least start constant
(>= end_constant start_const)
)
(
(set (LoopNumItersGuess inputs outputs) (+ (/ (- end_constant start_const) increment) 1))
(TerminatingLoop inputs outputs)
)
:ruleset loop-iter-analysis)

68 changes: 1 addition & 67 deletions dag_in_context/src/optimizations/loop_unroll.egg
Original file line number Diff line number Diff line change
@@ -1,75 +1,9 @@
;; Some simple simplifications of loops
;; Depends on loop iteration analysis
(ruleset loop-unroll)
(ruleset loop-peel)
(ruleset loop-iters-analysis)

;; inputs, outputs -> number of iterations
;; The minimum possible guess is 1 because of do-while loops
;; TODO: dead loop deletion can turn loops with a false condition to a body
(function LoopNumItersGuess (Expr Expr) i64 :merge (max 1 (min old new)))

;; by default, guess that all loops run 1000 times
(rule ((DoWhile inputs outputs))
((set (LoopNumItersGuess inputs outputs) 1000))
:ruleset loop-iters-analysis)

;; For a loop that is false, its num iters is 1
(rule
((= loop (DoWhile inputs outputs))
(= (Const (Bool false) ty ctx) (Get outputs 0)))
((set (LoopNumItersGuess inputs outputs) 1))
:ruleset loop-iters-analysis)

;; Figure out number of iterations for a loop with constant bounds and initial value
;; and i is updated before checking pred
;; TODO: we could make it work for decrementing loops
(rule
((= lhs (DoWhile inputs outputs))
(= pred (Get outputs 0))
;; iteration counter starts at start_const
(= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i))
;; updated counter at counter_i
(= next_counter (Get outputs (+ counter_i 1)))
;; increments by some constant each loop
(= next_counter (Bop (Add) (Get (Arg _ty _ctx) counter_i)
(Const (Int increment) _ty2 _ctx2)))
(> increment 0)
;; while next_counter less than end_constant
(= pred (Bop (LessThan) next_counter
(Const (Int end_constant) _ty3 _ctx3)))
;; end constant is at least start constant
(>= end_constant start_const)
)
(
(set (LoopNumItersGuess inputs outputs) (/ (- end_constant start_const) increment))
)
:ruleset loop-iters-analysis)

;; Figure out number of iterations for a loop with constant bounds and initial value
;; and i is updated after checking pred
(rule
((= lhs (DoWhile inputs outputs))
(= pred (Get outputs 0))
;; iteration counter starts at start_const
(= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i))
(= body-arg (Get (Arg _ty _ctx) counter_i))
;; updated counter at counter_i
(= next_counter (Get outputs (+ counter_i 1)))
;; increments by a constant each loop
(= next_counter (Bop (Add) body-arg
(Const (Int increment) _ty2 _ctx2)))
(> increment 0)
;; while this counter less than end_constant
(= pred (Bop (LessThan) body-arg
(Const (Int end_constant) _ty3 _ctx3)))
;; end constant is at least start constant
(>= end_constant start_const)
)
(
(set (LoopNumItersGuess inputs outputs) (+ (/ (- end_constant start_const) increment) 1))
)
:ruleset loop-iters-analysis)

;; loop peeling rule
;; Only peel loops that we know iterate < 3 times
(function LoopPeeledPlaceholder (Expr) Assumption :unextractable)
Expand Down
13 changes: 12 additions & 1 deletion dag_in_context/src/optimizations/passthrough.egg
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
;; Relies on loop iteration analysis
(ruleset passthrough)


;; Pass through thetas
;; Pass through thetas: pure case
(rule ((= lhs (Get loop i))
(= loop (DoWhile inputs pred-outputs))
(= (Get pred-outputs (+ i 1)) (Get (Arg _ty _ctx) i))
Expand All @@ -13,6 +14,16 @@
((union lhs (Get inputs i)))
:ruleset passthrough)

;; Pass through thetas: state edge case
(rule ((= lhs (Get loop i))
(= loop (DoWhile inputs pred-outputs))
(= (Get pred-outputs (+ i 1)) (Get (Arg _ty _ctx) i))
;; It is OK to pass through state edges as long as the loop terminates
(TerminatingLoop inputs pred-outputs)
)
((union lhs (Get inputs i)))
:ruleset passthrough)

;; Pass through switch arguments
(rule ((= lhs (Get switch i))
(= switch (Switch pred inputs branches))
Expand Down

0 comments on commit dd3e209

Please sign in to comment.