Skip to content

Commit

Permalink
fix(arrow/compute): compare kernels with UUID (#174)
Browse files Browse the repository at this point in the history
Split from #171 

Enable using the comparison kernels (equal, less, less_equal, greater,
greater_equal) with UUID columns and extension types in general.

Tests are added to check the kernel dispatch and to ensure compute via
substrait works for UUID type scalars.
  • Loading branch information
zeroshade authored Oct 26, 2024
1 parent fe4bd93 commit 18d6677
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 5 deletions.
22 changes: 21 additions & 1 deletion arrow/compute/exprs/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,25 @@ func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.E
return nil, err
}

var newArgs []compute.Datum
// cast arguments if necessary
for i, arg := range args {
if !arrow.TypeEqual(argTypes[i], arg.(compute.ArrayLikeDatum).Type()) {
if newArgs == nil {
newArgs = make([]compute.Datum, len(args))
copy(newArgs, args)
}
newArgs[i], err = compute.CastDatum(ctx, arg, compute.SafeCastOptions(argTypes[i]))
if err != nil {
return nil, err
}
defer newArgs[i].Release()
}
}
if newArgs != nil {
args = newArgs
}

kctx := &exec.KernelCtx{Ctx: ctx, Kernel: k}
init := k.GetInitFn()
kinitArgs := exec.KernelInitArgs{Kernel: k, Inputs: argTypes, Options: opts}
Expand Down Expand Up @@ -611,9 +630,10 @@ func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.E

if ctx.Err() == context.Canceled && result != nil {
result.Release()
result = nil
}

return result, nil
return result, err
}

return nil, arrow.ErrNotImplemented
Expand Down
23 changes: 21 additions & 2 deletions arrow/compute/exprs/exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ import (
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/compute"
"github.com/apache/arrow-go/v18/arrow/compute/exprs"
"github.com/apache/arrow-go/v18/arrow/extensions"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/apache/arrow-go/v18/arrow/scalar"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/expr"
Expand Down Expand Up @@ -135,8 +137,16 @@ func TestComparisons(t *testing.T) {
one = scalar.MakeScalar(int32(1))
two = scalar.MakeScalar(int32(2))

str = scalar.MakeScalar("hello")
bin = scalar.MakeScalar([]byte("hello"))
str = scalar.MakeScalar("hello")
bin = scalar.MakeScalar([]byte("hello"))
exampleUUID = uuid.MustParse("102cb62f-e6f8-4eb0-9973-d9b012ff0967")
exampleUUID2 = uuid.MustParse("c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b")
uuidStorage, _ = scalar.MakeScalarParam(exampleUUID[:],
&arrow.FixedSizeBinaryType{ByteWidth: 16})
uuidScalar = scalar.NewExtensionScalar(uuidStorage, extensions.NewUUIDType())
uuidStorage2, _ = scalar.MakeScalarParam(exampleUUID2[:],
&arrow.FixedSizeBinaryType{ByteWidth: 16})
uuidScalar2 = scalar.NewExtensionScalar(uuidStorage2, extensions.NewUUIDType())
)

getArgType := func(dt arrow.DataType) types.Type {
Expand All @@ -147,6 +157,8 @@ func TestComparisons(t *testing.T) {
return &types.StringType{}
case arrow.BINARY:
return &types.BinaryType{}
case arrow.EXTENSION:
return &types.UUIDType{}
}
panic("wtf")
}
Expand Down Expand Up @@ -190,6 +202,13 @@ func TestComparisons(t *testing.T) {

expect(t, "equal", str, bin, true)
expect(t, "equal", bin, str, true)

expect(t, "equal", uuidScalar, uuidScalar, true)
expect(t, "equal", uuidScalar, uuidScalar2, false)
expect(t, "less", uuidScalar, uuidScalar2, true)
expect(t, "less", uuidScalar2, uuidScalar, false)
expect(t, "greater", uuidScalar, uuidScalar2, false)
expect(t, "greater", uuidScalar2, uuidScalar, true)
}

func TestExecuteFieldRef(t *testing.T) {
Expand Down
8 changes: 6 additions & 2 deletions arrow/compute/exprs/extension_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (ef *simpleExtensionTypeFactory[P]) ExtensionEquals(other arrow.ExtensionTy
return ef.params == rhs.params
}
func (ef *simpleExtensionTypeFactory[P]) ArrayType() reflect.Type {
return reflect.TypeOf(array.ExtensionArrayBase{})
return reflect.TypeOf(simpleExtensionArrayFactory[P]{})
}

func (ef *simpleExtensionTypeFactory[P]) CreateType(params P) arrow.DataType {
Expand All @@ -91,10 +91,14 @@ func (ef *simpleExtensionTypeFactory[P]) CreateType(params P) arrow.DataType {
}
}

type simpleExtensionArrayFactory[P comparable] struct {
array.ExtensionArrayBase
}

type uuidExtParams struct{}

var uuidType = simpleExtensionTypeFactory[uuidExtParams]{
name: "uuid", getStorage: func(uuidExtParams) arrow.DataType {
name: "arrow.uuid", getStorage: func(uuidExtParams) arrow.DataType {
return &arrow.FixedSizeBinaryType{ByteWidth: 16}
}}

Expand Down
1 change: 1 addition & 0 deletions arrow/compute/scalar_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func (fn *compareFunction) DispatchBest(vals ...arrow.DataType) (exec.Kernel, er
}

ensureDictionaryDecoded(vals...)
ensureNoExtensionType(vals...)
replaceNullWithOtherType(vals...)

if dt := commonNumeric(vals...); dt != nil {
Expand Down
2 changes: 2 additions & 0 deletions arrow/compute/scalar_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/apache/arrow-go/v18/arrow/compute"
"github.com/apache/arrow-go/v18/arrow/compute/exec"
"github.com/apache/arrow-go/v18/arrow/compute/internal/kernels"
"github.com/apache/arrow-go/v18/arrow/extensions"
"github.com/apache/arrow-go/v18/arrow/internal/testing/gen"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/apache/arrow-go/v18/arrow/scalar"
Expand Down Expand Up @@ -1289,6 +1290,7 @@ func TestCompareKernelsDispatchBest(t *testing.T) {
&arrow.Decimal128Type{Precision: 3, Scale: 2}, &arrow.Decimal128Type{Precision: 21, Scale: 2}},
{arrow.PrimitiveTypes.Int64, &arrow.Decimal128Type{Precision: 3, Scale: 2},
&arrow.Decimal128Type{Precision: 21, Scale: 2}, &arrow.Decimal128Type{Precision: 3, Scale: 2}},
{extensions.NewUUIDType(), extensions.NewUUIDType(), &arrow.FixedSizeBinaryType{ByteWidth: 16}, &arrow.FixedSizeBinaryType{ByteWidth: 16}},
}

for _, name := range []string{"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"} {
Expand Down
8 changes: 8 additions & 0 deletions arrow/compute/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ func ensureDictionaryDecoded(vals ...arrow.DataType) {
}
}

func ensureNoExtensionType(vals ...arrow.DataType) {
for i, v := range vals {
if v.ID() == arrow.EXTENSION {
vals[i] = v.(arrow.ExtensionType).StorageType()
}
}
}

func replaceNullWithOtherType(vals ...arrow.DataType) {
debug.Assert(len(vals) == 2, "should be length 2")

Expand Down
5 changes: 5 additions & 0 deletions arrow/extensions/uuid.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ func (*UUIDType) ExtensionName() string {
return "arrow.uuid"
}

func (*UUIDType) Bytes() int { return 16 }
func (*UUIDType) BitWidth() int { return 128 }

func (e *UUIDType) String() string {
return fmt.Sprintf("extension<%s>", e.ExtensionName())
}
Expand Down Expand Up @@ -262,4 +265,6 @@ var (
_ array.CustomExtensionBuilder = (*UUIDType)(nil)
_ array.ExtensionArray = (*UUIDArray)(nil)
_ array.Builder = (*UUIDBuilder)(nil)

_ arrow.FixedWidthDataType = (*UUIDType)(nil)
)

0 comments on commit 18d6677

Please sign in to comment.