Skip to content

Commit

Permalink
feat(arrow/compute): make is_nan dispatchable (#177)
Browse files Browse the repository at this point in the history
Currently the `is_nan` compute kernel is a `MetaFunction` which cannot
be dispatched via kernel dispatch making it only usable via calling it
directly with `CallFunction`. By shifting it to be a proper function
instead of a `MetaFunction` this improves its compatibility and makes it
able to be dispatched and thus called through the substrait interface in
the `exprs` package.
  • Loading branch information
zeroshade authored Oct 29, 2024
1 parent 8e2553a commit f0c5d99
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 34 deletions.
52 changes: 52 additions & 0 deletions arrow/compute/internal/kernels/scalar_comparisons.go
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,55 @@ func IsNullNotNullKernels() []exec.ScalarKernel {

return results
}

func ConstBoolExec(val bool) func(*exec.KernelCtx, *exec.ExecSpan, *exec.ExecResult) error {
return func(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
bitutil.SetBitsTo(out.Buffers[1].Buf, out.Offset, batch.Len, val)
return nil
}
}

func isNanKernelExec[T float32 | float64](ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
kn := ctx.Kernel.(*exec.ScalarKernel)
knData := kn.Data.(CompareFuncData).Funcs()

outPrefix := int(out.Offset % 8)
outBuf := out.Buffers[1].Buf[out.Offset/8:]

inputBytes := getOffsetSpanBytes(&batch.Values[0].Array)
knData.funcAA(inputBytes, inputBytes, outBuf, outPrefix)
return nil
}

func IsNaNKernels() []exec.ScalarKernel {
outputType := exec.NewOutputType(arrow.FixedWidthTypes.Boolean)

knFloat32 := exec.NewScalarKernel([]exec.InputType{exec.NewExactInput(arrow.PrimitiveTypes.Float32)},
outputType, isNanKernelExec[float32], nil)
knFloat32.Data = genCompareKernel[float32](CmpNE)
knFloat32.NullHandling = exec.NullNoOutput
knFloat64 := exec.NewScalarKernel([]exec.InputType{exec.NewExactInput(arrow.PrimitiveTypes.Float64)},
outputType, isNanKernelExec[float64], nil)
knFloat64.Data = genCompareKernel[float64](CmpNE)
knFloat64.NullHandling = exec.NullNoOutput

kernels := []exec.ScalarKernel{knFloat32, knFloat64}

for _, dt := range intTypes {
kn := exec.NewScalarKernel(
[]exec.InputType{exec.NewExactInput(dt)},
outputType, ConstBoolExec(false), nil)
kn.NullHandling = exec.NullNoOutput
kernels = append(kernels, kn)
}

for _, id := range []arrow.Type{arrow.NULL, arrow.DURATION, arrow.DECIMAL32, arrow.DECIMAL64, arrow.DECIMAL128, arrow.DECIMAL256} {
kn := exec.NewScalarKernel(
[]exec.InputType{exec.NewIDInput(id)},
outputType, ConstBoolExec(false), nil)
kn.NullHandling = exec.NullNoOutput
kernels = append(kernels, kn)
}

return kernels
}
39 changes: 7 additions & 32 deletions arrow/compute/scalar_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ package compute

import (
"context"
"fmt"

"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/compute/exec"
"github.com/apache/arrow-go/v18/arrow/compute/internal/kernels"
"github.com/apache/arrow-go/v18/arrow/scalar"
)

type compareFunction struct {
Expand Down Expand Up @@ -152,34 +150,11 @@ func RegisterScalarComparisons(reg FunctionRegistry) {
reg.AddFunction(isNullFn, false)
reg.AddFunction(isNotNullFn, false)

reg.AddFunction(NewMetaFunction("is_nan", Unary(), EmptyFuncDoc,
func(ctx context.Context, opts FunctionOptions, args ...Datum) (Datum, error) {
type hasType interface {
Type() arrow.DataType
}

// only Scalar, Array and ChunkedArray have a Type method
arg, ok := args[0].(hasType)
if !ok {
// don't support Table/Record/None kinds
return nil, fmt.Errorf("%w: unsupported type for is_nan %s",
arrow.ErrNotImplemented, args[0])
}

switch arg.Type() {
case arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Float64:
return CallFunction(ctx, "not_equal", nil, args[0], args[0])
default:
if arg, ok := args[0].(ArrayLikeDatum); ok {
result, err := scalar.MakeArrayFromScalar(scalar.NewBooleanScalar(false),
int(arg.Len()), GetAllocator(ctx))
if err != nil {
return nil, err
}
return NewDatumWithoutOwning(result), nil
}

return NewDatum(false), nil
}
}), false)
isNaNFn := &compareFunction{*NewScalarFunction("is_nan", Unary(), EmptyFuncDoc)}
for _, k := range kernels.IsNaNKernels() {
if err := isNaNFn.AddKernel(k); err != nil {
panic(err)
}
}
reg.AddFunction(isNaNFn, false)
}
4 changes: 2 additions & 2 deletions arrow/compute/scalar_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1497,8 +1497,8 @@ func (sv *ScalarValiditySuite) TestIsNaN() {
}{
{`[]`, `[]`},
{`[1]`, `[false]`},
{`[null]`, `[null]`},
{`["NaN", 1, 0, null]`, `[true, false, false, null]`},
{`[null]`, `[false]`},
{`["NaN", 1, 0, null]`, `[true, false, false, false]`},
}

for _, typ := range floatingTypes {
Expand Down

0 comments on commit f0c5d99

Please sign in to comment.