Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculate task metric aggregates on-the-fly to reduce memory usage #1543

Merged
merged 16 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster,
import org.apache.spark.sql.rapids.tool.{AppBase, RDDCheckHelper, SqlPlanInfoGraphBuffer, SqlPlanInfoGraphEntry}
import org.apache.spark.sql.rapids.tool.profiling.ApplicationInfo
import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo
import org.apache.spark.sql.rapids.tool.store.{AccumInfo, DataSourceRecord}
import org.apache.spark.sql.rapids.tool.store.DataSourceRecord
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph

/**
Expand Down Expand Up @@ -114,9 +114,9 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
* @param stageId The ID of the stage.
* @return A seq of task IDs corresponding to the given stage ID.
*/
private def getStageTaskIds(stageId: Int): Seq[Long] = {
app.taskManager.getAllTasksStageAttempt(stageId).map(_.taskId)(breakOut).distinct
}
// private def getStageTaskIds(stageId: Int): Seq[Long] = {
// app.taskManager.getAllTasksStageAttempt(stageId).map(_.taskId)(breakOut).distinct
// }

/**
* Retrieves task update values from the accumulator info for the specified stage ID.
Expand All @@ -126,13 +126,13 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
* @return An array of task update values (`Long`) corresponding to the tasks
* in the specified stage.
*/
private def filterAccumTaskUpdatesForStage(accumInfo: AccumInfo, stageTaskIds: Seq[Long])
: Array[Long] = {
stageTaskIds.collect {
case taskId if accumInfo.taskUpdatesMap.contains(taskId) =>
accumInfo.taskUpdatesMap(taskId)
}(breakOut)
}
// private def filterAccumTaskUpdatesForStage(accumInfo: AccumInfo, stageTaskIds: Seq[Long])
// : Array[Long] = {
// stageTaskIds.collect {
// case taskId if accumInfo.taskUpdatesMap.contains(taskId) =>
// accumInfo.taskUpdatesMap(taskId)
// }(breakOut)
// }

/**
* Connects Operators to Stages using AccumulatorIDs.
Expand Down Expand Up @@ -406,7 +406,7 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
// TODO: currently if stage IDs is empty, the result is skipped
val stageIds = sqlAccums.head.stageIds
stageIds.flatMap { stageId =>
val stageTaskIds = getStageTaskIds(stageId)
// val stageTaskIds = getStageTaskIds(stageId)
val nodeName = sqlAccums.head.nodeName

// Initialize a map to store statistics for each IO metric
Expand All @@ -426,8 +426,9 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
Some(StatisticsMetrics(sqlAccum.min, sqlAccum.median, sqlAccum.max, sqlAccum.total))
} else {
// Retrieve task updates which correspond to the current stage
val filteredTaskUpdates = filterAccumTaskUpdatesForStage(accumInfo, stageTaskIds)
StatisticsMetrics.createOptionalFromArr(filteredTaskUpdates)
// val filteredTaskUpdates = filterAccumTaskUpdatesForStage(accumInfo, stageTaskIds)
// StatisticsMetrics.createOptionalFromArr(filteredTaskUpdates)
accumInfo.calculateAccStatsForStage(stageId)
}
}

Expand Down Expand Up @@ -478,11 +479,11 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
val accumInfo = accumMapEntry._2
accumInfo.stageValuesMap.keys.flatMap( stageId => {
// Retrieve task updates correspond to the current stage
val filteredTaskUpdates =
filterAccumTaskUpdatesForStage(accumInfo, getStageTaskIds(stageId))
// val filteredTaskUpdates =
// filterAccumTaskUpdatesForStage(accumInfo, getStageTaskIds(stageId))

// Get the task updates that belong to that stage
StatisticsMetrics.createOptionalFromArr(filteredTaskUpdates) match {
accumInfo.calculateAccStatsForStage(stageId) match {
case Some(stat) =>
// Reuse AccumProfileResults to avoid generating allocating new objects
val accumProfileResults = AccumProfileResults(
Expand All @@ -499,6 +500,23 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
Some(accumProfileResults)
case _ => None
}
// StatisticsMetrics.createOptionalFromArr(filteredTaskUpdates) match {
// case Some(stat) =>
// // Reuse AccumProfileResults to avoid generating allocating new objects
// val accumProfileResults = AccumProfileResults(
// appIndex,
// stageId,
// accumInfo.infoRef,
// min = stat.min,
// median = stat.med,
// max = stat.max,
// total = stat.total)
// if (isDiagnosticMetrics(accumInfo.infoRef.name.value)) {
// updateStageDiagnosticMetrics(accumProfileResults)
// }
// Some(accumProfileResults)
// case _ => None
// }
})
}(breakOut)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,15 +399,21 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {

val accumHelperObj = if (app.isPhoton) { // If this a photon app, use the photonHelper
// For max peak memory, we need to look at the accumulators at the task level.
val peakMemoryValues = tasksInStage.flatMap { taskModel =>
photonPeakMemoryAccumInfos.flatMap { accumInfo =>
accumInfo.taskUpdatesMap.get(taskModel.taskId)
}
}
// val peakMemoryValues = tasksInStage.flatMap { taskModel =>
// photonPeakMemoryAccumInfos.flatMap { accumInfo =>
// accumInfo.taskUpdatesMap.get(taskModel.taskId)
// }
// }
val peakMemoryValues = photonPeakMemoryAccumInfos.flatMap { accumInfo =>
accumInfo.stageValuesMap.get(sm.stageInfo.stageId)
}.map(_._1)
// For sum of shuffle write time, we need to look at the accumulators at the stage level.
// val shuffleWriteValues = photonShuffleWriteTimeAccumInfos.flatMap { accumInfo =>
// accumInfo.stageValuesMap.get(sm.stageInfo.stageId)
// }
val shuffleWriteValues = photonShuffleWriteTimeAccumInfos.flatMap { accumInfo =>
accumInfo.stageValuesMap.get(sm.stageInfo.stageId)
}
}.map(_._1)
new AggAccumPhotonHelper(shuffleWriteValues, peakMemoryValues)
} else {
// For non-Photon apps, use the task metrics directly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids.tool.analysis
import org.apache.spark.sql.rapids.tool.util.InPlaceMedianArrView.{chooseMidpointPivotInPlace, findMedianInPlace}

// Store (min, median, max, total) for a given metric
case class StatisticsMetrics(min: Long, med: Long, max: Long, total: Long)
case class StatisticsMetrics(min: Long, med: Long, var max: Long, var total: Long)

object StatisticsMetrics {
// a static variable used to represent zero-statistics instead of allocating a dummy record
Expand Down
Loading
Loading