diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala index 5368955e7be..8acc1ca7c31 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala @@ -178,25 +178,26 @@ import pekko.util.ccompat._ override def initialAttributes: Attributes = DefaultAttributes.dropWhile and SourceLocation.forLambda(p) def createLogic(inheritedAttributes: Attributes) = - new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler { - override def onPush(): Unit = { - val elem = grab(in) - withSupervision(() => p(elem)) match { - case Some(flag) => - if (flag) pull(in) - else { - push(out, elem) - setHandler(in, rest) + new GraphStageLogic(shape) with InHandler with OutHandler { + private lazy val decider = inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider + + override def onPush(): Unit = + try { + val elem = grab(in) + if (p(elem)) { + pull(in) + } else { + push(out, elem) + setHandler(in, () => push(out, grab(in))) + } + } catch { + case NonFatal(ex) => + decider(ex) match { + case Supervision.Stop => failStage(ex) + case Supervision.Resume => if (!hasBeenPulled(in)) pull(in) + case Supervision.Restart => if (!hasBeenPulled(in)) pull(in) } - case None => // do nothing } - } - - def rest = new InHandler { - def onPush() = push(out, grab(in)) - } - - override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in) override def onPull(): Unit = pull(in) @@ -252,23 +253,25 @@ private[stream] object Collect { override def initialAttributes: Attributes = DefaultAttributes.collect and SourceLocation.forLambda(pf) def createLogic(inheritedAttributes: Attributes) = - new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler { - + new GraphStageLogic(shape) with InHandler with OutHandler { + private lazy val decider = inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider import Collect.NotApplied - val wrappedPf = () => pf.applyOrElse(grab(in), NotApplied) - - override def onPush(): Unit = withSupervision(wrappedPf) match { - case Some(result) => - result match { + override def onPush(): Unit = + try { + pf.applyOrElse(grab(in), NotApplied) match { case NotApplied => pull(in) case result: Out @unchecked => push(out, result) case _ => throw new RuntimeException() // won't happen, compiler exhaustiveness check pleaser } - case None => // do nothing - } - - override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in) + } catch { + case NonFatal(ex) => + decider(ex) match { + case Supervision.Stop => failStage(ex) + case Supervision.Resume => if (!hasBeenPulled(in)) pull(in) + case Supervision.Restart => if (!hasBeenPulled(in)) pull(in) + } + } override def onPull(): Unit = pull(in) @@ -846,25 +849,25 @@ private[stream] object Collect { override def initialAttributes: Attributes = DefaultAttributes.limitWeighted and SourceLocation.forLambda(costFn) def createLogic(inheritedAttributes: Attributes) = - new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler { + new GraphStageLogic(shape) with InHandler with OutHandler { + private lazy val decider = inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider + private var left = n - override def onPush(): Unit = { - val elem = grab(in) - withSupervision(() => costFn(elem)) match { - case Some(weight) => - left -= weight - if (left >= 0) push(out, elem) else failStage(new StreamLimitReachedException(n)) - case None => // do nothing + override def onPush(): Unit = + try { + val elem = grab(in) + left -= costFn(elem) + if (left >= 0) push(out, elem) else failStage(new StreamLimitReachedException(n)) + } catch { + case NonFatal(ex) => decider(ex) match { + case Supervision.Stop => failStage(ex) + case Supervision.Resume => if (!hasBeenPulled(in)) pull(in) + case Supervision.Restart => + left = n + if (!hasBeenPulled(in)) pull(in) + } } - } - - override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in) - - override def onRestart(t: Throwable): Unit = { - left = n - if (!hasBeenPulled(in)) pull(in) - } override def onPull(): Unit = pull(in)