diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 3efca3205..561b5b27b 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -122,6 +122,15 @@ Assumptions: `a`, `b`, `c`, `d`, `e` are existing fields in `table` - `source = table | fillnull using a = 101, b = 102` - `source = table | fillnull using a = concat(b, c), d = 2 * pi() * e` +### Flatten +[See additional command details](ppl-flatten-command.md) +Assumptions: `bridges`, `coor` are existing fields in `table`, and the field's types are `struct` or `array>` +- `source = table | flatten bridges` +- `source = table | flatten coor` +- `source = table | flatten bridges | flatten coor` +- `source = table | fields bridges | flatten bridges` +- `source = table | fields country, bridges | flatten bridges | fields country, length | stats avg(length) as avg by country` + ```sql source = table | eval e = eval status_category = case(a >= 200 AND a < 300, 'Success', diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index 8d9b86eda..43e9579aa 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -31,6 +31,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`describe command`](PPL-Example-Commands.md/#describe) - [`fillnull command`](ppl-fillnull-command.md) + + - [`flatten command`](ppl-flatten-command.md) - [`eval command`](ppl-eval-command.md) diff --git a/docs/ppl-lang/ppl-flatten-command.md b/docs/ppl-lang/ppl-flatten-command.md new file mode 100644 index 000000000..4c1ae5d0d --- /dev/null +++ b/docs/ppl-lang/ppl-flatten-command.md @@ -0,0 +1,90 @@ +## PPL `flatten` command + +### Description +Using `flatten` command to flatten a field of type: +- `struct` +- `array>` + + +### Syntax +`flatten ` + +* field: to be flattened. The field must be of supported type. + +### Test table +#### Schema +| col\_name | data\_type | +|-----------|-------------------------------------------------| +| \_time | string | +| bridges | array\\> | +| city | string | +| coor | struct\ | +| country | string | +#### Data +| \_time | bridges | city | coor | country | +|---------------------|----------------------------------------------|---------|------------------------|---------------| +| 2024-09-13T12:00:00 | [{801, Tower Bridge}, {928, London Bridge}] | London | {35, 51.5074, -0.1278} | England | +| 2024-09-13T12:00:00 | [{232, Pont Neuf}, {160, Pont Alexandre III}]| Paris | {35, 48.8566, 2.3522} | France | +| 2024-09-13T12:00:00 | [{48, Rialto Bridge}, {11, Bridge of Sighs}] | Venice | {2, 45.4408, 12.3155} | Italy | +| 2024-09-13T12:00:00 | [{516, Charles Bridge}, {343, Legion Bridge}]| Prague | {200, 50.0755, 14.4378}| Czech Republic| +| 2024-09-13T12:00:00 | [{375, Chain Bridge}, {333, Liberty Bridge}] | Budapest| {96, 47.4979, 19.0402} | Hungary | +| 1990-09-13T12:00:00 | NULL | Warsaw | NULL | Poland | + + + +### Example 1: flatten struct +This example shows how to flatten a struct field. +PPL query: + - `source=table | flatten coor` + +| \_time | bridges | city | country | alt | lat | long | +|---------------------|----------------------------------------------|---------|---------------|-----|--------|--------| +| 2024-09-13T12:00:00 | [{801, Tower Bridge}, {928, London Bridge}] | London | England | 35 | 51.5074| -0.1278| +| 2024-09-13T12:00:00 | [{232, Pont Neuf}, {160, Pont Alexandre III}]| Paris | France | 35 | 48.8566| 2.3522 | +| 2024-09-13T12:00:00 | [{48, Rialto Bridge}, {11, Bridge of Sighs}] | Venice | Italy | 2 | 45.4408| 12.3155| +| 2024-09-13T12:00:00 | [{516, Charles Bridge}, {343, Legion Bridge}]| Prague | Czech Republic| 200 | 50.0755| 14.4378| +| 2024-09-13T12:00:00 | [{375, Chain Bridge}, {333, Liberty Bridge}] | Budapest| Hungary | 96 | 47.4979| 19.0402| +| 1990-09-13T12:00:00 | NULL | Warsaw | Poland | NULL| NULL | NULL | + + + +### Example 2: flatten array + +The example shows how to flatten an array of struct fields. + +PPL query: + - `source=table | flatten bridges` + +| \_time | city | coor | country | length | name | +|---------------------|---------|------------------------|---------------|--------|-------------------| +| 2024-09-13T12:00:00 | London | {35, 51.5074, -0.1278} | England | 801 | Tower Bridge | +| 2024-09-13T12:00:00 | London | {35, 51.5074, -0.1278} | England | 928 | London Bridge | +| 2024-09-13T12:00:00 | Paris | {35, 48.8566, 2.3522} | France | 232 | Pont Neuf | +| 2024-09-13T12:00:00 | Paris | {35, 48.8566, 2.3522} | France | 160 | Pont Alexandre III| +| 2024-09-13T12:00:00 | Venice | {2, 45.4408, 12.3155} | Italy | 48 | Rialto Bridge | +| 2024-09-13T12:00:00 | Venice | {2, 45.4408, 12.3155} | Italy | 11 | Bridge of Sighs | +| 2024-09-13T12:00:00 | Prague | {200, 50.0755, 14.4378}| Czech Republic| 516 | Charles Bridge | +| 2024-09-13T12:00:00 | Prague | {200, 50.0755, 14.4378}| Czech Republic| 343 | Legion Bridge | +| 2024-09-13T12:00:00 | Budapest| {96, 47.4979, 19.0402} | Hungary | 375 | Chain Bridge | +| 2024-09-13T12:00:00 | Budapest| {96, 47.4979, 19.0402} | Hungary | 333 | Liberty Bridge | +| 1990-09-13T12:00:00 | Warsaw | NULL | Poland | NULL | NULL | + + +### Example 3: flatten array and struct +This example shows how to flatten multiple fields. +PPL query: + - `source=table | flatten bridges | flatten coor` + +| \_time | city | country | length | name | alt | lat | long | +|---------------------|---------|---------------|--------|-------------------|------|--------|--------| +| 2024-09-13T12:00:00 | London | England | 801 | Tower Bridge | 35 | 51.5074| -0.1278| +| 2024-09-13T12:00:00 | London | England | 928 | London Bridge | 35 | 51.5074| -0.1278| +| 2024-09-13T12:00:00 | Paris | France | 232 | Pont Neuf | 35 | 48.8566| 2.3522 | +| 2024-09-13T12:00:00 | Paris | France | 160 | Pont Alexandre III| 35 | 48.8566| 2.3522 | +| 2024-09-13T12:00:00 | Venice | Italy | 48 | Rialto Bridge | 2 | 45.4408| 12.3155| +| 2024-09-13T12:00:00 | Venice | Italy | 11 | Bridge of Sighs | 2 | 45.4408| 12.3155| +| 2024-09-13T12:00:00 | Prague | Czech Republic| 516 | Charles Bridge | 200 | 50.0755| 14.4378| +| 2024-09-13T12:00:00 | Prague | Czech Republic| 343 | Legion Bridge | 200 | 50.0755| 14.4378| +| 2024-09-13T12:00:00 | Budapest| Hungary | 375 | Chain Bridge | 96 | 47.4979| 19.0402| +| 2024-09-13T12:00:00 | Budapest| Hungary | 333 | Liberty Bridge | 96 | 47.4979| 19.0402| +| 1990-09-13T12:00:00 | Warsaw | Poland | NULL | NULL | NULL | NULL | NULL | \ No newline at end of file diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index c8c902294..079b8fcae 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark -import java.nio.file.{Files, Paths} +import java.nio.file.{Files, Path, Paths} import java.util.Comparator import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture} @@ -534,6 +534,28 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit |""".stripMargin) } + protected def createMultiValueStructTable(testTable: String): Unit = { + // CSV doesn't support struct field + sql(s""" + | CREATE TABLE $testTable + | ( + | int_col INT, + | multi_value Array> + | ) + | USING JSON + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | SELECT /*+ COALESCE(1) */ * + | FROM VALUES + | ( 1, array(STRUCT("1_one", 1), STRUCT(null, 11), STRUCT("1_three", null)) ), + | ( 2, array(STRUCT("2_Monday", 2), null) ), + | ( 3, array(STRUCT("3_third", 3), STRUCT("3_4th", 4)) ), + | ( 4, null ) + |""".stripMargin) + } + protected def createTableIssue112(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable ( @@ -695,4 +717,100 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | (9, '2001:db8::ff00:12:', true, false) | """.stripMargin) } + + protected def createNestedJsonContentTable(tempFile: Path, testTable: String): Unit = { + val json = + """ + |[ + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Tower Bridge", "length": 801}, + | {"name": "London Bridge", "length": 928} + | ], + | "city": "London", + | "country": "England", + | "coor": { + | "lat": 51.5074, + | "long": -0.1278, + | "alt": 35 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Pont Neuf", "length": 232}, + | {"name": "Pont Alexandre III", "length": 160} + | ], + | "city": "Paris", + | "country": "France", + | "coor": { + | "lat": 48.8566, + | "long": 2.3522, + | "alt": 35 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Rialto Bridge", "length": 48}, + | {"name": "Bridge of Sighs", "length": 11} + | ], + | "city": "Venice", + | "country": "Italy", + | "coor": { + | "lat": 45.4408, + | "long": 12.3155, + | "alt": 2 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Charles Bridge", "length": 516}, + | {"name": "Legion Bridge", "length": 343} + | ], + | "city": "Prague", + | "country": "Czech Republic", + | "coor": { + | "lat": 50.0755, + | "long": 14.4378, + | "alt": 200 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Chain Bridge", "length": 375}, + | {"name": "Liberty Bridge", "length": 333} + | ], + | "city": "Budapest", + | "country": "Hungary", + | "coor": { + | "lat": 47.4979, + | "long": 19.0402, + | "alt": 96 + | } + | }, + | { + | "_time": "1990-09-13T12:00:00", + | "bridges": null, + | "city": "Warsaw", + | "country": "Poland", + | "coor": null + | } + |] + |""".stripMargin + val tempFile = Files.createTempFile("jsonTestData", ".json") + val absolutPath = tempFile.toAbsolutePath.toString; + Files.write(tempFile, json.getBytes) + sql(s""" + | CREATE TEMPORARY VIEW $testTable + | USING org.apache.spark.sql.json + | OPTIONS ( + | path "$absolutPath", + | multiLine true + | ); + |""".stripMargin) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFlattenITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFlattenITSuite.scala new file mode 100644 index 000000000..e714a5f7e --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFlattenITSuite.scala @@ -0,0 +1,350 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import java.nio.file.Files + +import org.opensearch.flint.spark.FlattenGenerator +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, GeneratorOuter, Literal, Or} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLFlattenITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + private val testTable = "flint_ppl_test" + private val structNestedTable = "spark_catalog.default.flint_ppl_struct_nested_test" + private val structTable = "spark_catalog.default.flint_ppl_struct_test" + private val multiValueTable = "spark_catalog.default.flint_ppl_multi_value_test" + private val tempFile = Files.createTempFile("jsonTestData", ".json") + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createNestedJsonContentTable(tempFile, testTable) + createStructNestedTable(structNestedTable) + createStructTable(structTable) + createMultiValueStructTable(multiValueTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + override def afterAll(): Unit = { + super.afterAll() + Files.deleteIfExists(tempFile) + } + + test("flatten for structs") { + val frame = sql(s""" + | source = $testTable + | | where country = 'England' or country = 'Poland' + | | fields coor + | | flatten coor + | """.stripMargin) + + assert(frame.columns.sameElements(Array("alt", "lat", "long"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(35, 51.5074, -0.1278), Row(null, null, null)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val filter = Filter( + Or( + EqualTo(UnresolvedAttribute("country"), Literal("England")), + EqualTo(UnresolvedAttribute("country"), Literal("Poland"))), + table) + val projectCoor = Project(Seq(UnresolvedAttribute("coor")), filter) + val flattenCoor = flattenPlanFor("coor", projectCoor) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenCoor) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + private def flattenPlanFor(flattenedColumn: String, parentPlan: LogicalPlan): LogicalPlan = { + val flattenGenerator = new FlattenGenerator(UnresolvedAttribute(flattenedColumn)) + val outerGenerator = GeneratorOuter(flattenGenerator) + val generate = Generate(outerGenerator, seq(), outer = true, None, seq(), parentPlan) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute(flattenedColumn)), generate) + dropSourceColumn + } + + test("flatten for arrays") { + val frame = sql(s""" + | source = $testTable + | | fields bridges + | | flatten bridges + | """.stripMargin) + + assert(frame.columns.sameElements(Array("length", "name"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(null, null), + Row(11L, "Bridge of Sighs"), + Row(48L, "Rialto Bridge"), + Row(160L, "Pont Alexandre III"), + Row(232L, "Pont Neuf"), + Row(801L, "Tower Bridge"), + Row(928L, "London Bridge"), + Row(343L, "Legion Bridge"), + Row(516L, "Charles Bridge"), + Row(333L, "Liberty Bridge"), + Row(375L, "Chain Bridge")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val projectCoor = Project(Seq(UnresolvedAttribute("bridges")), table) + val flattenBridges = flattenPlanFor("bridges", projectCoor) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenBridges) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten for structs and arrays") { + val frame = sql(s""" + | source = $testTable | flatten bridges | flatten coor + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array("_time", "city", "country", "length", "name", "alt", "lat", "long"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("1990-09-13T12:00:00", "Warsaw", "Poland", null, null, null, null, null), + Row( + "2024-09-13T12:00:00", + "Venice", + "Italy", + 11L, + "Bridge of Sighs", + 2, + 45.4408, + 12.3155), + Row("2024-09-13T12:00:00", "Venice", "Italy", 48L, "Rialto Bridge", 2, 45.4408, 12.3155), + Row( + "2024-09-13T12:00:00", + "Paris", + "France", + 160L, + "Pont Alexandre III", + 35, + 48.8566, + 2.3522), + Row("2024-09-13T12:00:00", "Paris", "France", 232L, "Pont Neuf", 35, 48.8566, 2.3522), + Row( + "2024-09-13T12:00:00", + "London", + "England", + 801L, + "Tower Bridge", + 35, + 51.5074, + -0.1278), + Row( + "2024-09-13T12:00:00", + "London", + "England", + 928L, + "London Bridge", + 35, + 51.5074, + -0.1278), + Row( + "2024-09-13T12:00:00", + "Prague", + "Czech Republic", + 343L, + "Legion Bridge", + 200, + 50.0755, + 14.4378), + Row( + "2024-09-13T12:00:00", + "Prague", + "Czech Republic", + 516L, + "Charles Bridge", + 200, + 50.0755, + 14.4378), + Row( + "2024-09-13T12:00:00", + "Budapest", + "Hungary", + 333L, + "Liberty Bridge", + 96, + 47.4979, + 19.0402), + Row( + "2024-09-13T12:00:00", + "Budapest", + "Hungary", + 375L, + "Chain Bridge", + 96, + 47.4979, + 19.0402)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](3)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val flattenBridges = flattenPlanFor("bridges", table) + val flattenCoor = flattenPlanFor("coor", flattenBridges) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenCoor) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test flatten and stats") { + val frame = sql(s""" + | source = $testTable + | | fields country, bridges + | | flatten bridges + | | fields country, length + | | stats avg(length) as avg by country + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(null, "Poland"), + Row(196d, "France"), + Row(429.5, "Czech Republic"), + Row(864.5, "England"), + Row(29.5, "Italy"), + Row(354.0, "Hungary")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val projectCountryBridges = + Project(Seq(UnresolvedAttribute("country"), UnresolvedAttribute("bridges")), table) + val flattenBridges = flattenPlanFor("bridges", projectCountryBridges) + val projectCountryLength = + Project(Seq(UnresolvedAttribute("country"), UnresolvedAttribute("length")), flattenBridges) + val average = Alias( + UnresolvedFunction( + seq("AVG"), + seq(UnresolvedAttribute("length")), + isDistinct = false, + None, + ignoreNulls = false), + "avg")() + val country = Alias(UnresolvedAttribute("country"), "country")() + val grouping = Alias(UnresolvedAttribute("country"), "country")() + val aggregate = Aggregate(Seq(grouping), Seq(average, country), projectCountryLength) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten struct table") { + val frame = sql(s""" + | source = $structTable + | | flatten struct_col + | | flatten field1 + | """.stripMargin) + + assert(frame.columns.sameElements(Array("int_col", "field2", "subfield"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(30, 123, "value1"), Row(40, 456, "value2"), Row(50, 789, "value3")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_struct_test")) + val flattenStructCol = flattenPlanFor("struct_col", table) + val flattenField1 = flattenPlanFor("field1", flattenStructCol) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenField1) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten struct nested table") { + val frame = sql(s""" + | source = $structNestedTable + | | flatten struct_col + | | flatten field1 + | | flatten struct_col2 + | | flatten field1 + | """.stripMargin) + + assert( + frame.columns.sameElements(Array("int_col", "field2", "subfield", "field2", "subfield"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(30, 123, "value1", 23, "valueA"), + Row(40, 123, "value5", 33, "valueB"), + Row(30, 823, "value4", 83, "valueC"), + Row(40, 456, "value2", 46, "valueD"), + Row(50, 789, "value3", 89, "valueE")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_struct_nested_test")) + val flattenStructCol = flattenPlanFor("struct_col", table) + val flattenField1 = flattenPlanFor("field1", flattenStructCol) + val flattenStructCol2 = flattenPlanFor("struct_col2", flattenField1) + val flattenField1Again = flattenPlanFor("field1", flattenStructCol2) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenField1Again) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten multi value nullable") { + val frame = sql(s""" + | source = $multiValueTable + | | flatten multi_value + | """.stripMargin) + + assert(frame.columns.sameElements(Array("int_col", "name", "value"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, "1_one", 1), + Row(1, null, 11), + Row(1, "1_three", null), + Row(2, "2_Monday", 2), + Row(2, null, null), + Row(3, "3_third", 3), + Row(3, "3_4th", 4), + Row(4, null, null)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_value_test")) + val flattenMultiValue = flattenPlanFor("multi_value", table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenMultiValue) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index bf6989b7c..58d10a560 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -37,6 +37,7 @@ KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; FILLNULL: 'FILLNULL'; +FLATTEN: 'FLATTEN'; //Native JOIN KEYWORDS JOIN: 'JOIN'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index aaf807a7b..8bb93567b 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -53,6 +53,7 @@ commands | renameCommand | fillnullCommand | fieldsummaryCommand + | flattenCommand ; commandName @@ -82,6 +83,7 @@ commandName | RENAME | FILLNULL | FIELDSUMMARY + | FLATTEN ; searchCommand @@ -89,7 +91,7 @@ searchCommand | (SEARCH)? fromClause logicalExpression # searchFromFilter | (SEARCH)? logicalExpression fromClause # searchFilterFrom ; - + fieldsummaryCommand : FIELDSUMMARY (fieldsummaryParameter)* ; @@ -246,6 +248,10 @@ fillnullCommand : expression ; +flattenCommand + : FLATTEN fieldExpression + ; + kmeansCommand : KMEANS (kmeansParameter)* diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 03c40fcd2..00db5b675 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -326,4 +326,8 @@ public T visitWindow(Window node, C context) { public T visitCidr(Cidr node, C context) { return visitChildren(node, context); } + + public T visitFlatten(Flatten flatten, C context) { + return visitChildren(flatten, context); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java new file mode 100644 index 000000000..e31fbb6e3 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java @@ -0,0 +1,34 @@ +package org.opensearch.sql.ast.tree; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; + +import java.util.List; + +@RequiredArgsConstructor +public class Flatten extends UnresolvedPlan { + + private UnresolvedPlan child; + + @Getter + private final Field fieldToBeFlattened; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return child == null ? List.of() : List.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFlatten(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 5d2fe986b..3ad1b95cb 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -15,6 +15,7 @@ import org.apache.spark.sql.catalyst.expressions.Descending$; import org.apache.spark.sql.catalyst.expressions.Exists$; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GeneratorOuter; import org.apache.spark.sql.catalyst.expressions.In$; import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; import org.apache.spark.sql.catalyst.expressions.InSubquery$; @@ -33,6 +34,7 @@ import org.apache.spark.sql.execution.command.ExplainCommand; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.opensearch.flint.spark.FlattenGenerator; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; @@ -74,6 +76,7 @@ import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Filter; +import org.opensearch.sql.ast.tree.Flatten; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Kmeans; @@ -98,6 +101,7 @@ import org.opensearch.sql.ppl.utils.ParseTransformer; import org.opensearch.sql.ppl.utils.SortUtils; import org.opensearch.sql.ppl.utils.WindowSpecTransformer; +import scala.None$; import scala.Option; import scala.Tuple2; import scala.collection.IterableLike; @@ -453,6 +457,20 @@ public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) return Objects.requireNonNull(resultWithoutDuplicatedColumns, "FillNull operation failed"); } + @Override + public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { + flatten.getChild().get(0).accept(this, context); + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + Expression field = visitExpression(flatten.getFieldToBeFlattened(), context); + context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + FlattenGenerator flattenGenerator = new FlattenGenerator(field); + context.apply(p -> new Generate(new GeneratorOuter(flattenGenerator), seq(), true, (Option) None$.MODULE$, seq(), p)); + return context.apply(logicalPlan -> DataFrameDropColumns$.MODULE$.apply(seq(field), logicalPlan)); + } + private void visitFieldList(List fieldList, CatalystPlanContext context) { fieldList.forEach(field -> visitExpression(field, context)); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index c69e9541e..09db8b126 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -562,6 +562,12 @@ public UnresolvedPlan visitFillnullCommand(OpenSearchPPLParser.FillnullCommandCo } } + @Override + public UnresolvedPlan visitFlattenCommand(OpenSearchPPLParser.FlattenCommandContext ctx) { + Field unresolvedExpression = (Field) internalVisitExpression(ctx.fieldExpression()); + return new Flatten(unresolvedExpression); + } + /** AD command. */ @Override public UnresolvedPlan visitAdCommand(OpenSearchPPLParser.AdCommandContext ctx) { diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlattenGenerator.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlattenGenerator.scala new file mode 100644 index 000000000..23b545826 --- /dev/null +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlattenGenerator.scala @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, CreateArray, Expression, GenericInternalRow, Inline, UnaryExpression} +import org.apache.spark.sql.types.{ArrayType, StructType} + +class FlattenGenerator(override val child: Expression) + extends Inline(child) + with CollectionGenerator { + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case st: StructType => TypeCheckResult.TypeCheckSuccess + case _ => super.checkInputDataTypes() + } + + override def elementSchema: StructType = child.dataType match { + case st: StructType => st + case _ => super.elementSchema + } + + override protected def withNewChildInternal(newChild: Expression): FlattenGenerator = { + newChild.dataType match { + case ArrayType(st: StructType, _) => new FlattenGenerator(newChild) + case st: StructType => withNewChildInternal(CreateArray(Seq(newChild), false)) + case _ => + throw new IllegalArgumentException(s"Unexpected input type ${newChild.dataType}") + } + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFlattenCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFlattenCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..58a6c04b3 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFlattenCommandTranslatorTestSuite.scala @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.FlattenGenerator +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, GeneratorOuter, Literal, NullsLast, RegExpExtract, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Generate, GlobalLimit, LocalLimit, Project, Sort} +import org.apache.spark.sql.types.IntegerType + +class PPLLogicalPlanFlattenCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test flatten only field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | flatten field_with_array"), + context) + + val relation = UnresolvedRelation(Seq("relation")) + val flattenGenerator = new FlattenGenerator(UnresolvedAttribute("field_with_array")) + val outerGenerator = GeneratorOuter(flattenGenerator) + val generate = Generate(outerGenerator, seq(), true, None, seq(), relation) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("field_with_array")), generate) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test flatten and stats") { + val context = new CatalystPlanContext + val query = + "source = relation | fields state, company, employee | flatten employee | fields state, company, salary | stats max(salary) as max by state, company" + val logPlan = + planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("relation")) + val projectStateCompanyEmployee = + Project( + Seq( + UnresolvedAttribute("state"), + UnresolvedAttribute("company"), + UnresolvedAttribute("employee")), + table) + val generate = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("employee"))), + seq(), + true, + None, + seq(), + projectStateCompanyEmployee) + val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val projectStateCompanySalary = Project( + Seq( + UnresolvedAttribute("state"), + UnresolvedAttribute("company"), + UnresolvedAttribute("salary")), + dropSourceColumn) + val average = Alias( + UnresolvedFunction(seq("MAX"), seq(UnresolvedAttribute("salary")), false, None, false), + "max")() + val state = Alias(UnresolvedAttribute("state"), "state")() + val company = Alias(UnresolvedAttribute("company"), "company")() + val groupingState = Alias(UnresolvedAttribute("state"), "state")() + val groupingCompany = Alias(UnresolvedAttribute("company"), "company")() + val aggregate = Aggregate( + Seq(groupingState, groupingCompany), + Seq(average, state, company), + projectStateCompanySalary) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test flatten and eval") { + val context = new CatalystPlanContext + val query = "source = relation | flatten employee | eval bonus = salary * 3" + val logPlan = planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("relation")) + val generate = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("employee"))), + seq(), + true, + None, + seq(), + table) + val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val bonusProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "*", + Seq(UnresolvedAttribute("salary"), Literal(3, IntegerType)), + isDistinct = false), + "bonus")()), + dropSourceColumn) + val expectedPlan = Project(Seq(UnresolvedStar(None)), bonusProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test flatten and parse and flatten") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | flatten employee | parse description '(?.+@.+)' | flatten roles"), + context) + val table = UnresolvedRelation(Seq("relation")) + val generateEmployee = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("employee"))), + seq(), + true, + None, + seq(), + table) + val dropSourceColumnEmployee = + DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generateEmployee) + val emailAlias = + Alias( + RegExpExtract(UnresolvedAttribute("description"), Literal("(?.+@.+)"), Literal(1)), + "email")() + val parseProject = Project( + Seq(UnresolvedAttribute("description"), emailAlias, UnresolvedStar(None)), + dropSourceColumnEmployee) + val generateRoles = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("roles"))), + seq(), + true, + None, + seq(), + parseProject) + val dropSourceColumnRoles = + DataFrameDropColumns(Seq(UnresolvedAttribute("roles")), generateRoles) + val expectedPlan = Project(Seq(UnresolvedStar(None)), dropSourceColumnRoles) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + +}