diff --git a/src/main/api/api.scala b/src/main/api/api.scala index 13f0be3..edb526e 100644 --- a/src/main/api/api.scala +++ b/src/main/api/api.scala @@ -28,7 +28,7 @@ export org.virtuslab.iskra.$ export org.virtuslab.iskra.{Column, DataFrame, UntypedColumn, UntypedDataFrame, :=, /} object functions: - export org.virtuslab.iskra.functions.{lit, when} + export org.virtuslab.iskra.functions.{explode, lit, when} export org.virtuslab.iskra.functions.Aggregates.* export org.apache.spark.sql.SparkSession diff --git a/src/main/functions/explode.scala b/src/main/functions/explode.scala new file mode 100644 index 0000000..ac9ed23 --- /dev/null +++ b/src/main/functions/explode.scala @@ -0,0 +1,8 @@ +package org.virtuslab.iskra.functions + +import org.apache.spark.sql +import org.virtuslab.iskra.Column +import org.virtuslab.iskra.types.{ ArrayOptType, DataType } + +def explode[T <: DataType](c: Column[ArrayOptType[T]]): Column[T] = Column(sql.functions.explode(c.untyped)) + diff --git a/src/main/types/DataType.scala b/src/main/types/DataType.scala index 1593d7d..86d141e 100644 --- a/src/main/types/DataType.scala +++ b/src/main/types/DataType.scala @@ -21,6 +21,7 @@ object DataType: case FloatOptType => FloatOptType case DoubleOptType => DoubleOptType case StructOptType[schema] => StructOptType[schema] + case ArrayOptType[schema] => ArrayOptType[schema] type NonNullable[T <: DataType] <: DataType = T match case BooleanOptType => BooleanType @@ -32,6 +33,7 @@ object DataType: case FloatOptType => FloatOptType case DoubleOptType => DoubleOptType case StructOptType[schema] => StructOptType[schema] + case ArrayOptType[schema] => ArrayType[schema] type CommonNumericNullableType[T1 <: DataType, T2 <: DataType] <: NumericOptType = (T1, T2) match case (DoubleOptType, _) | (_, DoubleOptType) => DoubleOptType @@ -77,3 +79,6 @@ final class DoubleType extends DoubleOptType, NotNull sealed class StructOptType[Schema <: Tuple] extends DataType final class StructType[Schema <: Tuple] extends StructOptType[Schema], NotNull + +sealed class ArrayOptType[T <: DataType] extends DataType +final class ArrayType[T <: DataType] extends ArrayOptType[T], NotNull diff --git a/src/main/types/Encoder.scala b/src/main/types/Encoder.scala index b8cff2c..9ad45f9 100644 --- a/src/main/types/Encoder.scala +++ b/src/main/types/Encoder.scala @@ -86,6 +86,14 @@ object Encoder: type ColumnType = DoubleOptType def catalystType = sql.types.DoubleType + inline given arrayFromMirror[A](using encoder: Encoder[A]): (Encoder[Seq[A]] { type ColumnType = ArrayOptType[encoder.ColumnType] }) = + new Encoder[Seq[A]]: + override type ColumnType = ArrayOptType[encoder.ColumnType] + override def encode(value: Seq[A]): Any = if (value == null) Seq() else value.map(encoder.encode) + override def decode(value: Any): Any = Seq(encoder.decode) + override def catalystType = sql.types.ArrayType(encoder.catalystType) + override def isNullable = true + export StructEncoder.{fromMirror, optFromMirror} trait StructEncoder[-A] extends Encoder[A]: diff --git a/src/test/ExplodeTest.scala b/src/test/ExplodeTest.scala new file mode 100644 index 0000000..23e2cff --- /dev/null +++ b/src/test/ExplodeTest.scala @@ -0,0 +1,24 @@ +package org.virtuslab.iskra.test + +class ExplodeTest extends SparkUnitTest: + import org.virtuslab.iskra.api.* + import functions.explode + + case class Foo(ints: Seq[Int]) + + val foos = Seq( + Foo(Seq(1)), + Foo(Seq(2)), + Foo(Seq()), + Foo(null), + Foo(Seq(3,4)) + ).toTypedDF + + test("explode") { + val result = foos + .select(explode($.ints).as("int")) + .collectAs[Int] + + result shouldEqual Seq(1,2,3,4) + } +