From fd2b51ed596490882c8ba9728cf3d7efad214e98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pa=C5=82ka?= Date: Mon, 8 Jul 2024 18:04:37 +0200 Subject: [PATCH] Clean up named columns abstraction --- src/main/CollectColumns.scala | 28 ++++++------ src/main/Column.scala | 86 ++++++++++++++++++----------------- src/main/FrameSchema.scala | 6 +-- src/main/Grouping.scala | 4 +- src/main/SchemaView.scala | 5 -- src/main/Select.scala | 2 +- src/main/WithColumns.scala | 2 +- src/test/ColumnsTest.scala | 24 ++++++++++ 8 files changed, 89 insertions(+), 68 deletions(-) create mode 100644 src/test/ColumnsTest.scala diff --git a/src/main/CollectColumns.scala b/src/main/CollectColumns.scala index cf24729..e18250e 100644 --- a/src/main/CollectColumns.scala +++ b/src/main/CollectColumns.scala @@ -2,6 +2,8 @@ package org.virtuslab.iskra import scala.compiletime.error +import org.virtuslab.iskra.types.DataType + // TODO should it be covariant or not? trait CollectColumns[-C]: type CollectedColumns <: Tuple @@ -9,27 +11,25 @@ trait CollectColumns[-C]: // Using `given ... with { ... }` syntax might sometimes break pattern match on `CollectColumns[...] { type CollectedColumns = cc }` -object CollectColumns extends CollectColumnsLowPrio: - given collectSingle[S <: Tuple]: CollectColumns[NamedColumns[S]] with +object CollectColumns: + given collectNamedColumn[N <: Name, T <: DataType]: CollectColumns[NamedColumn[N, T]] with + type CollectedColumns = (N := T) *: EmptyTuple + def underlyingColumns(c: NamedColumn[N, T]) = Seq(c.untyped) + + given collectColumnsWithSchema[S <: Tuple]: CollectColumns[ColumnsWithSchema[S]] with type CollectedColumns = S - def underlyingColumns(c: NamedColumns[S]) = c.underlyingColumns + def underlyingColumns(c: ColumnsWithSchema[S]) = c.underlyingColumns given collectEmptyTuple[S]: CollectColumns[EmptyTuple] with type CollectedColumns = EmptyTuple def underlyingColumns(c: EmptyTuple) = Seq.empty - given collectMultiCons[S <: Tuple, T <: Tuple](using collectTail: CollectColumns[T]): (CollectColumns[NamedColumns[S] *: T] { type CollectedColumns = Tuple.Concat[S, collectTail.CollectedColumns] }) = - new CollectColumns[NamedColumns[S] *: T]: - type CollectedColumns = Tuple.Concat[S, collectTail.CollectedColumns] - def underlyingColumns(c: NamedColumns[S] *: T) = c.head.underlyingColumns ++ collectTail.underlyingColumns(c.tail) + given collectCons[H, T <: Tuple](using collectHead: CollectColumns[H], collectTail: CollectColumns[T]): (CollectColumns[H *: T] { type CollectedColumns = Tuple.Concat[collectHead.CollectedColumns, collectTail.CollectedColumns] }) = + new CollectColumns[H *: T]: + type CollectedColumns = Tuple.Concat[collectHead.CollectedColumns, collectTail.CollectedColumns] + def underlyingColumns(c: H *: T) = collectHead.underlyingColumns(c.head) ++ collectTail.underlyingColumns(c.tail) + // TODO Customize error message for different operations with an explanation class CannotCollectColumns(typeName: String) extends Exception(s"Could not find an instance of CollectColumns for ${typeName}") - - -trait CollectColumnsLowPrio: - given collectSingleCons[S, T <: Tuple](using collectTail: CollectColumns[T]): (CollectColumns[NamedColumns[S] *: T] { type CollectedColumns = S *: collectTail.CollectedColumns}) = - new CollectColumns[NamedColumns[S] *: T]: - type CollectedColumns = S *: collectTail.CollectedColumns - def underlyingColumns(c: NamedColumns[S] *: T) = c.head.underlyingColumns ++ collectTail.underlyingColumns(c.tail) diff --git a/src/main/Column.scala b/src/main/Column.scala index 5eb9151..11f3c43 100644 --- a/src/main/Column.scala +++ b/src/main/Column.scala @@ -6,61 +6,35 @@ import scala.quoted.* import org.apache.spark.sql.{Column => UntypedColumn} import types.DataType - -sealed trait NamedColumns[Schema](val underlyingColumns: Seq[UntypedColumn]) - -object Columns: - transparent inline def apply(inline columns: NamedColumns[?]*): NamedColumns[?] = ${ applyImpl('columns) } - - private def applyImpl(columns: Expr[Seq[NamedColumns[?]]])(using Quotes): Expr[NamedColumns[?]] = - import quotes.reflect.* - - val columnValuesWithTypes = columns match - case Varargs(colExprs) => - colExprs.map { arg => - arg match - case '{ $value: NamedColumns[schema] } => ('{ ${ value }.underlyingColumns }, Type.of[schema]) - } - - val columnsValues = columnValuesWithTypes.map(_._1) - val columnsTypes = columnValuesWithTypes.map(_._2) - - val schemaTpe = FrameSchema.schemaTypeFromColumnsTypes(columnsTypes) - - schemaTpe match - case '[s] => - '{ - val cols = ${ Expr.ofSeq(columnsValues) }.flatten - new NamedColumns[s](cols) {} - } +import MacroHelpers.TupleSubtype class Column(val untyped: UntypedColumn): inline def name(using v: ValueOf[Name]): Name = v.value object Column: - implicit transparent inline def columnToLabeledColumn(inline col: Col[?]): LabeledColumn[?, ?] = - ${ columnToLabeledColumnImpl('col) } + implicit transparent inline def columnToNamedColumn(inline col: Col[?]): NamedColumn[?, ?] = + ${ columnToNamedColumnImpl('col) } - private def columnToLabeledColumnImpl(col: Expr[Col[?]])(using Quotes): Expr[LabeledColumn[?, ?]] = + private def columnToNamedColumnImpl(col: Expr[Col[?]])(using Quotes): Expr[NamedColumn[?, ?]] = import quotes.reflect.* col match case '{ ($v: StructuralSchemaView).selectDynamic($nm: Name).$asInstanceOf$[Col[tp]] } => nm.asTerm.tpe.asType match case '[Name.Subtype[n]] => - '{ LabeledColumn[n, tp](${ col }.untyped.as(${ nm })) } + '{ NamedColumn[n, tp](${ col }.untyped.as(${ nm })) } case '{ $c: Col[tp] } => col.asTerm match case Inlined(_, _, Ident(name)) => ConstantType(StringConstant(name)).asType match case '[Name.Subtype[n]] => val alias = Literal(StringConstant(name)).asExprOf[Name] - '{ LabeledColumn[n, tp](${ col }.untyped.as(${ alias })) } + '{ NamedColumn[n, tp](${ col }.untyped.as(${ alias })) } extension [T <: DataType](col: Col[T]) - inline def as[N <: Name](name: N): LabeledColumn[N, T] = - LabeledColumn[N, T](col.untyped.as(name)) - inline def alias[N <: Name](name: N): LabeledColumn[N, T] = - LabeledColumn[N, T](col.untyped.as(name)) + inline def as[N <: Name](name: N): NamedColumn[N, T] = + NamedColumn[N, T](col.untyped.as(name)) + inline def alias[N <: Name](name: N): NamedColumn[N, T] = + NamedColumn[N, T](col.untyped.as(name)) extension [T1 <: DataType](col1: Col[T1]) inline def +[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Plus[T1, T2]): Col[op.Out] = op(col1, col2) @@ -77,16 +51,44 @@ object Column: inline def &&[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.And[T1, T2]): Col[op.Out] = op(col1, col2) inline def ||[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Or[T1, T2]): Col[op.Out] = op(col1, col2) + class Col[+T <: DataType](untyped: UntypedColumn) extends Column(untyped) + +object Columns: + transparent inline def apply[C <: NamedColumns](columns: C): ColumnsWithSchema[?] = ${ applyImpl('columns) } + + private def applyImpl[C : Type](columns: Expr[C])(using Quotes): Expr[ColumnsWithSchema[?]] = + import quotes.reflect.* + + Expr.summon[CollectColumns[C]] match + case Some(collectColumns) => + collectColumns match + case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } => + Type.of[collectedColumns] match + case '[TupleSubtype[collectedCols]] => + '{ + val cols = ${ cc }.underlyingColumns(${ columns }) + ColumnsWithSchema[collectedCols](cols) + } + case None => + throw CollectColumns.CannotCollectColumns(Type.show[C]) + + +trait NamedColumnOrColumnsLike + +type NamedColumns = Repeated[NamedColumnOrColumnsLike] + +class NamedColumn[N <: Name, T <: DataType](val untyped: UntypedColumn) + extends NamedColumnOrColumnsLike + +class ColumnsWithSchema[Schema <: Tuple](val underlyingColumns: Seq[UntypedColumn]) extends NamedColumnOrColumnsLike + + @annotation.showAsInfix -trait :=[L <: LabeledColumn.Label, T <: DataType] +trait :=[L <: ColumnLabel, T <: DataType] @annotation.showAsInfix trait /[+Prefix <: Name, +Suffix <: Name] -class LabeledColumn[L <: Name, T <: DataType](untyped: UntypedColumn) - extends NamedColumns[(L := T) *: EmptyTuple](Seq(untyped)) - -object LabeledColumn: - type Label = Name | (Name / Name) +type ColumnLabel = Name | (Name / Name) diff --git a/src/main/FrameSchema.scala b/src/main/FrameSchema.scala index 0c1375d..f6188db 100644 --- a/src/main/FrameSchema.scala +++ b/src/main/FrameSchema.scala @@ -14,12 +14,12 @@ object FrameSchema: case TupleSubtype[s2] => S1 *: s2 case _ => S1 *: S2 *: EmptyTuple - type NullableLabeledColumn[T] = T match + type NullableLabeledDataType[T] = T match case label := tpe => label := DataType.Nullable[tpe] type NullableSchema[T] = T match - case TupleSubtype[s] => Tuple.Map[s, NullableLabeledColumn] - case _ => NullableLabeledColumn[T] + case TupleSubtype[s] => Tuple.Map[s, NullableLabeledDataType] + case _ => NullableLabeledDataType[T] def reownType[Owner <: Name : Type](schema: Type[?])(using Quotes): Type[?] = schema match diff --git a/src/main/Grouping.scala b/src/main/Grouping.scala index fc8d0b2..0a074b0 100644 --- a/src/main/Grouping.scala +++ b/src/main/Grouping.scala @@ -13,7 +13,7 @@ object GroupBy: given groupByOps: {} with extension [View <: SchemaView](groupBy: GroupBy[View]) - transparent inline def apply[C <: Repeated[NamedColumns[?]]](groupingColumns: View ?=> C) = ${ applyImpl[View, C]('groupBy, 'groupingColumns) } + transparent inline def apply[C <: NamedColumns](groupingColumns: View ?=> C) = ${ applyImpl[View, C]('groupBy, 'groupingColumns) } private def groupByImpl[S : Type](df: Expr[StructDataFrame[S]])(using Quotes): Expr[GroupBy[?]] = import quotes.reflect.asTerm @@ -59,7 +59,7 @@ trait GroupedDataFrame[FullView <: SchemaView]: object GroupedDataFrame: given groupedDataFrameOps: {} with extension [FullView <: SchemaView, GroupKeys <: Tuple, GroupView <: SchemaView](gdf: GroupedDataFrame[FullView]{ type GroupedView = GroupView; type GroupingKeys = GroupKeys }) - transparent inline def agg[C <: Repeated[NamedColumns[?]]](columns: (Agg { type View = FullView }, GroupView) ?=> C): StructDataFrame[?] = + transparent inline def agg[C <: NamedColumns](columns: (Agg { type View = FullView }, GroupView) ?=> C): StructDataFrame[?] = ${ aggImpl[FullView, GroupKeys, GroupView, C]('gdf, 'columns) } private def aggImpl[FullView <: SchemaView : Type, GroupingKeys <: Tuple : Type, GroupView <: SchemaView : Type, C : Type]( diff --git a/src/main/SchemaView.scala b/src/main/SchemaView.scala index 9a7cc1f..04bda8b 100644 --- a/src/main/SchemaView.scala +++ b/src/main/SchemaView.scala @@ -30,11 +30,6 @@ trait StructSchemaView extends StructuralSchemaView: // TODO: What should be the semantics of `*`? How to handle ambiguous columns? // type AllColumns <: Tuple // def * : AllColumns - - // def selectDynamic(name: String): AliasedSchemaView | LabeledColumn[?, ?] = - // if frameAliases.contains(name) - // then AliasedSchemaView(name) - // else LabeledColumn(col(Name.escape(name))) override def selectDynamic(name: String): AliasedSchemaView | Column = if frameAliases.contains(name) diff --git a/src/main/Select.scala b/src/main/Select.scala index 839418f..f2b111b 100644 --- a/src/main/Select.scala +++ b/src/main/Select.scala @@ -27,7 +27,7 @@ object Select: given selectOps: {} with extension [View <: SchemaView](select: Select[View]) - transparent inline def apply[C <: Repeated[NamedColumns[?]]](columns: View ?=> C): StructDataFrame[?] = + transparent inline def apply[C <: NamedColumns](columns: View ?=> C): StructDataFrame[?] = ${ applyImpl[View, C]('select, 'columns) } private def applyImpl[View <: SchemaView : Type, C : Type](using Quotes)(select: Expr[Select[View]], columns: Expr[View ?=> C]) = diff --git a/src/main/WithColumns.scala b/src/main/WithColumns.scala index f0bc5df..aa0b60f 100644 --- a/src/main/WithColumns.scala +++ b/src/main/WithColumns.scala @@ -24,7 +24,7 @@ object WithColumns: given withColumnsApply: {} with extension [Schema <: Tuple, View <: SchemaView](withColumns: WithColumns[Schema, View]) - transparent inline def apply[C <: Repeated[NamedColumns[?]]](columns: View ?=> C): StructDataFrame[?] = + transparent inline def apply[C <: NamedColumns](columns: View ?=> C): StructDataFrame[?] = ${ applyImpl[Schema, View, C]('withColumns, 'columns) } private def applyImpl[Schema <: Tuple : Type, View <: SchemaView : Type, C : Type]( diff --git a/src/test/ColumnsTest.scala b/src/test/ColumnsTest.scala new file mode 100644 index 0000000..13f030e --- /dev/null +++ b/src/test/ColumnsTest.scala @@ -0,0 +1,24 @@ +package org.virtuslab.iskra.test + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.BeforeAndAfterAll +import org.scalatest.matchers.should.Matchers.shouldEqual + +class ColumnsTest extends SparkUnitTest: + import org.virtuslab.iskra.api.* + + case class Foo(x1: Int, x2: Int, x3: Int, x4: Int) + + val foos = Seq( + Foo(1, 2, 3, 4) + ).toDF.asStruct + + test("plus") { + val result = foos.select { + val cols1 = Columns($.x1) + val cols2 = Columns($.x2, $.x3) + (cols1, cols2, $.x4) + }.asClass[Foo].collect().toList + + result shouldEqual List(Foo(1, 2, 3, 4)) + }