diff --git a/arrow/compute/exprs/exec_test.go b/arrow/compute/exprs/exec_test.go index c2a1c27b..96863280 100644 --- a/arrow/compute/exprs/exec_test.go +++ b/arrow/compute/exprs/exec_test.go @@ -27,8 +27,10 @@ import ( "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/compute" "github.com/apache/arrow-go/v18/arrow/compute/exprs" + "github.com/apache/arrow-go/v18/arrow/extensions" "github.com/apache/arrow-go/v18/arrow/memory" "github.com/apache/arrow-go/v18/arrow/scalar" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/expr" @@ -135,8 +137,12 @@ func TestComparisons(t *testing.T) { one = scalar.MakeScalar(int32(1)) two = scalar.MakeScalar(int32(2)) - str = scalar.MakeScalar("hello") - bin = scalar.MakeScalar([]byte("hello")) + str = scalar.MakeScalar("hello") + bin = scalar.MakeScalar([]byte("hello")) + exampleUUID = uuid.MustParse("102cb62f-e6f8-4eb0-9973-d9b012ff0967") + uidStorage, _ = scalar.MakeScalarParam(exampleUUID[:], + &arrow.FixedSizeBinaryType{ByteWidth: 16}) + uid = scalar.NewExtensionScalar(uidStorage, extensions.NewUUIDType()) ) getArgType := func(dt arrow.DataType) types.Type { @@ -147,6 +153,8 @@ func TestComparisons(t *testing.T) { return &types.StringType{} case arrow.BINARY: return &types.BinaryType{} + case arrow.EXTENSION: + return &types.UUIDType{} } panic("wtf") } @@ -183,6 +191,7 @@ func TestComparisons(t *testing.T) { expect(t, "equal", one, one, true) expect(t, "equal", one, two, false) + expect(t, "equal", uid, uid, true) expect(t, "less", one, two, true) expect(t, "less", one, zero, false) expect(t, "greater", one, zero, true) diff --git a/arrow/compute/scalar_compare_test.go b/arrow/compute/scalar_compare_test.go index ba7e110f..c769b458 100644 --- a/arrow/compute/scalar_compare_test.go +++ b/arrow/compute/scalar_compare_test.go @@ -30,6 +30,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/compute" "github.com/apache/arrow-go/v18/arrow/compute/exec" "github.com/apache/arrow-go/v18/arrow/compute/internal/kernels" + "github.com/apache/arrow-go/v18/arrow/extensions" "github.com/apache/arrow-go/v18/arrow/internal/testing/gen" "github.com/apache/arrow-go/v18/arrow/memory" "github.com/apache/arrow-go/v18/arrow/scalar" @@ -1289,6 +1290,8 @@ func TestCompareKernelsDispatchBest(t *testing.T) { &arrow.Decimal128Type{Precision: 3, Scale: 2}, &arrow.Decimal128Type{Precision: 21, Scale: 2}}, {arrow.PrimitiveTypes.Int64, &arrow.Decimal128Type{Precision: 3, Scale: 2}, &arrow.Decimal128Type{Precision: 21, Scale: 2}, &arrow.Decimal128Type{Precision: 3, Scale: 2}}, + + {extensions.NewUUIDType(), extensions.NewUUIDType(), &arrow.FixedSizeBinaryType{ByteWidth: 16}, &arrow.FixedSizeBinaryType{ByteWidth: 16}}, } for _, name := range []string{"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"} { diff --git a/parquet/pqarrow/encode_arrow_test.go b/parquet/pqarrow/encode_arrow_test.go index 1ff1710b..b79734e1 100644 --- a/parquet/pqarrow/encode_arrow_test.go +++ b/parquet/pqarrow/encode_arrow_test.go @@ -2057,6 +2057,8 @@ func (ps *ParquetIOTestSuite) TestArrowExtensionTypeRoundTrip() { defer tbl.Release() ps.roundTripTable(mem, tbl, true) + // ensure we get UUID back even without storing the schema + ps.roundTripTable(mem, tbl, false) } func (ps *ParquetIOTestSuite) TestArrowUnknownExtensionTypeRoundTrip() { diff --git a/parquet/pqarrow/schema.go b/parquet/pqarrow/schema.go index 67631ed2..9930a15b 100644 --- a/parquet/pqarrow/schema.go +++ b/parquet/pqarrow/schema.go @@ -987,13 +987,15 @@ func applyOriginalStorageMetadata(origin arrow.Field, inferred *SchemaField) (mo return } - if !arrow.TypeEqual(extType.StorageType(), inferred.Field.Type) { - return modified, fmt.Errorf("%w: mismatch storage type '%s' for extension type '%s'", - arrow.ErrInvalid, inferred.Field.Type, extType) - } + if modified || !arrow.TypeEqual(extType, inferred.Field.Type) { + if !arrow.TypeEqual(extType.StorageType(), inferred.Field.Type) { + return modified, fmt.Errorf("%w: mismatch storage type '%s' for extension type '%s'", + arrow.ErrInvalid, inferred.Field.Type, extType) + } - inferred.Field.Type = extType - modified = true + inferred.Field.Type = extType + modified = true + } case arrow.SPARSE_UNION, arrow.DENSE_UNION: err = xerrors.New("unimplemented type") case arrow.STRUCT: