Skip to content

Commit

Permalink
Use Repeated arguments instead of inline varargs
Browse files Browse the repository at this point in the history
  • Loading branch information
prolativ committed Jul 8, 2024
1 parent cf90acc commit f0753a4
Show file tree
Hide file tree
Showing 10 changed files with 232 additions and 145 deletions.
35 changes: 35 additions & 0 deletions src/main/CollectColumns.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.virtuslab.iskra

import scala.compiletime.error

// TODO should it be covariant or not?
trait CollectColumns[-C]:
type CollectedColumns <: Tuple
def underlyingColumns(c: C): Seq[UntypedColumn]

// Using `given ... with { ... }` syntax might sometimes break pattern match on `CollectColumns[...] { type CollectedColumns = cc }`

object CollectColumns extends CollectColumnsLowPrio:
given collectSingle[S <: Tuple]: CollectColumns[NamedColumns[S]] with
type CollectedColumns = S
def underlyingColumns(c: NamedColumns[S]) = c.underlyingColumns

given collectEmptyTuple[S]: CollectColumns[EmptyTuple] with
type CollectedColumns = EmptyTuple
def underlyingColumns(c: EmptyTuple) = Seq.empty

given collectMultiCons[S <: Tuple, T <: Tuple](using collectTail: CollectColumns[T]): (CollectColumns[NamedColumns[S] *: T] { type CollectedColumns = Tuple.Concat[S, collectTail.CollectedColumns] }) =
new CollectColumns[NamedColumns[S] *: T]:
type CollectedColumns = Tuple.Concat[S, collectTail.CollectedColumns]
def underlyingColumns(c: NamedColumns[S] *: T) = c.head.underlyingColumns ++ collectTail.underlyingColumns(c.tail)

// TODO Customize error message for different operations with an explanation
class CannotCollectColumns(typeName: String)
extends Exception(s"Could not find an instance of CollectColumns for ${typeName}")


trait CollectColumnsLowPrio:
given collectSingleCons[S, T <: Tuple](using collectTail: CollectColumns[T]): (CollectColumns[NamedColumns[S] *: T] { type CollectedColumns = S *: collectTail.CollectedColumns}) =
new CollectColumns[NamedColumns[S] *: T]:
type CollectedColumns = S *: collectTail.CollectedColumns
def underlyingColumns(c: NamedColumns[S] *: T) = c.head.underlyingColumns ++ collectTail.underlyingColumns(c.tail)
38 changes: 28 additions & 10 deletions src/main/Column.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.virtuslab.iskra

import scala.language.implicitConversions

import scala.quoted.*

import org.apache.spark.sql.{Column => UntypedColumn}
Expand Down Expand Up @@ -32,15 +34,33 @@ object Columns:
new NamedColumns[s](cols) {}
}

abstract class Column(val untyped: UntypedColumn):
class Column(val untyped: UntypedColumn):
inline def name(using v: ValueOf[Name]): Name = v.value

object Column:
implicit transparent inline def columnToLabeledColumn(inline col: Col[?]): LabeledColumn[?, ?] =
${ columnToLabeledColumnImpl('col) }

private def columnToLabeledColumnImpl(col: Expr[Col[?]])(using Quotes): Expr[LabeledColumn[?, ?]] =
import quotes.reflect.*
col match
case '{ ($v: StructuralSchemaView).selectDynamic($nm: Name).$asInstanceOf$[Col[tp]] } =>
nm.asTerm.tpe.asType match
case '[Name.Subtype[n]] =>
'{ LabeledColumn[n, tp](${ col }.untyped.as(${ nm })) }
case '{ $c: Col[tp] } =>
col.asTerm match
case Inlined(_, _, Ident(name)) =>
ConstantType(StringConstant(name)).asType match
case '[Name.Subtype[n]] =>
val alias = Literal(StringConstant(name)).asExprOf[Name]
'{ LabeledColumn[n, tp](${ col }.untyped.as(${ alias })) }

extension [T <: DataType](col: Col[T])
inline def as[N <: Name](name: N)(using v: ValueOf[N]): LabeledColumn[N, T] =
LabeledColumn[N, T](col.untyped.as(v.value))
inline def alias[N <: Name](name: N)(using v: ValueOf[N]): LabeledColumn[N, T] =
LabeledColumn[N, T](col.untyped.as(v.value))
inline def as[N <: Name](name: N): LabeledColumn[N, T] =
LabeledColumn[N, T](col.untyped.as(name))
inline def alias[N <: Name](name: N): LabeledColumn[N, T] =
LabeledColumn[N, T](col.untyped.as(name))

extension [T1 <: DataType](col1: Col[T1])
inline def +[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Plus[T1, T2]): Col[op.Out] = op(col1, col2)
Expand All @@ -60,15 +80,13 @@ object Column:
class Col[+T <: DataType](untyped: UntypedColumn) extends Column(untyped)

@annotation.showAsInfix
class :=[L <: LabeledColumn.Label, T <: DataType](untyped: UntypedColumn)
extends Col[T](untyped)
with NamedColumns[(L := T) *: EmptyTuple](Seq(untyped))
trait :=[L <: LabeledColumn.Label, T <: DataType]

@annotation.showAsInfix
trait /[+Prefix <: Name, +Suffix <: Name]

type LabeledColumn[L <: LabeledColumn.Label, T <: DataType] = :=[L, T]
class LabeledColumn[L <: Name, T <: DataType](untyped: UntypedColumn)
extends NamedColumns[(L := T) *: EmptyTuple](Seq(untyped))

object LabeledColumn:
type Label = Name | (Name / Name)
def apply[L <: LabeledColumn.Label, T <: DataType](untyped: UntypedColumn) = new :=[L, T](untyped)
96 changes: 40 additions & 56 deletions src/main/Grouping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,45 +13,38 @@ object GroupBy:

given groupByOps: {} with
extension [View <: SchemaView](groupBy: GroupBy[View])
transparent inline def apply(inline groupingColumns: View ?=> NamedColumns[?]*) = ${ applyImpl[View]('groupBy, 'groupingColumns) }
transparent inline def apply[C <: Repeated[NamedColumns[?]]](groupingColumns: View ?=> C) = ${ applyImpl[View, C]('groupBy, 'groupingColumns) }

def groupByImpl[S : Type](df: Expr[StructDataFrame[S]])(using Quotes): Expr[GroupBy[?]] =
private def groupByImpl[S : Type](df: Expr[StructDataFrame[S]])(using Quotes): Expr[GroupBy[?]] =
import quotes.reflect.asTerm
val viewExpr = StructSchemaView.schemaViewExpr[StructDataFrame[S]]
viewExpr.asTerm.tpe.asType match
case '[SchemaView.Subtype[v]] =>
'{ GroupBy[v](${ viewExpr }.asInstanceOf[v], ${ df }.untyped) }

def applyImpl[View <: SchemaView : Type](groupBy: Expr[GroupBy[View]], groupingColumns: Expr[Seq[View ?=> NamedColumns[?]]])(using Quotes): Expr[GroupedDataFrame[View]] =
private def applyImpl[View <: SchemaView : Type, C : Type](groupBy: Expr[GroupBy[View]], groupingColumns: Expr[View ?=> C])(using Quotes): Expr[GroupedDataFrame[View]] =
import quotes.reflect.*

val columnValuesWithTypes = groupingColumns match
case Varargs(colExprs) =>
colExprs.map { arg =>
val reduced = Term.betaReduce('{$arg(using ${ groupBy }.view)}.asTerm).get
reduced.asExpr match
case '{ $value: NamedColumns[schema] } => ('{ ${ value }.underlyingColumns }, Type.of[schema])
}

val columnsValues = columnValuesWithTypes.map(_._1)
val columnsTypes = columnValuesWithTypes.map(_._2)

val groupedSchemaTpe = FrameSchema.schemaTypeFromColumnsTypes(columnsTypes)
groupedSchemaTpe match
case '[TupleSubtype[groupingKeys]] =>
val groupedViewExpr = StructSchemaView.schemaViewExpr[StructDataFrame[groupingKeys]]

groupedViewExpr.asTerm.tpe.asType match
case '[SchemaView.Subtype[groupedView]] =>
'{
val groupingCols = ${ Expr.ofSeq(columnsValues) }.flatten
new GroupedDataFrame[View]:
type GroupingKeys = groupingKeys
type GroupedView = groupedView
def underlying = ${ groupBy }.underlying.groupBy(groupingCols*)
def fullView = ${ groupBy }.view
def groupedView = ${ groupedViewExpr }.asInstanceOf[GroupedView]
}
Expr.summon[CollectColumns[C]] match
case Some(collectColumns) =>
collectColumns match
case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } =>
Type.of[collectedColumns] match
case '[TupleSubtype[collectedCols]] =>
val groupedViewExpr = StructSchemaView.schemaViewExpr[StructDataFrame[collectedCols]]
groupedViewExpr.asTerm.tpe.asType match
case '[SchemaView.Subtype[groupedView]] =>
'{
val groupingCols = ${ cc }.underlyingColumns(${ groupingColumns }(using ${ groupBy }.view))
new GroupedDataFrame[View]:
type GroupingKeys = collectedCols
type GroupedView = groupedView
def underlying = ${ groupBy }.underlying.groupBy(groupingCols*)
def fullView = ${ groupBy }.view
def groupedView = ${ groupedViewExpr }.asInstanceOf[GroupedView]
}
case None =>
throw CollectColumns.CannotCollectColumns(Type.show[C])

// TODO: Rename to RelationalGroupedDataset and handle other aggregations: cube, rollup (and pivot?)
trait GroupedDataFrame[FullView <: SchemaView]:
Expand All @@ -66,13 +59,12 @@ trait GroupedDataFrame[FullView <: SchemaView]:
object GroupedDataFrame:
given groupedDataFrameOps: {} with
extension [FullView <: SchemaView, GroupKeys <: Tuple, GroupView <: SchemaView](gdf: GroupedDataFrame[FullView]{ type GroupedView = GroupView; type GroupingKeys = GroupKeys })
transparent inline def agg(inline columns: (Agg { type View = FullView }, GroupView) ?=> NamedColumns[?]*): StructDataFrame[?] =
${ aggImpl[FullView, GroupKeys, GroupView]('gdf, 'columns) }
transparent inline def agg[C <: Repeated[NamedColumns[?]]](columns: (Agg { type View = FullView }, GroupView) ?=> C): StructDataFrame[?] =
${ aggImpl[FullView, GroupKeys, GroupView, C]('gdf, 'columns) }


def aggImpl[FullView <: SchemaView : Type, GroupingKeys <: Tuple : Type, GroupView <: SchemaView : Type](
private def aggImpl[FullView <: SchemaView : Type, GroupingKeys <: Tuple : Type, GroupView <: SchemaView : Type, C : Type](
gdf: Expr[GroupedDataFrame[FullView] { type GroupedView = GroupView }],
columns: Expr[Seq[(Agg { type View = FullView }, GroupView) ?=> NamedColumns[?]]]
columns: Expr[(Agg { type View = FullView }, GroupView) ?=> C]
)(using Quotes): Expr[StructDataFrame[?]] =
import quotes.reflect.*

Expand All @@ -82,27 +74,19 @@ object GroupedDataFrame:
val view = ${ gdf }.fullView
}

val columnValuesWithTypes = columns match
case Varargs(colExprs) =>
colExprs.map { arg =>
val reduced = Term.betaReduce('{$arg(using ${ aggWrapper }, ${ gdf }.groupedView)}.asTerm).get
reduced.asExpr match
case '{ $value: NamedColumns[schema] } => ('{ ${ value }.underlyingColumns }, Type.of[schema])
}

val columnsValues = columnValuesWithTypes.map(_._1)
val columnsTypes = columnValuesWithTypes.map(_._2)

val schemaTpe = FrameSchema.schemaTypeFromColumnsTypes(columnsTypes)
schemaTpe match
case '[s] =>
'{
// TODO assert cols is not empty
val cols = ${ Expr.ofSeq(columnsValues) }.flatten
StructDataFrame[FrameSchema.Merge[GroupingKeys, s]](
${ gdf }.underlying.agg(cols.head, cols.tail*)
)
}
Expr.summon[CollectColumns[C]] match
case Some(collectColumns) =>
collectColumns match
case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } =>
'{
// TODO assert cols is not empty
val cols = ${ cc }.underlyingColumns(${ columns }(using ${ aggWrapper }, ${ gdf }.groupedView))
StructDataFrame[FrameSchema.Merge[GroupingKeys, collectedColumns]](
${ gdf }.underlying.agg(cols.head, cols.tail*)
)
}
case None =>
throw CollectColumns.CannotCollectColumns(Type.show[C])

trait Agg:
type View <: SchemaView
Expand Down
25 changes: 25 additions & 0 deletions src/main/Repeated.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.virtuslab.iskra

type Repeated[A] =
A
| (A, A)
| (A, A, A)
| (A, A, A, A)
| (A, A, A, A, A)
| (A, A, A, A, A, A)
| (A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) // 22 is maximal arity
56 changes: 37 additions & 19 deletions src/main/SchemaView.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,26 @@ object SchemaView:
case '[StructDataFrame.Subtype[df]] =>
Some(StructSchemaView.schemaViewExpr[df])

trait StructSchemaView extends SchemaView, Selectable:
trait StructuralSchemaView extends SchemaView, Selectable:
def selectDynamic(name: String): AliasedSchemaView | Column

trait StructSchemaView extends StructuralSchemaView:
def frameAliases: Seq[String] // TODO: get rid of this at runtime

// TODO: What should be the semantics of `*`? How to handle ambiguous columns?
// type AllColumns <: Tuple
// def * : AllColumns

def selectDynamic(name: String): AliasedSchemaView | LabeledColumn[?, ?] =
// def selectDynamic(name: String): AliasedSchemaView | LabeledColumn[?, ?] =
// if frameAliases.contains(name)
// then AliasedSchemaView(name)
// else LabeledColumn(col(Name.escape(name)))

override def selectDynamic(name: String): AliasedSchemaView | Column =
if frameAliases.contains(name)
then AliasedSchemaView(name)
else LabeledColumn(col(Name.escape(name)))
else Col[DataType](col(Name.escape(name)))


object StructSchemaView:
type Subtype[T <: StructSchemaView] = T
Expand All @@ -48,14 +57,15 @@ object StructSchemaView:
import quotes.reflect.*
schemaType match
case '[EmptyTuple] => base
case '[LabeledColumn[headLabel, headType] *: tail] => // TODO: get rid of duplicates
val nameType = Type.of[headLabel] match
case '[Name.Subtype[name]] => Type.of[name]
case '[(Name.Subtype[framePrefix], Name.Subtype[name])] => Type.of[name]
case '[(headLabelPrefix / headLabelName := headType) *: tail] => // TODO: get rid of duplicates
val nameType = Type.of[headLabelName] match
case '[Name.Subtype[name]] =>
Type.of[name]
case '[(Name.Subtype[framePrefix], Name.Subtype[name])] =>
Type.of[name]
val name = nameType match
case '[n] => Type.valueOfConstant[n].get.toString
val info = TypeRepr.of[LabeledColumn[headLabel, headType]]
val newBase = Refinement(base, name, info)
val newBase = Refinement(base, name, TypeRepr.of[Col[headType]])
schemaViewType(newBase, Type.of[tail])

// private def reifyColumns[T <: Tuple : Type](using Quotes): Expr[Tuple] = reifyCols(Type.of[T])
Expand All @@ -64,7 +74,7 @@ object StructSchemaView:
// import quotes.reflect.*
// schemaType match
// case '[EmptyTuple] => '{ EmptyTuple }
// case '[LabeledColumn[headLabel1, headType] *: tail] =>
// case '[(headLabel1 := headType) *: tail] =>
// headLabel1 match
// case '[Name.Subtype[name]] => // TODO: handle frame prefixes
// val label = Expr(Type.valueOfConstant[name].get.toString)
Expand Down Expand Up @@ -94,15 +104,19 @@ object StructSchemaView:

def allPrefixedColumns(using Quotes)(schemaType: Type[?]): List[(String, (String, quotes.reflect.TypeRepr))] =
import quotes.reflect.*

schemaType match
case '[EmptyTuple] => List.empty
case '[(Name.Subtype[name] := dataType) *: tail] =>
allPrefixedColumns(Type.of[tail])
case '[(framePrefix / name := dataType) *: tail] =>
val prefix = Type.valueOfConstant[framePrefix].get.toString
val colName = Type.valueOfConstant[name].get.toString
(prefix -> (colName -> TypeRepr.of[name := dataType])) :: allPrefixedColumns(Type.of[tail])
// TODO: Handle Nothing as schemaType (which might appear as propagation of earlier errors)
(prefix -> (colName -> TypeRepr.of[Col[dataType]])) :: allPrefixedColumns(Type.of[tail])

// TODO Show this case to users as propagated error
case _ =>
List.empty

def frameAliasViewsByName(using Quotes)(schemaType: Type[?]): List[(String, quotes.reflect.TypeRepr)] =
import quotes.reflect.*
Expand All @@ -120,16 +134,20 @@ object StructSchemaView:
import quotes.reflect.*
schemaType match
case '[EmptyTuple] => List.empty
case '[LabeledColumn[Name.Subtype[name], dataType] *: tail] =>
case '[(Name.Subtype[name] := dataType) *: tail] =>
val colName = Type.valueOfConstant[name].get.toString
val namedColumn = colName -> TypeRepr.of[LabeledColumn[name, dataType]]
val namedColumn = colName -> TypeRepr.of[Col[dataType]]
namedColumn :: allColumns(Type.of[tail])
case '[LabeledColumn[Name.Subtype[framePrefix] / Name.Subtype[name], dataType] *: tail] =>
case '[((Name.Subtype[framePrefix] / Name.Subtype[name]) := dataType) *: tail] =>
val colName = Type.valueOfConstant[name].get.toString
val namedColumn = colName -> TypeRepr.of[LabeledColumn[name, dataType]]
val namedColumn = colName -> TypeRepr.of[Col[dataType]]
namedColumn :: allColumns(Type.of[tail])

class AliasedSchemaView(frameAliasName: String) extends Selectable:
def selectDynamic(name: String): LabeledColumn[Name, DataType] =
// TODO Show this case to users as propagated error
case _ =>
List.empty

class AliasedSchemaView(frameAliasName: String) extends StructuralSchemaView:
override def selectDynamic(name: String): Column =
val columnName = s"${Name.escape(frameAliasName)}.${Name.escape(name)}"
LabeledColumn[Name, DataType](col(columnName))
Col[DataType](col(columnName))
Loading

0 comments on commit f0753a4

Please sign in to comment.