Skip to content

Commit

Permalink
Consider nulls order in TopN pushdown (#1187) (#1193)
Browse files Browse the repository at this point in the history
  • Loading branch information
birdstorm authored and marsishandsome committed Nov 8, 2019
1 parent bf6bb9a commit d9060f6
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 28 deletions.
29 changes: 20 additions & 9 deletions core/src/main/scala/org/apache/spark/sql/TiStrategy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import com.pingcap.tispark.{BasicExpression, TiConfigConst, TiDBRelation}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.CleanupAliases
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, _}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeSet, Expression, IntegerLiteral, NamedExpression, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeMap, AttributeSet, Descending, Expression, IntegerLiteral, IsNull, NamedExpression, NullsFirst, NullsLast, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -355,17 +355,28 @@ case class TiStrategy(getOrCreateTiContext: SparkSession => TiContext)(sparkSess
val aliases = AttributeMap(projectList.collect {
case a: Alias => a.toAttribute -> a
})
val refinedSortOrder = sortOrders.map { sortOrder =>
// Order by desc/asc + nulls first/last
//
// 1. Order by asc + nulls first:
// order by col asc nulls first = order by col asc
// 2. Order by desc + nulls first:
// order by col desc nulls first = order by col is null desc, col desc
// 3. Order by asc + nulls last:
// order by col asc nulls last = order by col is null asc, col asc
// 4. Order by desc + nulls last:
// order by col desc nulls last = order by col desc
val refinedSortOrder = sortOrders.flatMap { sortOrder: SortOrder =>
val newSortExpr = sortOrder.child.transformUp {
case a: Attribute => aliases.getOrElse(a, a)
}
val trimedExpr = CleanupAliases.trimNonTopLevelAliases(newSortExpr)
SortOrder(
trimedExpr,
sortOrder.direction,
sortOrder.nullOrdering,
sortOrder.sameOrderExpressions
)
val trimmedExpr = CleanupAliases.trimNonTopLevelAliases(newSortExpr)
val trimmedSortOrder = sortOrder.copy(child = trimmedExpr)
(sortOrder.direction, sortOrder.nullOrdering) match {
case (_ @Ascending, _ @NullsLast) | (_ @Descending, _ @NullsFirst) =>
sortOrder.copy(child = IsNull(trimmedExpr)) :: trimmedSortOrder :: Nil
case _ =>
trimmedSortOrder :: Nil
}
}
if (refinedSortOrder.exists(
order => !TiUtil.isSupportedBasicExpression(order.child, source, blacklist)
Expand Down
35 changes: 24 additions & 11 deletions core/src/test/scala/org/apache/spark/sql/BaseTiSparkSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ class BaseTiSparkSuite extends QueryTest with SharedSQLContext {

private def tiCatalog = ti.tiCatalog

protected def querySpark(query: String): List[List[Any]] = {
protected def queryViaTiSpark(query: String): List[List[Any]] = {
val df = sql(query)
val schema = df.schema.fields

dfData(df, schema)
}

protected def queryTiDB(query: String): List[List[Any]] = {
val resultSet = tidbStmt.executeQuery(query)
protected def queryTiDBViaJDBC(query: String): List[List[Any]] = {
val resultSet = callWithRetry(tidbStmt.executeQuery(query))
val rsMetaData = resultSet.getMetaData
val retSet = ArrayBuffer.empty[List[Any]]
val retSchema = ArrayBuffer.empty[String]
Expand Down Expand Up @@ -135,9 +135,13 @@ class BaseTiSparkSuite extends QueryTest with SharedSQLContext {
}
}

override def beforeAll(): Unit = {
def beforeAllWithoutLoadData(): Unit = {
super.beforeAll()
setLogLevel("WARN")
}

override def beforeAll(): Unit = {
beforeAllWithoutLoadData()
loadTestData()
}

Expand Down Expand Up @@ -193,18 +197,27 @@ class BaseTiSparkSuite extends QueryTest with SharedSQLContext {
protected def judge(str: String, skipped: Boolean = false, checkLimit: Boolean = true): Unit =
runTest(str, skipped = skipped, skipJDBC = true, checkLimit = checkLimit)

protected def compSparkWithTiDB(sql: String, checkLimit: Boolean = true): Boolean =
compSqlResult(sql, querySpark(sql), queryTiDB(sql), checkLimit)
protected def compSparkWithTiDB(qSpark: String,
qTiDB: String = null,
checkLimit: Boolean = true): Unit =
if (qTiDB == null) {
compSparkWithTiDB(qSpark, checkLimit)
} else {
runTest(qSpark, rTiDB = queryTiDBViaJDBC(qTiDB), skipJDBC = true, checkLimit = checkLimit)
}

private def compSparkWithTiDB(sql: String, checkLimit: Boolean): Unit =
runTest(sql, skipJDBC = true, checkLimit = checkLimit)

protected def checkSparkResult(sql: String,
result: List[List[Any]],
checkLimit: Boolean = true): Unit =
assert(compSqlResult(sql, querySpark(sql), result, checkLimit))
assert(compSqlResult(sql, queryViaTiSpark(sql), result, checkLimit))

protected def checkSparkResultContains(sql: String,
result: List[Any],
checkLimit: Boolean = true): Unit =
assert(querySpark(sql).exists(x => compSqlResult(sql, List(x), List(result), checkLimit)))
assert(queryViaTiSpark(sql).exists(x => compSqlResult(sql, List(x), List(result), checkLimit)))

protected def explainSpark(str: String, skipped: Boolean = false): Unit =
try {
Expand Down Expand Up @@ -335,7 +348,7 @@ class BaseTiSparkSuite extends QueryTest with SharedSQLContext {

if (r1 == null) {
try {
r1 = querySpark(qSpark)
r1 = queryViaTiSpark(qSpark)
} catch {
case e: Throwable => fail(e)
}
Expand All @@ -352,7 +365,7 @@ class BaseTiSparkSuite extends QueryTest with SharedSQLContext {

if (!skipJDBC && r2 == null) {
try {
r2 = querySpark(qJDBC)
r2 = queryViaTiSpark(qJDBC)
} catch {
case e: Throwable =>
logger.warn(s"Spark with JDBC failed when executing:$qJDBC", e) // JDBC failed
Expand All @@ -362,7 +375,7 @@ class BaseTiSparkSuite extends QueryTest with SharedSQLContext {
if (skipJDBC || !compSqlResult(qSpark, r1, r2, checkLimit)) {
if (!skipTiDB && r3 == null) {
try {
r3 = queryTiDB(qSpark)
r3 = queryTiDBViaJDBC(qSpark)
} catch {
case e: Throwable => logger.warn(s"TiDB failed when executing:$qSpark", e) // TiDB failed
}
Expand Down
30 changes: 30 additions & 0 deletions core/src/test/scala/org/apache/spark/sql/IssueTestSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,36 @@ import com.pingcap.tispark.TiConfigConst
import org.apache.spark.sql.functions.{col, sum}

class IssueTestSuite extends BaseTiSparkSuite {

// https://github.com/pingcap/tispark/issues/1186
test("Consider nulls order when performing TopN") {
// table `full_data_type_table` contains a single line of nulls
compSparkWithTiDB(
qSpark =
"select id_dt, tp_int, tp_bigint from full_data_type_table order by tp_int asc nulls last limit 2",
qTiDB =
"select id_dt, tp_int, tp_bigint from full_data_type_table order by tp_int is null asc, tp_int asc limit 2"
)
compSparkWithTiDB(
qSpark =
"select id_dt, tp_int, tp_bigint from full_data_type_table order by tp_int desc nulls first limit 2",
qTiDB =
"select id_dt, tp_int, tp_bigint from full_data_type_table order by tp_int is null desc, tp_int desc limit 2"
)
compSparkWithTiDB(
qSpark =
"select id_dt, tp_int, tp_bigint from full_data_type_table order by tp_int asc nulls first limit 2",
qTiDB =
"select id_dt, tp_int, tp_bigint from full_data_type_table order by tp_int asc limit 2"
)
compSparkWithTiDB(
qSpark =
"select id_dt, tp_int, tp_bigint from full_data_type_table order by tp_int desc nulls last limit 2",
qTiDB =
"select id_dt, tp_int, tp_bigint from full_data_type_table order by tp_int desc limit 2"
)
}

// https://github.com/pingcap/tispark/issues/1161
test("No Match Column") {
tidbStmt.execute("DROP TABLE IF EXISTS t_no_match_column")
Expand Down
29 changes: 28 additions & 1 deletion core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql

import java.sql.{Date, Timestamp}
import java.sql.{Date, ResultSet, Timestamp}
import java.text.SimpleDateFormat
import java.util.TimeZone

Expand Down Expand Up @@ -407,6 +407,33 @@ abstract class QueryTest extends PlanTest {
s"The physical plan has missing inputs:\n${query.queryExecution.executedPlan}"
)
}

protected def callWithRetry[A](execute: => A): A = {
callWithRetry(execute, retryOnFailure = 3)
}

/**
* Execute a command with retry number = `retryOnFailure`
*
* @param execute command that returns anything
* @param retryOnFailure number of remaining retries before it fails
* @param exception last exception thrown
* @return result of command
*/
protected def callWithRetry[A](execute: => A,
retryOnFailure: Int,
exception: Exception = null): A = {
if (retryOnFailure <= 0) {
fail(exception)
} else
try {
execute
} catch {
case e: Exception =>
logger.info(s"Error occurs when calling with retry, remain retries: $retryOnFailure", e)
callWithRetry(execute, retryOnFailure - 1, e)
}
}
}

object QueryTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class TPCDSQuerySuite extends BaseTiSparkSuite {
}
}

val res = querySpark(queryString)
val res = queryViaTiSpark(queryString)
println(s"TiSpark finished $q")
}
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class TPCHQuerySuite extends BaseTiSparkSuite {
}
case _ =>
}
val res = querySpark(queryString)
val res = queryViaTiSpark(queryString)
println(s"TiSpark finished $name")
res
} catch {
Expand All @@ -100,7 +100,7 @@ class TPCHQuerySuite extends BaseTiSparkSuite {
throw new AssertionError("JDBC plan should not use CoprocessorRDD as data source node!")
case _ =>
}
val res = querySpark(queryString)
val res = queryViaTiSpark(queryString)
println(s"Spark JDBC finished $name")
res
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class TxnTestSuite extends BaseTiSparkSuite {
test("resolveLock concurrent test") {
ti.tiConf.setIsolationLevel(IsolationLevel.SI)

val start = querySpark(sumString).head.head
val start = queryViaTiSpark(sumString).head.head

val threads =
scala.util.Random.shuffle(
Expand All @@ -109,7 +109,7 @@ class TxnTestSuite extends BaseTiSparkSuite {
i / 100 match {
case 0 =>
doThread(i, () => {
querySpark(q1String)
queryViaTiSpark(q1String)
})
case 1 =>
doThread(
Expand All @@ -129,7 +129,7 @@ class TxnTestSuite extends BaseTiSparkSuite {
(i - 200) / 20 match {
case 0 =>
doThread(i, () => {
querySpark(q2String)
queryViaTiSpark(q2String)
})
case 1 =>
doThread(
Expand Down Expand Up @@ -168,7 +168,7 @@ class TxnTestSuite extends BaseTiSparkSuite {
t.join()
}

val end = querySpark(sumString).head.head
val end = queryViaTiSpark(sumString).head.head
if (start != end) {
fail(s"""Failed With
| error transaction
Expand Down

0 comments on commit d9060f6

Please sign in to comment.