diff --git a/arrow/compute/internal/kernels/scalar_comparisons.go b/arrow/compute/internal/kernels/scalar_comparisons.go index b30605bf..e4a50540 100644 --- a/arrow/compute/internal/kernels/scalar_comparisons.go +++ b/arrow/compute/internal/kernels/scalar_comparisons.go @@ -745,3 +745,55 @@ func IsNullNotNullKernels() []exec.ScalarKernel { return results } + +func ConstBoolExec(val bool) func(*exec.KernelCtx, *exec.ExecSpan, *exec.ExecResult) error { + return func(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + bitutil.SetBitsTo(out.Buffers[1].Buf, out.Offset, batch.Len, val) + return nil + } +} + +func isNanKernelExec[T float32 | float64](ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + kn := ctx.Kernel.(*exec.ScalarKernel) + knData := kn.Data.(CompareFuncData).Funcs() + + outPrefix := int(out.Offset % 8) + outBuf := out.Buffers[1].Buf[out.Offset/8:] + + inputBytes := getOffsetSpanBytes(&batch.Values[0].Array) + knData.funcAA(inputBytes, inputBytes, outBuf, outPrefix) + return nil +} + +func IsNaNKernels() []exec.ScalarKernel { + outputType := exec.NewOutputType(arrow.FixedWidthTypes.Boolean) + + knFloat32 := exec.NewScalarKernel([]exec.InputType{exec.NewExactInput(arrow.PrimitiveTypes.Float32)}, + outputType, isNanKernelExec[float32], nil) + knFloat32.Data = genCompareKernel[float32](CmpNE) + knFloat32.NullHandling = exec.NullNoOutput + knFloat64 := exec.NewScalarKernel([]exec.InputType{exec.NewExactInput(arrow.PrimitiveTypes.Float64)}, + outputType, isNanKernelExec[float64], nil) + knFloat64.Data = genCompareKernel[float64](CmpNE) + knFloat64.NullHandling = exec.NullNoOutput + + kernels := []exec.ScalarKernel{knFloat32, knFloat64} + + for _, dt := range intTypes { + kn := exec.NewScalarKernel( + []exec.InputType{exec.NewExactInput(dt)}, + outputType, ConstBoolExec(false), nil) + kn.NullHandling = exec.NullNoOutput + kernels = append(kernels, kn) + } + + for _, id := range []arrow.Type{arrow.NULL, arrow.DURATION, arrow.DECIMAL32, arrow.DECIMAL64, arrow.DECIMAL128, arrow.DECIMAL256} { + kn := exec.NewScalarKernel( + []exec.InputType{exec.NewIDInput(id)}, + outputType, ConstBoolExec(false), nil) + kn.NullHandling = exec.NullNoOutput + kernels = append(kernels, kn) + } + + return kernels +} diff --git a/arrow/compute/scalar_compare.go b/arrow/compute/scalar_compare.go index cfead2a8..0e853a65 100644 --- a/arrow/compute/scalar_compare.go +++ b/arrow/compute/scalar_compare.go @@ -20,12 +20,10 @@ package compute import ( "context" - "fmt" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/compute/exec" "github.com/apache/arrow-go/v18/arrow/compute/internal/kernels" - "github.com/apache/arrow-go/v18/arrow/scalar" ) type compareFunction struct { @@ -152,34 +150,11 @@ func RegisterScalarComparisons(reg FunctionRegistry) { reg.AddFunction(isNullFn, false) reg.AddFunction(isNotNullFn, false) - reg.AddFunction(NewMetaFunction("is_nan", Unary(), EmptyFuncDoc, - func(ctx context.Context, opts FunctionOptions, args ...Datum) (Datum, error) { - type hasType interface { - Type() arrow.DataType - } - - // only Scalar, Array and ChunkedArray have a Type method - arg, ok := args[0].(hasType) - if !ok { - // don't support Table/Record/None kinds - return nil, fmt.Errorf("%w: unsupported type for is_nan %s", - arrow.ErrNotImplemented, args[0]) - } - - switch arg.Type() { - case arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Float64: - return CallFunction(ctx, "not_equal", nil, args[0], args[0]) - default: - if arg, ok := args[0].(ArrayLikeDatum); ok { - result, err := scalar.MakeArrayFromScalar(scalar.NewBooleanScalar(false), - int(arg.Len()), GetAllocator(ctx)) - if err != nil { - return nil, err - } - return NewDatumWithoutOwning(result), nil - } - - return NewDatum(false), nil - } - }), false) + isNaNFn := &compareFunction{*NewScalarFunction("is_nan", Unary(), EmptyFuncDoc)} + for _, k := range kernels.IsNaNKernels() { + if err := isNaNFn.AddKernel(k); err != nil { + panic(err) + } + } + reg.AddFunction(isNaNFn, false) } diff --git a/arrow/compute/scalar_compare_test.go b/arrow/compute/scalar_compare_test.go index b0c9ab91..e45b3afc 100644 --- a/arrow/compute/scalar_compare_test.go +++ b/arrow/compute/scalar_compare_test.go @@ -1497,8 +1497,8 @@ func (sv *ScalarValiditySuite) TestIsNaN() { }{ {`[]`, `[]`}, {`[1]`, `[false]`}, - {`[null]`, `[null]`}, - {`["NaN", 1, 0, null]`, `[true, false, false, null]`}, + {`[null]`, `[false]`}, + {`["NaN", 1, 0, null]`, `[true, false, false, false]`}, } for _, typ := range floatingTypes {