Skip to content

Commit

Permalink
fix: compute fixes for extension types
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroshade committed Oct 24, 2024
1 parent 56b794f commit c779c38
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 4 deletions.
21 changes: 20 additions & 1 deletion arrow/compute/exprs/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,25 @@ func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.E
return nil, err
}

var newArgs []compute.Datum
// cast arguments if necessary
for i, arg := range args {
if !arrow.TypeEqual(argTypes[i], arg.(compute.ArrayLikeDatum).Type()) {
if newArgs == nil {
newArgs = make([]compute.Datum, len(args))
copy(newArgs, args)
}
newArgs[i], err = compute.CastDatum(ctx, arg, compute.SafeCastOptions(argTypes[i]))
if err != nil {
return nil, err
}
defer newArgs[i].Release()
}
}
if newArgs != nil {
args = newArgs
}

kctx := &exec.KernelCtx{Ctx: ctx, Kernel: k}
init := k.GetInitFn()
kinitArgs := exec.KernelInitArgs{Kernel: k, Inputs: argTypes, Options: opts}
Expand Down Expand Up @@ -613,7 +632,7 @@ func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.E
result.Release()
}

return result, nil
return result, err
}

return nil, arrow.ErrNotImplemented
Expand Down
8 changes: 6 additions & 2 deletions arrow/compute/exprs/extension_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (ef *simpleExtensionTypeFactory[P]) ExtensionEquals(other arrow.ExtensionTy
return ef.params == rhs.params
}
func (ef *simpleExtensionTypeFactory[P]) ArrayType() reflect.Type {
return reflect.TypeOf(array.ExtensionArrayBase{})
return reflect.TypeOf(simpleExtensionArrayFactory[P]{})
}

func (ef *simpleExtensionTypeFactory[P]) CreateType(params P) arrow.DataType {
Expand All @@ -91,10 +91,14 @@ func (ef *simpleExtensionTypeFactory[P]) CreateType(params P) arrow.DataType {
}
}

type simpleExtensionArrayFactory[P comparable] struct {
array.ExtensionArrayBase
}

type uuidExtParams struct{}

var uuidType = simpleExtensionTypeFactory[uuidExtParams]{
name: "uuid", getStorage: func(uuidExtParams) arrow.DataType {
name: "arrow.uuid", getStorage: func(uuidExtParams) arrow.DataType {
return &arrow.FixedSizeBinaryType{ByteWidth: 16}
}}

Expand Down
1 change: 1 addition & 0 deletions arrow/compute/scalar_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func (fn *compareFunction) DispatchBest(vals ...arrow.DataType) (exec.Kernel, er
}

ensureDictionaryDecoded(vals...)
ensureNotExtensionType(vals...)
replaceNullWithOtherType(vals...)

if dt := commonNumeric(vals...); dt != nil {
Expand Down
8 changes: 8 additions & 0 deletions arrow/compute/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ func ensureDictionaryDecoded(vals ...arrow.DataType) {
}
}

func ensureNotExtensionType(vals ...arrow.DataType) {
for i, v := range vals {
if v.ID() == arrow.EXTENSION {
vals[i] = v.(arrow.ExtensionType).StorageType()
}
}
}

func replaceNullWithOtherType(vals ...arrow.DataType) {
debug.Assert(len(vals) == 2, "should be length 2")

Expand Down
4 changes: 4 additions & 0 deletions arrow/extensions/uuid.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ func (*UUIDType) ExtensionName() string {
return "arrow.uuid"
}

func (*UUIDType) Bytes() int { return 16 }

func (*UUIDType) BitWidth() int { return 128 }

func (e *UUIDType) String() string {
return fmt.Sprintf("extension<%s>", e.ExtensionName())
}
Expand Down
5 changes: 4 additions & 1 deletion parquet/pqarrow/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/decimal128"
"github.com/apache/arrow-go/v18/arrow/extensions"
"github.com/apache/arrow-go/v18/arrow/flight"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/apache/arrow-go/v18/parquet"
Expand Down Expand Up @@ -514,8 +515,10 @@ func arrowFromFLBA(logical schema.LogicalType, length int) (arrow.DataType, erro
switch logtype := logical.(type) {
case schema.DecimalLogicalType:
return arrowDecimal(logtype), nil
case schema.NoLogicalType, schema.IntervalLogicalType, schema.UUIDLogicalType:
case schema.NoLogicalType, schema.IntervalLogicalType:
return &arrow.FixedSizeBinaryType{ByteWidth: int(length)}, nil
case schema.UUIDLogicalType:
return extensions.NewUUIDType(), nil
case schema.Float16LogicalType:
return &arrow.Float16Type{}, nil
default:
Expand Down

0 comments on commit c779c38

Please sign in to comment.