Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroshade committed Oct 25, 2024
1 parent c779c38 commit 2189067
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
13 changes: 11 additions & 2 deletions arrow/compute/exprs/exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ import (
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/compute"
"github.com/apache/arrow-go/v18/arrow/compute/exprs"
"github.com/apache/arrow-go/v18/arrow/extensions"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/apache/arrow-go/v18/arrow/scalar"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/expr"
Expand Down Expand Up @@ -135,8 +137,12 @@ func TestComparisons(t *testing.T) {
one = scalar.MakeScalar(int32(1))
two = scalar.MakeScalar(int32(2))

str = scalar.MakeScalar("hello")
bin = scalar.MakeScalar([]byte("hello"))
str = scalar.MakeScalar("hello")
bin = scalar.MakeScalar([]byte("hello"))
exampleUUID = uuid.MustParse("102cb62f-e6f8-4eb0-9973-d9b012ff0967")
uidStorage, _ = scalar.MakeScalarParam(exampleUUID[:],
&arrow.FixedSizeBinaryType{ByteWidth: 16})
uid = scalar.NewExtensionScalar(uidStorage, extensions.NewUUIDType())
)

getArgType := func(dt arrow.DataType) types.Type {
Expand All @@ -147,6 +153,8 @@ func TestComparisons(t *testing.T) {
return &types.StringType{}
case arrow.BINARY:
return &types.BinaryType{}
case arrow.EXTENSION:
return &types.UUIDType{}
}
panic("wtf")
}
Expand Down Expand Up @@ -183,6 +191,7 @@ func TestComparisons(t *testing.T) {

expect(t, "equal", one, one, true)
expect(t, "equal", one, two, false)
expect(t, "equal", uid, uid, true)
expect(t, "less", one, two, true)
expect(t, "less", one, zero, false)
expect(t, "greater", one, zero, true)
Expand Down
3 changes: 3 additions & 0 deletions arrow/compute/scalar_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/apache/arrow-go/v18/arrow/compute"
"github.com/apache/arrow-go/v18/arrow/compute/exec"
"github.com/apache/arrow-go/v18/arrow/compute/internal/kernels"
"github.com/apache/arrow-go/v18/arrow/extensions"
"github.com/apache/arrow-go/v18/arrow/internal/testing/gen"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/apache/arrow-go/v18/arrow/scalar"
Expand Down Expand Up @@ -1289,6 +1290,8 @@ func TestCompareKernelsDispatchBest(t *testing.T) {
&arrow.Decimal128Type{Precision: 3, Scale: 2}, &arrow.Decimal128Type{Precision: 21, Scale: 2}},
{arrow.PrimitiveTypes.Int64, &arrow.Decimal128Type{Precision: 3, Scale: 2},
&arrow.Decimal128Type{Precision: 21, Scale: 2}, &arrow.Decimal128Type{Precision: 3, Scale: 2}},

{extensions.NewUUIDType(), extensions.NewUUIDType(), &arrow.FixedSizeBinaryType{ByteWidth: 16}, &arrow.FixedSizeBinaryType{ByteWidth: 16}},
}

for _, name := range []string{"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"} {
Expand Down
2 changes: 2 additions & 0 deletions parquet/pqarrow/encode_arrow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2057,6 +2057,8 @@ func (ps *ParquetIOTestSuite) TestArrowExtensionTypeRoundTrip() {
defer tbl.Release()

ps.roundTripTable(mem, tbl, true)
// ensure we get UUID back even without storing the schema
ps.roundTripTable(mem, tbl, false)
}

func (ps *ParquetIOTestSuite) TestArrowUnknownExtensionTypeRoundTrip() {
Expand Down
14 changes: 8 additions & 6 deletions parquet/pqarrow/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -987,13 +987,15 @@ func applyOriginalStorageMetadata(origin arrow.Field, inferred *SchemaField) (mo
return
}

if !arrow.TypeEqual(extType.StorageType(), inferred.Field.Type) {
return modified, fmt.Errorf("%w: mismatch storage type '%s' for extension type '%s'",
arrow.ErrInvalid, inferred.Field.Type, extType)
}
if modified || !arrow.TypeEqual(extType, inferred.Field.Type) {
if !arrow.TypeEqual(extType.StorageType(), inferred.Field.Type) {
return modified, fmt.Errorf("%w: mismatch storage type '%s' for extension type '%s'",
arrow.ErrInvalid, inferred.Field.Type, extType)
}

inferred.Field.Type = extType
modified = true
inferred.Field.Type = extType
modified = true
}
case arrow.SPARSE_UNION, arrow.DENSE_UNION:
err = xerrors.New("unimplemented type")
case arrow.STRUCT:
Expand Down

0 comments on commit 2189067

Please sign in to comment.