Skip to content

Commit

Permalink
Separate supertype Column from more specific Col[T] with particular d…
Browse files Browse the repository at this point in the history
…ata type
  • Loading branch information
prolativ committed Jul 4, 2024
1 parent dd486a8 commit cf90acc
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 35 deletions.
37 changes: 19 additions & 18 deletions src/main/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,36 @@ object Columns:
new NamedColumns[s](cols) {}
}

class Column[+T <: DataType](val untyped: UntypedColumn):

abstract class Column(val untyped: UntypedColumn):
inline def name(using v: ValueOf[Name]): Name = v.value

object Column:
extension [T <: DataType](col: Column[T])
extension [T <: DataType](col: Col[T])
inline def as[N <: Name](name: N)(using v: ValueOf[N]): LabeledColumn[N, T] =
LabeledColumn[N, T](col.untyped.as(v.value))
inline def alias[N <: Name](name: N)(using v: ValueOf[N]): LabeledColumn[N, T] =
LabeledColumn[N, T](col.untyped.as(v.value))

extension [T1 <: DataType](col1: Column[T1])
inline def +[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Plus[T1, T2]): Column[op.Out] = op(col1, col2)
inline def -[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Minus[T1, T2]): Column[op.Out] = op(col1, col2)
inline def *[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Mult[T1, T2]): Column[op.Out] = op(col1, col2)
inline def /[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Div[T1, T2]): Column[op.Out] = op(col1, col2)
inline def ++[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.PlusPlus[T1, T2]): Column[op.Out] = op(col1, col2)
inline def <[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Lt[T1, T2]): Column[op.Out] = op(col1, col2)
inline def <=[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Le[T1, T2]): Column[op.Out] = op(col1, col2)
inline def >[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Gt[T1, T2]): Column[op.Out] = op(col1, col2)
inline def >=[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Ge[T1, T2]): Column[op.Out] = op(col1, col2)
inline def ===[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Eq[T1, T2]): Column[op.Out] = op(col1, col2)
inline def =!=[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Ne[T1, T2]): Column[op.Out] = op(col1, col2)
inline def &&[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.And[T1, T2]): Column[op.Out] = op(col1, col2)
inline def ||[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Or[T1, T2]): Column[op.Out] = op(col1, col2)
extension [T1 <: DataType](col1: Col[T1])
inline def +[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Plus[T1, T2]): Col[op.Out] = op(col1, col2)
inline def -[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Minus[T1, T2]): Col[op.Out] = op(col1, col2)
inline def *[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Mult[T1, T2]): Col[op.Out] = op(col1, col2)
inline def /[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Div[T1, T2]): Col[op.Out] = op(col1, col2)
inline def ++[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.PlusPlus[T1, T2]): Col[op.Out] = op(col1, col2)
inline def <[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Lt[T1, T2]): Col[op.Out] = op(col1, col2)
inline def <=[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Le[T1, T2]): Col[op.Out] = op(col1, col2)
inline def >[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Gt[T1, T2]): Col[op.Out] = op(col1, col2)
inline def >=[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Ge[T1, T2]): Col[op.Out] = op(col1, col2)
inline def ===[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Eq[T1, T2]): Col[op.Out] = op(col1, col2)
inline def =!=[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Ne[T1, T2]): Col[op.Out] = op(col1, col2)
inline def &&[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.And[T1, T2]): Col[op.Out] = op(col1, col2)
inline def ||[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Or[T1, T2]): Col[op.Out] = op(col1, col2)

class Col[+T <: DataType](untyped: UntypedColumn) extends Column(untyped)

@annotation.showAsInfix
class :=[L <: LabeledColumn.Label, T <: DataType](untyped: UntypedColumn)
extends Column[T](untyped)
extends Col[T](untyped)
with NamedColumns[(L := T) *: EmptyTuple](Seq(untyped))

@annotation.showAsInfix
Expand Down
2 changes: 1 addition & 1 deletion src/main/ColumnOp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package org.virtuslab.iskra
import scala.quoted.*
import org.apache.spark.sql
import org.apache.spark.sql.functions.concat
import org.virtuslab.iskra.{Column as Col}
import org.virtuslab.iskra.Col
import org.virtuslab.iskra.UntypedOps.typed
import org.virtuslab.iskra.types.*
import DataType.*
Expand Down
2 changes: 1 addition & 1 deletion src/main/JoinOnCondition.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ object JoinOnCondition:
import quotes.reflect.*

'{ ${ condition }(using ${ joiningView }) } match
case '{ $cond: Column[BooleanOptType] } =>
case '{ $cond: Col[BooleanOptType] } =>
'{
val joined = ${ join }.left.join(${ join }.right, ${ cond }.untyped, JoinType.typeName[T])
StructDataFrame[JoinedSchema](joined)
Expand Down
2 changes: 1 addition & 1 deletion src/main/SchemaView.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ object StructSchemaView:
// headLabel1 match
// case '[Name.Subtype[name]] => // TODO: handle frame prefixes
// val label = Expr(Type.valueOfConstant[name].get.toString)
// '{ Column[Nothing](col(Name.escape(${ label }))) *: ${ reifyCols(Type.of[tail]) } }
// '{ Col[Nothing](col(Name.escape(${ label }))) *: ${ reifyCols(Type.of[tail]) } }

def schemaViewExpr[DF <: StructDataFrame[?] : Type](using Quotes): Expr[StructSchemaView] =
import quotes.reflect.*
Expand Down
2 changes: 1 addition & 1 deletion src/main/UntypedOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import types.{DataType, Encoder, StructType, StructEncoder}

object UntypedOps:
extension (untyped: UntypedColumn)
def typed[A <: DataType] = Column[A](untyped)
def typed[A <: DataType] = Col[A](untyped)

extension (df: UntypedDataFrame)
transparent inline def typed[A](using encoder: StructEncoder[A]): ClassDataFrame[?] = ${ typedDataFrameImpl('df, 'encoder) } // TODO: Check schema at runtime? Check if names of columns match?
Expand Down
10 changes: 5 additions & 5 deletions src/main/When.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import org.apache.spark.sql.{functions => f, Column => UntypedColumn}
import org.virtuslab.iskra.types.{Coerce, DataType, BooleanOptType}

object When:
class WhenColumn[T <: DataType](untyped: UntypedColumn) extends Column[DataType.Nullable[T]](untyped):
def when[U <: DataType](condition: Column[BooleanOptType], value: Column[U])(using coerce: Coerce[T, U]): WhenColumn[coerce.Coerced] =
class WhenColumn[T <: DataType](untyped: UntypedColumn) extends Col[DataType.Nullable[T]](untyped):
def when[U <: DataType](condition: Col[BooleanOptType], value: Col[U])(using coerce: Coerce[T, U]): WhenColumn[coerce.Coerced] =
WhenColumn(this.untyped.when(condition.untyped, value.untyped))
def otherwise[U <: DataType](value: Column[U])(using coerce: Coerce[T, U]): Column[coerce.Coerced] =
Column(this.untyped.otherwise(value.untyped))
def otherwise[U <: DataType](value: Col[U])(using coerce: Coerce[T, U]): Col[coerce.Coerced] =
Col(this.untyped.otherwise(value.untyped))

def when[T <: DataType](condition: Column[BooleanOptType], value: Column[T]): WhenColumn[T] =
def when[T <: DataType](condition: Col[BooleanOptType], value: Col[T]): WhenColumn[T] =
WhenColumn(f.when(condition.untyped, value.untyped))
2 changes: 1 addition & 1 deletion src/main/Where.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ object Where:
import quotes.reflect.*

'{ ${ condition }(using ${ where }.view) } match
case '{ $cond: Column[BooleanOptType] } =>
case '{ $cond: Col[BooleanOptType] } =>
'{
val filtered = ${ where }.underlying.where(${ cond }.untyped)
StructDataFrame[Schema](filtered)
Expand Down
10 changes: 5 additions & 5 deletions src/main/functions/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@ package org.virtuslab.iskra.functions

import org.apache.spark.sql
import org.virtuslab.iskra.Agg
import org.virtuslab.iskra.Column
import org.virtuslab.iskra.Col
import org.virtuslab.iskra.UntypedOps.typed
import org.virtuslab.iskra.types.*
import org.virtuslab.iskra.types.DataType.{NumericOptType, Nullable}

class Sum[A <: Agg](val agg: A):
def apply[T <: NumericOptType](column: agg.View ?=> Column[T]): Column[Nullable[T]] =
def apply[T <: NumericOptType](column: agg.View ?=> Col[T]): Col[Nullable[T]] =
sql.functions.sum(column(using agg.view).untyped).typed

class Max[A <: Agg](val agg: A):
def apply[T <: NumericOptType](column: agg.View ?=> Column[T]): Column[Nullable[T]] =
def apply[T <: NumericOptType](column: agg.View ?=> Col[T]): Col[Nullable[T]] =
sql.functions.max(column(using agg.view).untyped).typed

class Min[A <: Agg](val agg: A):
def apply[T <: NumericOptType](column: agg.View ?=> Column[T]): Column[Nullable[T]] =
def apply[T <: NumericOptType](column: agg.View ?=> Col[T]): Col[Nullable[T]] =
sql.functions.min(column(using agg.view).untyped).typed

class Avg[A <: Agg](val agg: A):
def apply(column: agg.View ?=> Column[NumericOptType]): Column[DoubleOptType] =
def apply(column: agg.View ?=> Col[NumericOptType]): Col[DoubleOptType] =
sql.functions.avg(column(using agg.view).untyped).typed

object Aggregates:
Expand Down
4 changes: 2 additions & 2 deletions src/main/functions/lit.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.virtuslab.iskra.functions

import org.apache.spark.sql
import org.virtuslab.iskra.Column
import org.virtuslab.iskra.Col
import org.virtuslab.iskra.types.PrimitiveEncoder

def lit[A](value: A)(using encoder: PrimitiveEncoder[A]): Column[encoder.ColumnType] = Column(sql.functions.lit(encoder.encode(value)))
def lit[A](value: A)(using encoder: PrimitiveEncoder[A]): Col[encoder.ColumnType] = Col(sql.functions.lit(encoder.encode(value)))

0 comments on commit cf90acc

Please sign in to comment.