From f0753a412610eeb705215b2e83e13655c225e8f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pa=C5=82ka?= Date: Mon, 8 Jul 2024 15:31:24 +0200 Subject: [PATCH] Use `Repeated` arguments instead of inline varargs --- src/main/CollectColumns.scala | 35 +++++++++++++ src/main/Column.scala | 38 ++++++++++---- src/main/Grouping.scala | 96 ++++++++++++++-------------------- src/main/Repeated.scala | 25 +++++++++ src/main/SchemaView.scala | 56 +++++++++++++------- src/main/Select.scala | 42 +++++++-------- src/main/StructDataFrame.scala | 2 +- src/main/WithColumns.scala | 54 +++++++------------ src/test/WithColumnsTest.scala | 23 ++++++++ src/test/example/Workers.scala | 6 +-- 10 files changed, 232 insertions(+), 145 deletions(-) create mode 100644 src/main/CollectColumns.scala create mode 100644 src/main/Repeated.scala diff --git a/src/main/CollectColumns.scala b/src/main/CollectColumns.scala new file mode 100644 index 0000000..cf24729 --- /dev/null +++ b/src/main/CollectColumns.scala @@ -0,0 +1,35 @@ +package org.virtuslab.iskra + +import scala.compiletime.error + +// TODO should it be covariant or not? +trait CollectColumns[-C]: + type CollectedColumns <: Tuple + def underlyingColumns(c: C): Seq[UntypedColumn] + +// Using `given ... with { ... }` syntax might sometimes break pattern match on `CollectColumns[...] { type CollectedColumns = cc }` + +object CollectColumns extends CollectColumnsLowPrio: + given collectSingle[S <: Tuple]: CollectColumns[NamedColumns[S]] with + type CollectedColumns = S + def underlyingColumns(c: NamedColumns[S]) = c.underlyingColumns + + given collectEmptyTuple[S]: CollectColumns[EmptyTuple] with + type CollectedColumns = EmptyTuple + def underlyingColumns(c: EmptyTuple) = Seq.empty + + given collectMultiCons[S <: Tuple, T <: Tuple](using collectTail: CollectColumns[T]): (CollectColumns[NamedColumns[S] *: T] { type CollectedColumns = Tuple.Concat[S, collectTail.CollectedColumns] }) = + new CollectColumns[NamedColumns[S] *: T]: + type CollectedColumns = Tuple.Concat[S, collectTail.CollectedColumns] + def underlyingColumns(c: NamedColumns[S] *: T) = c.head.underlyingColumns ++ collectTail.underlyingColumns(c.tail) + + // TODO Customize error message for different operations with an explanation + class CannotCollectColumns(typeName: String) + extends Exception(s"Could not find an instance of CollectColumns for ${typeName}") + + +trait CollectColumnsLowPrio: + given collectSingleCons[S, T <: Tuple](using collectTail: CollectColumns[T]): (CollectColumns[NamedColumns[S] *: T] { type CollectedColumns = S *: collectTail.CollectedColumns}) = + new CollectColumns[NamedColumns[S] *: T]: + type CollectedColumns = S *: collectTail.CollectedColumns + def underlyingColumns(c: NamedColumns[S] *: T) = c.head.underlyingColumns ++ collectTail.underlyingColumns(c.tail) diff --git a/src/main/Column.scala b/src/main/Column.scala index af3bb6b..5eb9151 100644 --- a/src/main/Column.scala +++ b/src/main/Column.scala @@ -1,5 +1,7 @@ package org.virtuslab.iskra +import scala.language.implicitConversions + import scala.quoted.* import org.apache.spark.sql.{Column => UntypedColumn} @@ -32,15 +34,33 @@ object Columns: new NamedColumns[s](cols) {} } -abstract class Column(val untyped: UntypedColumn): +class Column(val untyped: UntypedColumn): inline def name(using v: ValueOf[Name]): Name = v.value object Column: + implicit transparent inline def columnToLabeledColumn(inline col: Col[?]): LabeledColumn[?, ?] = + ${ columnToLabeledColumnImpl('col) } + + private def columnToLabeledColumnImpl(col: Expr[Col[?]])(using Quotes): Expr[LabeledColumn[?, ?]] = + import quotes.reflect.* + col match + case '{ ($v: StructuralSchemaView).selectDynamic($nm: Name).$asInstanceOf$[Col[tp]] } => + nm.asTerm.tpe.asType match + case '[Name.Subtype[n]] => + '{ LabeledColumn[n, tp](${ col }.untyped.as(${ nm })) } + case '{ $c: Col[tp] } => + col.asTerm match + case Inlined(_, _, Ident(name)) => + ConstantType(StringConstant(name)).asType match + case '[Name.Subtype[n]] => + val alias = Literal(StringConstant(name)).asExprOf[Name] + '{ LabeledColumn[n, tp](${ col }.untyped.as(${ alias })) } + extension [T <: DataType](col: Col[T]) - inline def as[N <: Name](name: N)(using v: ValueOf[N]): LabeledColumn[N, T] = - LabeledColumn[N, T](col.untyped.as(v.value)) - inline def alias[N <: Name](name: N)(using v: ValueOf[N]): LabeledColumn[N, T] = - LabeledColumn[N, T](col.untyped.as(v.value)) + inline def as[N <: Name](name: N): LabeledColumn[N, T] = + LabeledColumn[N, T](col.untyped.as(name)) + inline def alias[N <: Name](name: N): LabeledColumn[N, T] = + LabeledColumn[N, T](col.untyped.as(name)) extension [T1 <: DataType](col1: Col[T1]) inline def +[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Plus[T1, T2]): Col[op.Out] = op(col1, col2) @@ -60,15 +80,13 @@ object Column: class Col[+T <: DataType](untyped: UntypedColumn) extends Column(untyped) @annotation.showAsInfix -class :=[L <: LabeledColumn.Label, T <: DataType](untyped: UntypedColumn) - extends Col[T](untyped) - with NamedColumns[(L := T) *: EmptyTuple](Seq(untyped)) +trait :=[L <: LabeledColumn.Label, T <: DataType] @annotation.showAsInfix trait /[+Prefix <: Name, +Suffix <: Name] -type LabeledColumn[L <: LabeledColumn.Label, T <: DataType] = :=[L, T] +class LabeledColumn[L <: Name, T <: DataType](untyped: UntypedColumn) + extends NamedColumns[(L := T) *: EmptyTuple](Seq(untyped)) object LabeledColumn: type Label = Name | (Name / Name) - def apply[L <: LabeledColumn.Label, T <: DataType](untyped: UntypedColumn) = new :=[L, T](untyped) diff --git a/src/main/Grouping.scala b/src/main/Grouping.scala index 395d1fb..fc8d0b2 100644 --- a/src/main/Grouping.scala +++ b/src/main/Grouping.scala @@ -13,45 +13,38 @@ object GroupBy: given groupByOps: {} with extension [View <: SchemaView](groupBy: GroupBy[View]) - transparent inline def apply(inline groupingColumns: View ?=> NamedColumns[?]*) = ${ applyImpl[View]('groupBy, 'groupingColumns) } + transparent inline def apply[C <: Repeated[NamedColumns[?]]](groupingColumns: View ?=> C) = ${ applyImpl[View, C]('groupBy, 'groupingColumns) } - def groupByImpl[S : Type](df: Expr[StructDataFrame[S]])(using Quotes): Expr[GroupBy[?]] = + private def groupByImpl[S : Type](df: Expr[StructDataFrame[S]])(using Quotes): Expr[GroupBy[?]] = import quotes.reflect.asTerm val viewExpr = StructSchemaView.schemaViewExpr[StructDataFrame[S]] viewExpr.asTerm.tpe.asType match case '[SchemaView.Subtype[v]] => '{ GroupBy[v](${ viewExpr }.asInstanceOf[v], ${ df }.untyped) } - def applyImpl[View <: SchemaView : Type](groupBy: Expr[GroupBy[View]], groupingColumns: Expr[Seq[View ?=> NamedColumns[?]]])(using Quotes): Expr[GroupedDataFrame[View]] = + private def applyImpl[View <: SchemaView : Type, C : Type](groupBy: Expr[GroupBy[View]], groupingColumns: Expr[View ?=> C])(using Quotes): Expr[GroupedDataFrame[View]] = import quotes.reflect.* - val columnValuesWithTypes = groupingColumns match - case Varargs(colExprs) => - colExprs.map { arg => - val reduced = Term.betaReduce('{$arg(using ${ groupBy }.view)}.asTerm).get - reduced.asExpr match - case '{ $value: NamedColumns[schema] } => ('{ ${ value }.underlyingColumns }, Type.of[schema]) - } - - val columnsValues = columnValuesWithTypes.map(_._1) - val columnsTypes = columnValuesWithTypes.map(_._2) - - val groupedSchemaTpe = FrameSchema.schemaTypeFromColumnsTypes(columnsTypes) - groupedSchemaTpe match - case '[TupleSubtype[groupingKeys]] => - val groupedViewExpr = StructSchemaView.schemaViewExpr[StructDataFrame[groupingKeys]] - - groupedViewExpr.asTerm.tpe.asType match - case '[SchemaView.Subtype[groupedView]] => - '{ - val groupingCols = ${ Expr.ofSeq(columnsValues) }.flatten - new GroupedDataFrame[View]: - type GroupingKeys = groupingKeys - type GroupedView = groupedView - def underlying = ${ groupBy }.underlying.groupBy(groupingCols*) - def fullView = ${ groupBy }.view - def groupedView = ${ groupedViewExpr }.asInstanceOf[GroupedView] - } + Expr.summon[CollectColumns[C]] match + case Some(collectColumns) => + collectColumns match + case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } => + Type.of[collectedColumns] match + case '[TupleSubtype[collectedCols]] => + val groupedViewExpr = StructSchemaView.schemaViewExpr[StructDataFrame[collectedCols]] + groupedViewExpr.asTerm.tpe.asType match + case '[SchemaView.Subtype[groupedView]] => + '{ + val groupingCols = ${ cc }.underlyingColumns(${ groupingColumns }(using ${ groupBy }.view)) + new GroupedDataFrame[View]: + type GroupingKeys = collectedCols + type GroupedView = groupedView + def underlying = ${ groupBy }.underlying.groupBy(groupingCols*) + def fullView = ${ groupBy }.view + def groupedView = ${ groupedViewExpr }.asInstanceOf[GroupedView] + } + case None => + throw CollectColumns.CannotCollectColumns(Type.show[C]) // TODO: Rename to RelationalGroupedDataset and handle other aggregations: cube, rollup (and pivot?) trait GroupedDataFrame[FullView <: SchemaView]: @@ -66,13 +59,12 @@ trait GroupedDataFrame[FullView <: SchemaView]: object GroupedDataFrame: given groupedDataFrameOps: {} with extension [FullView <: SchemaView, GroupKeys <: Tuple, GroupView <: SchemaView](gdf: GroupedDataFrame[FullView]{ type GroupedView = GroupView; type GroupingKeys = GroupKeys }) - transparent inline def agg(inline columns: (Agg { type View = FullView }, GroupView) ?=> NamedColumns[?]*): StructDataFrame[?] = - ${ aggImpl[FullView, GroupKeys, GroupView]('gdf, 'columns) } + transparent inline def agg[C <: Repeated[NamedColumns[?]]](columns: (Agg { type View = FullView }, GroupView) ?=> C): StructDataFrame[?] = + ${ aggImpl[FullView, GroupKeys, GroupView, C]('gdf, 'columns) } - - def aggImpl[FullView <: SchemaView : Type, GroupingKeys <: Tuple : Type, GroupView <: SchemaView : Type]( + private def aggImpl[FullView <: SchemaView : Type, GroupingKeys <: Tuple : Type, GroupView <: SchemaView : Type, C : Type]( gdf: Expr[GroupedDataFrame[FullView] { type GroupedView = GroupView }], - columns: Expr[Seq[(Agg { type View = FullView }, GroupView) ?=> NamedColumns[?]]] + columns: Expr[(Agg { type View = FullView }, GroupView) ?=> C] )(using Quotes): Expr[StructDataFrame[?]] = import quotes.reflect.* @@ -82,27 +74,19 @@ object GroupedDataFrame: val view = ${ gdf }.fullView } - val columnValuesWithTypes = columns match - case Varargs(colExprs) => - colExprs.map { arg => - val reduced = Term.betaReduce('{$arg(using ${ aggWrapper }, ${ gdf }.groupedView)}.asTerm).get - reduced.asExpr match - case '{ $value: NamedColumns[schema] } => ('{ ${ value }.underlyingColumns }, Type.of[schema]) - } - - val columnsValues = columnValuesWithTypes.map(_._1) - val columnsTypes = columnValuesWithTypes.map(_._2) - - val schemaTpe = FrameSchema.schemaTypeFromColumnsTypes(columnsTypes) - schemaTpe match - case '[s] => - '{ - // TODO assert cols is not empty - val cols = ${ Expr.ofSeq(columnsValues) }.flatten - StructDataFrame[FrameSchema.Merge[GroupingKeys, s]]( - ${ gdf }.underlying.agg(cols.head, cols.tail*) - ) - } + Expr.summon[CollectColumns[C]] match + case Some(collectColumns) => + collectColumns match + case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } => + '{ + // TODO assert cols is not empty + val cols = ${ cc }.underlyingColumns(${ columns }(using ${ aggWrapper }, ${ gdf }.groupedView)) + StructDataFrame[FrameSchema.Merge[GroupingKeys, collectedColumns]]( + ${ gdf }.underlying.agg(cols.head, cols.tail*) + ) + } + case None => + throw CollectColumns.CannotCollectColumns(Type.show[C]) trait Agg: type View <: SchemaView diff --git a/src/main/Repeated.scala b/src/main/Repeated.scala new file mode 100644 index 0000000..ec45683 --- /dev/null +++ b/src/main/Repeated.scala @@ -0,0 +1,25 @@ +package org.virtuslab.iskra + +type Repeated[A] = + A + | (A, A) + | (A, A, A) + | (A, A, A, A) + | (A, A, A, A, A) + | (A, A, A, A, A, A) + | (A, A, A, A, A, A, A) + | (A, A, A, A, A, A, A, A) + | (A, A, A, A, A, A, A, A, A) + | (A, A, A, A, A, A, A, A, A, A) + | (A, A, A, A, A, A, A, A, A, A, A) + | (A, A, A, A, A, A, A, A, A, A, A, A) + | (A, A, A, A, A, A, A, A, A, A, A, A, A) + | (A, A, A, A, A, A, A, A, A, A, A, A, A, A) + | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) + | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) + | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) + | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) + | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) + | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) + | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) + | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) // 22 is maximal arity diff --git a/src/main/SchemaView.scala b/src/main/SchemaView.scala index e4cde7f..9a7cc1f 100644 --- a/src/main/SchemaView.scala +++ b/src/main/SchemaView.scala @@ -21,17 +21,26 @@ object SchemaView: case '[StructDataFrame.Subtype[df]] => Some(StructSchemaView.schemaViewExpr[df]) -trait StructSchemaView extends SchemaView, Selectable: +trait StructuralSchemaView extends SchemaView, Selectable: + def selectDynamic(name: String): AliasedSchemaView | Column + +trait StructSchemaView extends StructuralSchemaView: def frameAliases: Seq[String] // TODO: get rid of this at runtime // TODO: What should be the semantics of `*`? How to handle ambiguous columns? // type AllColumns <: Tuple // def * : AllColumns - def selectDynamic(name: String): AliasedSchemaView | LabeledColumn[?, ?] = + // def selectDynamic(name: String): AliasedSchemaView | LabeledColumn[?, ?] = + // if frameAliases.contains(name) + // then AliasedSchemaView(name) + // else LabeledColumn(col(Name.escape(name))) + + override def selectDynamic(name: String): AliasedSchemaView | Column = if frameAliases.contains(name) then AliasedSchemaView(name) - else LabeledColumn(col(Name.escape(name))) + else Col[DataType](col(Name.escape(name))) + object StructSchemaView: type Subtype[T <: StructSchemaView] = T @@ -48,14 +57,15 @@ object StructSchemaView: import quotes.reflect.* schemaType match case '[EmptyTuple] => base - case '[LabeledColumn[headLabel, headType] *: tail] => // TODO: get rid of duplicates - val nameType = Type.of[headLabel] match - case '[Name.Subtype[name]] => Type.of[name] - case '[(Name.Subtype[framePrefix], Name.Subtype[name])] => Type.of[name] + case '[(headLabelPrefix / headLabelName := headType) *: tail] => // TODO: get rid of duplicates + val nameType = Type.of[headLabelName] match + case '[Name.Subtype[name]] => + Type.of[name] + case '[(Name.Subtype[framePrefix], Name.Subtype[name])] => + Type.of[name] val name = nameType match case '[n] => Type.valueOfConstant[n].get.toString - val info = TypeRepr.of[LabeledColumn[headLabel, headType]] - val newBase = Refinement(base, name, info) + val newBase = Refinement(base, name, TypeRepr.of[Col[headType]]) schemaViewType(newBase, Type.of[tail]) // private def reifyColumns[T <: Tuple : Type](using Quotes): Expr[Tuple] = reifyCols(Type.of[T]) @@ -64,7 +74,7 @@ object StructSchemaView: // import quotes.reflect.* // schemaType match // case '[EmptyTuple] => '{ EmptyTuple } - // case '[LabeledColumn[headLabel1, headType] *: tail] => + // case '[(headLabel1 := headType) *: tail] => // headLabel1 match // case '[Name.Subtype[name]] => // TODO: handle frame prefixes // val label = Expr(Type.valueOfConstant[name].get.toString) @@ -94,6 +104,7 @@ object StructSchemaView: def allPrefixedColumns(using Quotes)(schemaType: Type[?]): List[(String, (String, quotes.reflect.TypeRepr))] = import quotes.reflect.* + schemaType match case '[EmptyTuple] => List.empty case '[(Name.Subtype[name] := dataType) *: tail] => @@ -101,8 +112,11 @@ object StructSchemaView: case '[(framePrefix / name := dataType) *: tail] => val prefix = Type.valueOfConstant[framePrefix].get.toString val colName = Type.valueOfConstant[name].get.toString - (prefix -> (colName -> TypeRepr.of[name := dataType])) :: allPrefixedColumns(Type.of[tail]) - // TODO: Handle Nothing as schemaType (which might appear as propagation of earlier errors) + (prefix -> (colName -> TypeRepr.of[Col[dataType]])) :: allPrefixedColumns(Type.of[tail]) + + // TODO Show this case to users as propagated error + case _ => + List.empty def frameAliasViewsByName(using Quotes)(schemaType: Type[?]): List[(String, quotes.reflect.TypeRepr)] = import quotes.reflect.* @@ -120,16 +134,20 @@ object StructSchemaView: import quotes.reflect.* schemaType match case '[EmptyTuple] => List.empty - case '[LabeledColumn[Name.Subtype[name], dataType] *: tail] => + case '[(Name.Subtype[name] := dataType) *: tail] => val colName = Type.valueOfConstant[name].get.toString - val namedColumn = colName -> TypeRepr.of[LabeledColumn[name, dataType]] + val namedColumn = colName -> TypeRepr.of[Col[dataType]] namedColumn :: allColumns(Type.of[tail]) - case '[LabeledColumn[Name.Subtype[framePrefix] / Name.Subtype[name], dataType] *: tail] => + case '[((Name.Subtype[framePrefix] / Name.Subtype[name]) := dataType) *: tail] => val colName = Type.valueOfConstant[name].get.toString - val namedColumn = colName -> TypeRepr.of[LabeledColumn[name, dataType]] + val namedColumn = colName -> TypeRepr.of[Col[dataType]] namedColumn :: allColumns(Type.of[tail]) -class AliasedSchemaView(frameAliasName: String) extends Selectable: - def selectDynamic(name: String): LabeledColumn[Name, DataType] = + // TODO Show this case to users as propagated error + case _ => + List.empty + +class AliasedSchemaView(frameAliasName: String) extends StructuralSchemaView: + override def selectDynamic(name: String): Column = val columnName = s"${Name.escape(frameAliasName)}.${Name.escape(name)}" - LabeledColumn[Name, DataType](col(columnName)) \ No newline at end of file + Col[DataType](col(columnName)) \ No newline at end of file diff --git a/src/main/Select.scala b/src/main/Select.scala index 2d3225b..839418f 100644 --- a/src/main/Select.scala +++ b/src/main/Select.scala @@ -27,27 +27,27 @@ object Select: given selectOps: {} with extension [View <: SchemaView](select: Select[View]) - transparent inline def apply(inline columns: View ?=> NamedColumns[?]*): StructDataFrame[?] = - ${ applyImpl[View]('select, 'columns) } + transparent inline def apply[C <: Repeated[NamedColumns[?]]](columns: View ?=> C): StructDataFrame[?] = + ${ applyImpl[View, C]('select, 'columns) } - private def applyImpl[View <: SchemaView : Type](using Quotes)(select: Expr[Select[View]], columns: Expr[Seq[View ?=> NamedColumns[?]]]) = + private def applyImpl[View <: SchemaView : Type, C : Type](using Quotes)(select: Expr[Select[View]], columns: Expr[View ?=> C]) = import quotes.reflect.* - val columnValuesWithTypes = columns match - case Varargs(colExprs) => - colExprs.map { arg => - val reduced = Term.betaReduce('{$arg(using ${ select }.view)}.asTerm).get - reduced.asExpr match - case '{ $value: NamedColumns[schema] } => ('{ ${ value }.underlyingColumns }, Type.of[schema]) - } - - val columnsValues = columnValuesWithTypes.map(_._1) - val columnsTypes = columnValuesWithTypes.map(_._2) - - val schemaTpe = FrameSchema.schemaTypeFromColumnsTypes(columnsTypes) - schemaTpe match - case '[s] => - '{ - val cols = ${ Expr.ofSeq(columnsValues) }.flatten - StructDataFrame[s](${ select }.underlying.select(cols*)) - } + Expr.summon[CollectColumns[C]] match + case Some(collectColumns) => + collectColumns match + case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } => + Type.of[collectedColumns] match + case '[head *: EmptyTuple] => + '{ + val cols = ${ cc }.underlyingColumns(${ columns }(using ${ select }.view)) + StructDataFrame[head](${ select }.underlying.select(cols*)) + } + + case '[s] => + '{ + val cols = ${ cc }.underlyingColumns(${ columns }(using ${ select }.view)) + StructDataFrame[s](${ select }.underlying.select(cols*)) + } + case None => + throw CollectColumns.CannotCollectColumns(Type.show[C]) diff --git a/src/main/StructDataFrame.scala b/src/main/StructDataFrame.scala index 86b80d4..a4e1537 100644 --- a/src/main/StructDataFrame.scala +++ b/src/main/StructDataFrame.scala @@ -11,7 +11,7 @@ object StructDataFrame: type Subtype[T <: StructDataFrame[?]] = T type WithAlias[T <: String & Singleton] = StructDataFrame[?] { type Alias = T } - extension [Schema <: Tuple](df: StructDataFrame[Schema]) + extension [Schema](df: StructDataFrame[Schema]) inline def asClass[A]: ClassDataFrame[A] = ${ asClassImpl[Schema, A]('df) } private def asClassImpl[FrameSchema : Type, A : Type](df: Expr[StructDataFrame[FrameSchema]])(using Quotes): Expr[ClassDataFrame[A]] = diff --git a/src/main/WithColumns.scala b/src/main/WithColumns.scala index d4b50a2..f0bc5df 100644 --- a/src/main/WithColumns.scala +++ b/src/main/WithColumns.scala @@ -24,43 +24,27 @@ object WithColumns: given withColumnsApply: {} with extension [Schema <: Tuple, View <: SchemaView](withColumns: WithColumns[Schema, View]) - transparent inline def apply(inline columns: View ?=> NamedColumns[?]*): StructDataFrame[?] = - ${ applyImpl[Schema, View]('withColumns, 'columns) } + transparent inline def apply[C <: Repeated[NamedColumns[?]]](columns: View ?=> C): StructDataFrame[?] = + ${ applyImpl[Schema, View, C]('withColumns, 'columns) } - def applyImpl[Schema <: Tuple : Type, View <: SchemaView : Type]( + private def applyImpl[Schema <: Tuple : Type, View <: SchemaView : Type, C : Type]( withColumns: Expr[WithColumns[Schema, View]], - columns: Expr[Seq[View ?=> NamedColumns[?]]] + columns: Expr[View ?=> C] )(using Quotes): Expr[StructDataFrame[?]] = import quotes.reflect.* - val columnValuesWithTypesWithLabels = columns match - case Varargs(colExprs) => - colExprs.map { arg => - val reduced = Term.betaReduce('{$arg(using ${ withColumns }.view)}.asTerm).get - reduced.asExpr match - case '{ $value: NamedColumns[schema] } => ('{ ${ value }.underlyingColumns }, Type.of[schema], labelsNames(Type.of[schema])) - } - - val columnsValues = columnValuesWithTypesWithLabels.map(_._1) - val columnsTypes = columnValuesWithTypesWithLabels.map(_._2) - val columnsNames = columnValuesWithTypesWithLabels.map(_._3).flatten - - val schemaTpe = FrameSchema.schemaTypeFromColumnsTypes(columnsTypes) - schemaTpe match - case '[TupleSubtype[s]] => - '{ - val cols = ${ Expr.ofSeq(columnsValues) }.flatten - val withColumnsAppended = - ${ Expr(columnsNames) }.zip(cols).foldLeft(${ withColumns }.underlying){ - case (df, (label, col)) => - df.withColumn(label, col) - } - StructDataFrame[Tuple.Concat[Schema, s]](withColumnsAppended) - } - - private def labelsNames(schema: Type[?])(using Quotes): List[String] = - schema match - case '[EmptyTuple] => Nil - case '[(label := column) *: tail] => - val headValue = Type.valueOfConstant[label].get.toString - headValue :: labelsNames(Type.of[tail]) + Expr.summon[CollectColumns[C]] match + case Some(collectColumns) => + collectColumns match + case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } => + Type.of[collectedColumns] match + case '[TupleSubtype[collectedCols]] => + '{ + val cols = + org.apache.spark.sql.functions.col("*") +: ${ cc }.underlyingColumns(${ columns }(using ${ withColumns }.view)) + val withColumnsAppended = + ${ withColumns }.underlying.select(cols*) + StructDataFrame[Tuple.Concat[Schema, collectedCols]](withColumnsAppended) + } + case None => + throw CollectColumns.CannotCollectColumns(Type.show[C]) diff --git a/src/test/WithColumnsTest.scala b/src/test/WithColumnsTest.scala index 351c821..201ee50 100644 --- a/src/test/WithColumnsTest.scala +++ b/src/test/WithColumnsTest.scala @@ -21,6 +21,17 @@ class WithColumnsTest extends SparkUnitTest: result shouldEqual List(Bar(1, 2, 3)) } + test("withColumns-single-autoAliased") { + val result = foos + .withColumns { + val c = ($.a + $.b) + c + } + .asClass[Bar].collect().toList + + result shouldEqual List(Bar(1, 2, 3)) + } + test("withColumns-many") { val result = foos .withColumns( @@ -31,3 +42,15 @@ class WithColumnsTest extends SparkUnitTest: result shouldEqual List(Baz(1, 2, 3, -1)) } + + test("withColumns-many-autoAliased") { + val result = foos + .withColumns{ + val c = ($.a + $.b) + val d = ($.a - $.b) + (c, d) + } + .asClass[Baz].collect().toList + + result shouldEqual List(Baz(1, 2, 3, -1)) + } \ No newline at end of file diff --git a/src/test/example/Workers.scala b/src/test/example/Workers.scala index fc0e182..d17bc26 100644 --- a/src/test/example/Workers.scala +++ b/src/test/example/Workers.scala @@ -36,9 +36,9 @@ import functions.lit .leftJoin(supervisions).on($.subordinates.id === $.subordinateId) .leftJoin(workers.as("supervisors")).on($.supervisorId === $.supervisors.id) .select { - val salary = (lit(4732) + $.subordinates.yearsInCompany * lit(214)).as("salary") - val supervisor = ($.supervisors.firstName ++ lit(" ") ++ $.supervisors.lastName).as("supervisor") - Columns($.subordinates.firstName, $.subordinates.lastName, supervisor, salary) + val salary = lit(4732) + $.subordinates.yearsInCompany * lit(214) + val supervisor = $.supervisors.firstName ++ lit(" ") ++ $.supervisors.lastName + ($.subordinates.firstName, $.subordinates.lastName, supervisor, salary) } .where($.salary > lit(5000)) .show()