From c5706be6a6768866fcf65283683629e6c2abbd41 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Wed, 9 Oct 2024 14:07:19 +0800 Subject: [PATCH] Support aggregate functions in Eval expressions Signed-off-by: Lantao Jin --- .../spark/ppl/FlintSparkPPLEvalITSuite.scala | 58 +++++++++- .../src/main/antlr4/OpenSearchPPLParser.g4 | 1 + .../sql/ppl/CatalystQueryPlanVisitor.java | 14 ++- ...PLLogicalPlanEvalTranslatorTestSuite.scala | 107 ++++++++++++++++-- 4 files changed, 168 insertions(+), 12 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala index e10b2e2a6..362a79341 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala @@ -429,6 +429,61 @@ class FlintSparkPPLEvalITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + test("test eval comma separated expressions with stats functions") { + val frame = sql(s""" + | source = $testTable | eval col1 = max(age), col2 = avg(age), col3 = min(age), col4 = sum(age), col5 = count(age) | fields col1, col2, col3, col4, col5 + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(70, 36.25, 20, 145, 4)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val evalProjectList = Seq( + Alias( + UnresolvedFunction("max", Seq(UnresolvedAttribute("age")), isDistinct = false), + "col1")(), + Alias( + UnresolvedFunction("avg", Seq(UnresolvedAttribute("age")), isDistinct = false), + "col2")(), + Alias( + UnresolvedFunction("min", Seq(UnresolvedAttribute("age")), isDistinct = false), + "col3")(), + Alias( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("age")), isDistinct = false), + "col4")(), + Alias( + UnresolvedFunction("count", Seq(UnresolvedAttribute("age")), isDistinct = false), + "col5")()) + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val project = Project(evalProjectList, table) + val expectedPlan = Project( + seq( + UnresolvedAttribute("col1"), + UnresolvedAttribute("col2"), + UnresolvedAttribute("col3"), + UnresolvedAttribute("col4"), + UnresolvedAttribute("col5")), + project) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("eval stats functions adding other field list should throw exception") { + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval col1 = max(age), col2 = avg(age), col3 = min(age), col4 = sum(age), col5 = count(age) | fields age, col1, col2, col3, col4, col5 + | """.stripMargin)) + assert(ex.getMessage().contains("UNRESOLVED_COLUMN")) + } + + test("eval stats functions without fields command should throw exception") { + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval col1 = max(age), col2 = avg(age), col3 = min(age), col4 = sum(age), col5 = count(age) + | """.stripMargin)) + assert(ex.getMessage().contains("MISSING_GROUP_BY")) + } + test("test complex eval expressions with fields command") { val frame = sql(s""" | source = $testTable | eval new_name = upper(name) | eval compound_field = concat('Hello ', if(like(new_name, 'HEL%'), 'World', name)) | fields new_name, compound_field @@ -672,8 +727,7 @@ class FlintSparkPPLEvalITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } - // Todo excluded fields not support yet - ignore("test single eval expression with excluded fields") { + test("test single eval expression with excluded fields") { val frame = sql(s""" | source = $testTable | eval new_field = "New Field" | fields - age | """.stripMargin) diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 06b3166f0..ab9b4f982 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -473,6 +473,7 @@ evalFunctionName | systemFunctionName | positionFunctionName | coalesceFunctionName + | statsFunctionName ; functionArgs diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index bd1785c85..8eb1dbae7 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -76,6 +76,7 @@ import org.opensearch.sql.ast.tree.TopAggregation; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; @@ -439,9 +440,16 @@ public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { Alias alias = new Alias(let.getVar().getField().toString(), let.getExpression()); aliases.add(alias); } - if (context.getNamedParseExpressions().isEmpty()) { - // Create an UnresolvedStar for all-fields projection - context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + long statsFunctionsCount = node.getExpressionList().stream().map(Let::getExpression) + .filter(e -> e instanceof Function).map(f -> ((Function) f).getFuncName()) + .filter(n -> BuiltinFunctionName.ofAggregation(n).isPresent()).count(); + // An eval expression equals to add a projection to existing project list. + // So it must start with an UnresolvedStar except all eval expressions are aggregation functions with no fields command + if (statsFunctionsCount == node.getExpressionList().size() && + context.getProjectedFields().stream().noneMatch(f -> f instanceof AllFields)) { + // do nothing + } else { + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.empty())); } List expressionList = visitExpressionList(aliases, context); Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala index 3e2b3cc30..250b66e58 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala @@ -14,7 +14,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, ExprId, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Project, Sort} class PPLLogicalPlanEvalTranslatorTestSuite extends SparkFunSuite @@ -150,6 +150,96 @@ class PPLLogicalPlanEvalTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + test("test complex eval expressions - stats function") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = max(l) | eval b = avg(l) | eval c = min(l) | eval d = sum(l) | eval e = count(l) | fields a, b, c, d, e"), + context) + + val evalProjectListA = Seq( + Alias(UnresolvedFunction("max", Seq(UnresolvedAttribute("l")), isDistinct = false), "a")()) + val evalProjectListB = Seq( + Alias(UnresolvedFunction("avg", Seq(UnresolvedAttribute("l")), isDistinct = false), "b")()) + val evalProjectListC = Seq( + Alias(UnresolvedFunction("min", Seq(UnresolvedAttribute("l")), isDistinct = false), "c")()) + val evalProjectListD = Seq( + Alias(UnresolvedFunction("sum", Seq(UnresolvedAttribute("l")), isDistinct = false), "d")()) + val evalProjectListE = Seq( + Alias( + UnresolvedFunction("count", Seq(UnresolvedAttribute("l")), isDistinct = false), + "e")()) + val projectA = Project(evalProjectListA, UnresolvedRelation(Seq("t"))) + val projectB = Project(evalProjectListB, projectA) + val projectC = Project(evalProjectListC, projectB) + val projectD = Project(evalProjectListD, projectC) + val projectE = Project(evalProjectListE, projectD) + val expectedPlan = Project( + seq( + UnresolvedAttribute("a"), + UnresolvedAttribute("b"), + UnresolvedAttribute("c"), + UnresolvedAttribute("d"), + UnresolvedAttribute("e")), + projectE) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test complex eval comma separated expressions - stats function") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = max(l), b = avg(l), c = min(l), d = sum(l), e = count(l) | fields a, b, c, d, e"), + context) + + val evalProjectList = Seq( + Alias(UnresolvedFunction("max", Seq(UnresolvedAttribute("l")), isDistinct = false), "a")(), + Alias(UnresolvedFunction("avg", Seq(UnresolvedAttribute("l")), isDistinct = false), "b")(), + Alias(UnresolvedFunction("min", Seq(UnresolvedAttribute("l")), isDistinct = false), "c")(), + Alias(UnresolvedFunction("sum", Seq(UnresolvedAttribute("l")), isDistinct = false), "d")(), + Alias( + UnresolvedFunction("count", Seq(UnresolvedAttribute("l")), isDistinct = false), + "e")()) + val project = Project(evalProjectList, UnresolvedRelation(Seq("t"))) + val expectedPlan = Project( + seq( + UnresolvedAttribute("a"), + UnresolvedAttribute("b"), + UnresolvedAttribute("c"), + UnresolvedAttribute("d"), + UnresolvedAttribute("e")), + project) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test( + "test complex eval comma separated expressions - stats function - without fields command") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = max(l), b = avg(l), c = min(l), d = sum(l), e = count(l)"), + context) + + val evalProjectList = Seq( + UnresolvedStar(None), + Alias(UnresolvedFunction("max", Seq(UnresolvedAttribute("l")), isDistinct = false), "a")(), + Alias(UnresolvedFunction("avg", Seq(UnresolvedAttribute("l")), isDistinct = false), "b")(), + Alias(UnresolvedFunction("min", Seq(UnresolvedAttribute("l")), isDistinct = false), "c")(), + Alias(UnresolvedFunction("sum", Seq(UnresolvedAttribute("l")), isDistinct = false), "d")(), + Alias( + UnresolvedFunction("count", Seq(UnresolvedAttribute("l")), isDistinct = false), + "e")()) + val project = Project(evalProjectList, UnresolvedRelation(Seq("t"))) + val expectedPlan = Project(Seq(UnresolvedStar(None)), project) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + test("test complex eval expressions - compound function") { val context = new CatalystPlanContext val logPlan = @@ -177,27 +267,30 @@ class PPLLogicalPlanEvalTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } - // Todo fields-excluded command not supported - ignore("test eval expressions with fields-excluded command") { + test("test eval expressions with fields-excluded command") { val context = new CatalystPlanContext val logPlan = planTransformer.visit(plan(pplParser, "source=t | eval a = 1, b = 2 | fields - b"), context) val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(2), "b")()) - val expectedPlan = Project(projectList, UnresolvedRelation(Seq("t"))) + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + DataFrameDropColumns( + Seq(UnresolvedAttribute("b")), + Project(projectList, UnresolvedRelation(Seq("t"))))) comparePlans(expectedPlan, logPlan, checkAnalysis = false) } - // Todo fields-included command not supported - ignore("test eval expressions with fields-included command") { + test("test eval expressions with fields-included command") { val context = new CatalystPlanContext val logPlan = planTransformer.visit(plan(pplParser, "source=t | eval a = 1, b = 2 | fields + b"), context) val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(2), "b")()) - val expectedPlan = Project(projectList, UnresolvedRelation(Seq("t"))) + val expectedPlan = + Project(Seq(UnresolvedAttribute("b")), Project(projectList, UnresolvedRelation(Seq("t")))) comparePlans(expectedPlan, logPlan, checkAnalysis = false) } }