Skip to content

Commit

Permalink
AccumInfo - Reverting test
Browse files Browse the repository at this point in the history
Signed-off-by: Sayed Bilal Bari <sbari@nvidia.com>
  • Loading branch information
sayedbilalbari committed Feb 13, 2025
1 parent 64606bd commit 76b2bff
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@ package com.nvidia.spark.rapids.tool.analysis
import org.apache.spark.sql.rapids.tool.util.InPlaceMedianArrView.{chooseMidpointPivotInPlace, findMedianInPlace}

// Store (min, median, max, total) for a given metric
case class StatisticsMetrics(
var min: Long,
var med: Long,
var max: Long,
var total: Long)
case class StatisticsMetrics(min: Long, med: Long, var max: Long, var total: Long)

object StatisticsMetrics {
// a static variable used to represent zero-statistics instead of allocating a dummy record
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ class AccumInfo(val infoRef: AccumMetaRef) {
parsedUpdateValue.foreach{ value =>
val (total, stats) = stageValuesMap.getOrElse(stageId,
(0L, StatisticsMetrics(value, 0L, value, 0L)))
stats.min = Math.min(stats.min, value)
stats.med = (stats.med * stats.total + value) / ( stats.total + 1)
stats.max = Math.max(stats.max, value)
stats.total = stats.total + 1
stageValuesMap.put(stageId, (total + value, stats))
val newStats = StatisticsMetrics(
Math.min(stats.min, value),
(stats.med * stats.total + value) / ( stats.total + 1),
Math.max(stats.max, value),
stats.total + 1
)
stageValuesMap.put(stageId, (total + value, newStats))
}
}

Expand All @@ -96,18 +98,19 @@ class AccumInfo(val infoRef: AccumMetaRef) {

def calculateAccStats(): StatisticsMetrics = {
val reduced_val = stageValuesMap.values.reduce { (a, b) =>
a._2.min = Math.min(a._2.min, b._2.min)
a._2.med = (a._2.med * a._2.total + b._2.med * b._2.total) / (a._2.total + b._2.total)
a._2.max = Math.max(a._2.max, b._2.max)
a._2.total = a._2.total + b._2.total
(a._1 + b._1, a._2)
(a._1 + b._1,
StatisticsMetrics(
Math.min(a._2.min, b._2.min),
(a._2.med * a._2.total + b._2.med * b._2.total) / (a._2.total + b._2.total),
Math.max(a._2.max, b._2.max),
a._2.total + b._2.total
))
}
StatisticsMetrics(
reduced_val._2.min,
reduced_val._2.med,
reduced_val._2.max,
reduced_val._1
)
reduced_val._1)
}

def calculateAccStatsForStage(stageId: Int): Option[StatisticsMetrics] = {
Expand Down

0 comments on commit 76b2bff

Please sign in to comment.