diff --git a/arrow/array/compare.go b/arrow/array/compare.go index ad3a50b8..dbebbfb3 100644 --- a/arrow/array/compare.go +++ b/arrow/array/compare.go @@ -19,6 +19,7 @@ package array import ( "fmt" "math" + "strings" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/float16" @@ -487,19 +488,19 @@ func arrayApproxEqual(left, right arrow.Array, opt equalOption) bool { return arrayEqualBinary(l, r) case *String: r := right.(*String) - return arrayEqualString(l, r) + return arrayApproxEqualString(l, r) case *LargeBinary: r := right.(*LargeBinary) return arrayEqualLargeBinary(l, r) case *LargeString: r := right.(*LargeString) - return arrayEqualLargeString(l, r) + return arrayApproxEqualLargeString(l, r) case *BinaryView: r := right.(*BinaryView) return arrayEqualBinaryView(l, r) case *StringView: r := right.(*StringView) - return arrayEqualStringView(l, r) + return arrayApproxEqualStringView(l, r) case *Int8: r := right.(*Int8) return arrayEqualInt8(l, r) @@ -644,6 +645,46 @@ func validityBitmapEqual(left, right arrow.Array) bool { return true } +func arrayApproxEqualString(left, right *String) bool { + for i := 0; i < left.Len(); i++ { + if left.IsNull(i) { + continue + } + if stripNulls(left.Value(i)) != stripNulls(right.Value(i)) { + return false + } + } + return true +} + +func arrayApproxEqualLargeString(left, right *LargeString) bool { + for i := 0; i < left.Len(); i++ { + if left.IsNull(i) { + continue + } + if stripNulls(left.Value(i)) != stripNulls(right.Value(i)) { + return false + } + } + return true +} + +func arrayApproxEqualStringView(left, right *StringView) bool { + for i := 0; i < left.Len(); i++ { + if left.IsNull(i) { + continue + } + if stripNulls(left.Value(i)) != stripNulls(right.Value(i)) { + return false + } + } + return true +} + +func stripNulls(s string) string { + return strings.TrimRight(s, "\x00") +} + func arrayApproxEqualFloat16(left, right *Float16, opt equalOption) bool { for i := 0; i < left.Len(); i++ { if left.IsNull(i) { diff --git a/arrow/array/compare_test.go b/arrow/array/compare_test.go index 3059ed31..5c569f25 100644 --- a/arrow/array/compare_test.go +++ b/arrow/array/compare_test.go @@ -111,6 +111,94 @@ func TestArrayApproxEqual(t *testing.T) { } } +func TestArrayApproxEqualStrings(t *testing.T) { + for _, tc := range []struct { + name string + a1 interface{} + a2 interface{} + want bool + }{ + { + name: "string", + a1: []string{"a", "b", "c", "d", "e", "f"}, + a2: []string{"a", "b", "c", "d", "e", "f"}, + want: true, + }, + { + name: "string", + a1: []string{"a", "b\x00"}, + a2: []string{"a", "b"}, + want: true, + }, + { + name: "string", + a1: []string{"a", "b\x00"}, + a2: []string{"a\x00", "b"}, + want: true, + }, + { + name: "equal large strings", + a1: []string{"a", "b", "c", "d", "e", "f"}, + a2: []string{"a", "b", "c", "d", "e", "f"}, + want: true, + }, + { + name: "equal large strings with nulls", + a1: []string{"a", "b\x00"}, + a2: []string{"a", "b"}, + want: true, + }, + { + name: "equal large strings with nulls in both", + a1: []string{"Apache", "Arrow\x00"}, + a2: []string{"Apache\x00", "Arrow"}, + want: true, + }, + { + name: "equal string views", + a1: []string{"a", "b", "c", "d", "e", "f"}, + a2: []string{"a", "b", "c", "d", "e", "f"}, + want: true, + }, + { + name: "equal string views with nulls", + a1: []string{"Apache", "Arrow\x00"}, + a2: []string{"Apache", "Arrow"}, + want: true, + }, + { + name: "equal string views with nulls in both", + a1: []string{"Apache", "Arrow\x00"}, + a2: []string{"Apache\x00", "Arrow"}, + want: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + var a1, a2 arrow.Array + switch tc.name { + case "equal large strings", "equal large strings with nulls", "equal large strings with nulls in both": + a1 = arrayOfLargeString(mem, tc.a1.([]string), nil) + a2 = arrayOfLargeString(mem, tc.a2.([]string), nil) + case "equal string views", "equal string views with nulls", "equal string views with nulls in both": + a1 = arrayOfStringView(mem, tc.a1.([]string), nil) + a2 = arrayOfStringView(mem, tc.a2.([]string), nil) + default: + a1 = arrayOf(mem, tc.a1, nil) + a2 = arrayOf(mem, tc.a2, nil) + } + defer a1.Release() + defer a2.Release() + + if got, want := array.ApproxEqual(a1, a2), tc.want; got != want { + t.Fatalf("invalid comparison: got=%v, want=%v\na1: %v\na2: %v\n", got, want, a1, a2) + } + }) + } +} + func TestArrayApproxEqualFloats(t *testing.T) { f16sFrom := func(vs []float64) []float16.Num { o := make([]float16.Num, len(vs)) @@ -445,11 +533,34 @@ func arrayOf(mem memory.Allocator, a interface{}, valids []bool) arrow.Array { bldr.AppendValues(a, valids) return bldr.NewFloat64Array() + case []string: + bldr := array.NewStringBuilder(mem) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewStringArray() + default: panic(fmt.Errorf("arrdata: invalid data slice type %T", a)) } } +func arrayOfLargeString(mem memory.Allocator, a []string, valids []bool) arrow.Array { + bldr := array.NewLargeStringBuilder(mem) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewLargeStringArray() +} + +func arrayOfStringView(mem memory.Allocator, a []string, valids []bool) arrow.Array { + bldr := array.NewStringViewBuilder(mem) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewStringViewArray() +} + func TestArrayEqualBaseArray(t *testing.T) { mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) defer mem.AssertSize(t, 0)