diff --git a/iter.opam b/iter.opam index 410f54a..3526433 100644 --- a/iter.opam +++ b/iter.opam @@ -17,6 +17,7 @@ depends: [ "ounit2" {with-test} "mdx" {with-test & >= "1.3" } "odoc" {with-doc} + "containers" {with-test} ] tags: [ "iter" "iterator" "iter" "fold" ] homepage: "https://github.com/c-cube/iter/" diff --git a/src/Iter.ml b/src/Iter.ml index 6a9a5f7..b622a43 100644 --- a/src/Iter.ml +++ b/src/Iter.ml @@ -658,6 +658,16 @@ let take_while p seq k = exception ExitFoldWhile +let map_while f seq k = + let exception ExitMapWhile in + let consume x = + match f x with + | `Yield y -> k y + | `Return y -> k y; raise_notrace ExitMapWhile + | `Stop -> raise_notrace ExitMapWhile + in + try seq consume with ExitMapWhile -> () + let fold_while f s seq = let state = ref s in let consume x = diff --git a/src/Iter.mli b/src/Iter.mli index 1d7f5ad..7e64297 100644 --- a/src/Iter.mli +++ b/src/Iter.mli @@ -487,6 +487,17 @@ val take_while : ('a -> bool) -> 'a t -> 'a t Will work on an infinite iterator [s] if the predicate is false for at least one element of [s]. *) +val map_while : ('a -> [ `Yield of 'b | `Return of 'b | `Stop ]) -> 'a t -> 'b t +(** Maps over elements of the iterator, stopping early if the mapped function + returns [`Stop] or [`Return x]. At each iteration: + {ul + {- If [f] returns [`Yield y], [y] is added to the sequence and the + iteration continues.} + {- If [f] returns [`Stop], nothing is added to the sequence and the + iteration stops.} + {- If [f] returns [`Return y], [y] is added to the sequence and the + iteration stops.}} *) + val fold_while : ('a -> 'b -> 'a * [ `Stop | `Continue ]) -> 'a -> 'b t -> 'a (** Folds over elements of the iterator, stopping early if the accumulator returns [('a, `Stop)] diff --git a/tests/unit/dune b/tests/unit/dune index f1160d0..921c4dd 100644 --- a/tests/unit/dune +++ b/tests/unit/dune @@ -1,4 +1,4 @@ (tests (names t_iter) - (libraries iter qcheck-core qcheck-core.runner ounit2)) + (libraries iter qcheck-core qcheck-core.runner ounit2 containers)) diff --git a/tests/unit/t_iter.ml b/tests/unit/t_iter.ml index 2d6a2c4..b7b696d 100644 --- a/tests/unit/t_iter.ml +++ b/tests/unit/t_iter.ml @@ -237,6 +237,22 @@ let () = OUnit.assert_equal 2 n; () +let () = + OUnit.assert_equal + ~cmp:(CCList.equal Int.equal) + (1 -- 10 + |> map_while (fun x -> if x = 7 then `Return (x + 1) else `Yield (x - 1)) + |> to_list) + [0; 1; 2; 3; 4; 5; 8] + +let () = + OUnit.assert_equal + ~cmp:(List.equal Int.equal) + (1 -- 10 + |> map_while (fun x -> if x = 7 then `Stop else `Yield (x - 1)) + |> to_list) + [0; 1; 2; 3; 4; 5] + let () = 1 -- 5 |> drop 2 |> to_list |> OUnit.assert_equal [ 3; 4; 5 ] let () = 1 -- 5 |> rev |> to_list |> OUnit.assert_equal [ 5; 4; 3; 2; 1 ]