Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Micro optimization for collect* operator. #983

Merged
merged 1 commit into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* license agreements; and to You under the Apache License, version 2.0:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* This file is part of the Apache Pekko project, which was derived from Akka.
*/

/*
* Copyright (C) 2009-2022 Lightbend Inc. <https://www.lightbend.com>
*/

package org.apache.pekko.stream

import com.typesafe.config.ConfigFactory
import org.apache.pekko
import org.apache.pekko.stream.ActorAttributes.SupervisionStrategy
import org.apache.pekko.stream.Attributes.SourceLocation
import org.apache.pekko.stream.impl.Stages.DefaultAttributes
import org.apache.pekko.stream.impl.fusing.Collect
import org.apache.pekko.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }
import org.openjdk.jmh.annotations._
import pekko.actor.ActorSystem
import pekko.stream.scaladsl._

import java.util.concurrent.TimeUnit
import scala.annotation.nowarn
import scala.concurrent._
import scala.concurrent.duration._
import scala.util.control.NonFatal

object CollectBenchmark {
final val OperationsPerInvocation = 10000000
}

@State(Scope.Benchmark)
@OutputTimeUnit(TimeUnit.SECONDS)
@BenchmarkMode(Array(Mode.Throughput))
@nowarn("msg=deprecated")
class CollectBenchmark {
import CollectBenchmark._

private val config = ConfigFactory.parseString("""
pekko.actor.default-dispatcher {
executor = "fork-join-executor"
fork-join-executor {
parallelism-factor = 1
}
}
""")

private implicit val system: ActorSystem = ActorSystem("CollectBenchmark", config)

@TearDown
def shutdown(): Unit = {
Await.result(system.terminate(), 5.seconds)
}

private val newCollect = Source
.repeat(1)
.via(new Collect({ case elem => elem }))
.take(OperationsPerInvocation)
.toMat(Sink.ignore)(Keep.right)

private val oldCollect = Source
.repeat(1)
.via(new SimpleCollect({ case elem => elem }))
.take(OperationsPerInvocation)
.toMat(Sink.ignore)(Keep.right)

private class SimpleCollect[In, Out](pf: PartialFunction[In, Out])
extends GraphStage[FlowShape[In, Out]] {
val in = Inlet[In]("SimpleCollect.in")
val out = Outlet[Out]("SimpleCollect.out")
override val shape = FlowShape(in, out)

override def initialAttributes: Attributes = DefaultAttributes.collect and SourceLocation.forLambda(pf)

def createLogic(inheritedAttributes: Attributes) =
new GraphStageLogic(shape) with InHandler with OutHandler {
private lazy val decider = inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider
import Collect.NotApplied

override def onPush(): Unit =
try {
pf.applyOrElse(grab(in), NotApplied) match {
case NotApplied => pull(in)
case result: Out @unchecked => push(out, result)
case _ => throw new RuntimeException()
}
} catch {
case NonFatal(ex) =>
decider(ex) match {
case Supervision.Stop => failStage(ex)
case Supervision.Resume => if (!hasBeenPulled(in)) pull(in)
case Supervision.Restart => if (!hasBeenPulled(in)) pull(in)
}
}

override def onPull(): Unit = pull(in)

setHandlers(in, out, this)
}

override def toString = "SimpleCollect"
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def benchOldCollect(): Unit =
Await.result(oldCollect.run(), Duration.Inf)

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def benchNewCollect(): Unit =
Await.result(newCollect.run(), Duration.Inf)

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,26 @@ private[pekko] final class CollectWhile[In, Out](pf: PartialFunction[In, Out]) e

override final def onPush(): Unit =
try {
pf.applyOrElse(grab(in), NotApplied) match {
case NotApplied => completeStage()
case result: Out @unchecked => push(out, result)
case _ => throw new RuntimeException() // won't happen, compiler exhaustiveness check pleaser
// 1. `applyOrElse` is faster than (`pf.isDefinedAt` and then `pf.apply`)
// 2. using reference comparing here instead of pattern matching can generate less and quicker bytecode,
// eg: just a simple `IF_ACMPNE`, and you can find the same trick in `Collect` operator.
// If you interest, you can check the associated PR for this change and the
// current implementation of `scala.collection.IterableOnceOps.collectFirst`.
val result = pf.applyOrElse(grab(in), NotApplied)
if (result.asInstanceOf[AnyRef] eq NotApplied) {
completeStage()
} else {
push(out, result.asInstanceOf[Out])
}
} catch {
case NonFatal(ex) =>
decider(ex) match {
case Supervision.Stop => failStage(ex)
case _ => pull(in)
case _ =>
// The !hasBeenPulled(in) check is not required here since it
// isn't possible to do an additional pull(in) due to the nature
// of how collect works
pull(in)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,16 @@ private[stream] object Collect {

override def onPush(): Unit =
try {
pf.applyOrElse(grab(in), NotApplied) match {
case NotApplied => pull(in)
case result: Out @unchecked => push(out, result)
case _ => throw new RuntimeException() // won't happen, compiler exhaustiveness check pleaser
val result = pf.applyOrElse(grab(in), NotApplied)
// 1. `applyOrElse` is faster than (`pf.isDefinedAt` and then `pf.apply`)
// 2. using reference comparing here instead of pattern matching can generate less and quicker bytecode,
// eg: just a simple `IF_ACMPNE`, and you can find the same trick in `CollectWhile` operator.
// If you interest, you can check the associated PR for this change and the
// current implementation of `scala.collection.IterableOnceOps.collectFirst`.
if (result.asInstanceOf[AnyRef] eq Collect.NotApplied) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mdedetrich I have updated the text, how do you think about it? you can change it directly, I may not be able to push next time:)

Copy link
Contributor

@mdedetrich mdedetrich Jan 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me its good, just need to wait for @pjfanning t ore-review this. The code is not formatted so you need to repush though.

pull(in)
} else {
push(out, result.asInstanceOf[Out])
}
} catch {
case NonFatal(ex) =>
Expand Down