Skip to content

Commit

Permalink
perf: Micro optimization for collect* operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed Jan 20, 2024
1 parent 637d72a commit 29774b8
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* Copyright (C) 2014-2022 Lightbend Inc. <https://www.lightbend.com>
*/

package org.apache.pekko.stream

import com.typesafe.config.ConfigFactory
import org.apache.pekko
import org.apache.pekko.stream.ActorAttributes.SupervisionStrategy
import org.apache.pekko.stream.Attributes.SourceLocation
import org.apache.pekko.stream.impl.Stages.DefaultAttributes
import org.apache.pekko.stream.impl.fusing.Collect
import org.apache.pekko.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }
import org.openjdk.jmh.annotations._
import pekko.actor.ActorSystem
import pekko.stream.scaladsl._

import java.util.concurrent.TimeUnit
import scala.annotation.nowarn
import scala.concurrent._
import scala.concurrent.duration._
import scala.util.control.NonFatal

object CollectBenchmark {
final val OperationsPerInvocation = 10000000
}

@State(Scope.Benchmark)
@OutputTimeUnit(TimeUnit.SECONDS)
@BenchmarkMode(Array(Mode.Throughput))
@nowarn("msg=deprecated")
class CollectBenchmark {
import CollectBenchmark._

private val config = ConfigFactory.parseString("""
akka.actor.default-dispatcher {
executor = "fork-join-executor"
fork-join-executor {
parallelism-factor = 1
}
}
""")

private implicit val system: ActorSystem = ActorSystem("CollectBenchmark", config)

@TearDown
def shutdown(): Unit = {
Await.result(system.terminate(), 5.seconds)
}

private val newCollect = Source
.repeat(1)
.via(new Collect({ case elem => elem }))
.take(OperationsPerInvocation)
.toMat(Sink.ignore)(Keep.right)

private val oldCollect = Source
.repeat(1)
.via(new SimpleCollect({ case elem => elem }))
.take(OperationsPerInvocation)
.toMat(Sink.ignore)(Keep.right)

private class SimpleCollect[In, Out](pf: PartialFunction[In, Out])
extends GraphStage[FlowShape[In, Out]] {
val in = Inlet[In]("SimpleCollect.in")
val out = Outlet[Out]("SimpleCollect.out")
override val shape = FlowShape(in, out)

override def initialAttributes: Attributes = DefaultAttributes.collect and SourceLocation.forLambda(pf)

def createLogic(inheritedAttributes: Attributes) =
new GraphStageLogic(shape) with InHandler with OutHandler {
private lazy val decider = inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider
import Collect.NotApplied

override def onPush(): Unit =
try {
pf.applyOrElse(grab(in), NotApplied) match {
case NotApplied => pull(in)
case result: Out @unchecked => push(out, result)
case _ => throw new RuntimeException()
}
} catch {
case NonFatal(ex) =>
decider(ex) match {
case Supervision.Stop => failStage(ex)
case Supervision.Resume => if (!hasBeenPulled(in)) pull(in)
case Supervision.Restart => if (!hasBeenPulled(in)) pull(in)
}
}

override def onPull(): Unit = pull(in)

setHandlers(in, out, this)
}

override def toString = "SimpleCollect"
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def benchOldCollect(): Unit =
Await.result(oldCollect.run(), Duration.Inf)

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def benchNewCollect(): Unit =
Await.result(newCollect.run(), Duration.Inf)

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ private[pekko] final class CollectWhile[In, Out](pf: PartialFunction[In, Out]) e

override final def onPush(): Unit =
try {
pf.applyOrElse(grab(in), NotApplied) match {
case NotApplied => completeStage()
case result: Out @unchecked => push(out, result)
case _ => throw new RuntimeException() // won't happen, compiler exhaustiveness check pleaser
val result = pf.applyOrElse(grab(in), NotApplied)
if (result.asInstanceOf[AnyRef] eq NotApplied) {
completeStage()
} else {
push(out, result.asInstanceOf[Out])
}
} catch {
case NonFatal(ex) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,16 @@ private[stream] object Collect {

override def onPush(): Unit =
try {
pf.applyOrElse(grab(in), NotApplied) match {
case NotApplied => pull(in)
case result: Out @unchecked => push(out, result)
case _ => throw new RuntimeException() // won't happen, compiler exhaustiveness check pleaser
val result = pf.applyOrElse(grab(in), NotApplied)
// 1. `applyOrElse` is faster than (`pf.isDefinedAt` and then `pf.apply`)
// 2. using reference comparing here instead of pattern matching can generate less and quicker bytecode,
// eg: just a simple `IF_ACMPNE`, and you can find the same trick in `CollectWhile` operator.
// If you interest, you can check the associated PR for this change and the
// current implementation of `scala.collection.IterableOnceOps.collectFirst`.
if (result.asInstanceOf[AnyRef] eq Collect.NotApplied) {
pull(in)
} else {
push(out, result.asInstanceOf[Out])
}
} catch {
case NonFatal(ex) =>
Expand Down

0 comments on commit 29774b8

Please sign in to comment.