From c0a53f6b17daaca9d057e70d7fae0a0e9c2cd02a Mon Sep 17 00:00:00 2001 From: Mihai Budiu Date: Mon, 3 Jun 2024 17:43:36 -0700 Subject: [PATCH] [CALCITE-6380] Casts from INTERVAL and STRING to DECIMAL are incorrect Signed-off-by: Mihai Budiu --- .../enumerable/RexToLixTranslator.java | 36 +++++++- .../apache/calcite/util/BuiltInMethod.java | 5 + .../calcite/test/SqlToRelConverterTest.xml | 2 +- .../apache/calcite/linq4j/tree/Primitive.java | 59 ++++++++++++ .../apache/calcite/test/SqlOperatorTest.java | 91 ++++++++++--------- 5 files changed, 146 insertions(+), 47 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java index 7886c5353d77..64db857bb503 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java @@ -449,6 +449,35 @@ private Expression getConvertExpression( return defaultExpression.get(); } + case DECIMAL: { + int precision = targetType.getPrecision(); + int scale = targetType.getScale(); + if (precision != RelDataType.PRECISION_NOT_SPECIFIED + && scale != RelDataType.SCALE_NOT_SPECIFIED) { + if (sourceType.getFamily() == SqlTypeFamily.CHARACTER) { + return Expressions.call( + BuiltInMethod.CHAR_DECIMAL_CAST.method, + operand, + Expressions.constant(precision), + Expressions.constant(scale)); + } else if (sourceType.getFamily() == SqlTypeFamily.INTERVAL_DAY_TIME) { + return Expressions.call( + BuiltInMethod.SHORT_INTERVAL_DECIMAL_CAST.method, + operand, + Expressions.constant(precision), + Expressions.constant(scale), + Expressions.constant(sourceType.getSqlTypeName().getEndUnit().multiplier)); + } else if (sourceType.getFamily() == SqlTypeFamily.INTERVAL_YEAR_MONTH) { + return Expressions.call( + BuiltInMethod.LONG_INTERVAL_DECIMAL_CAST.method, + operand, + Expressions.constant(precision), + Expressions.constant(scale), + Expressions.constant(sourceType.getSqlTypeName().getEndUnit().multiplier)); + } + } + return defaultExpression.get(); + } case BIGINT: case INTEGER: case TINYINT: @@ -1067,7 +1096,9 @@ public Expression getRoot() { /** If an expression is a {@code NUMERIC} derived from an {@code INTERVAL}, * scales it appropriately; returns the operand unchanged if the conversion - * is not from {@code INTERVAL} to {@code NUMERIC}. */ + * is not from {@code INTERVAL} to {@code NUMERIC}. + * Does not scale values of type DECIMAL, these are expected + * to be already scaled. */ private static Expression scaleValue( RelDataType sourceType, RelDataType targetType, @@ -1075,6 +1106,9 @@ private static Expression scaleValue( final SqlTypeFamily targetFamily = targetType.getSqlTypeName().getFamily(); final SqlTypeFamily sourceFamily = sourceType.getSqlTypeName().getFamily(); if (targetFamily == SqlTypeFamily.NUMERIC + // multiplyDivide cannot handle DECIMALs, but for DECIMAL + // destination types the result is already scaled. + && targetType.getSqlTypeName() != SqlTypeName.DECIMAL && (sourceFamily == SqlTypeFamily.INTERVAL_YEAR_MONTH || sourceFamily == SqlTypeFamily.INTERVAL_DAY_TIME)) { // Scale to the given field. diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java index 07bc45ab5c5a..1982a4ec135f 100644 --- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java +++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java @@ -157,6 +157,11 @@ public enum BuiltInMethod { QueryProvider.class, SchemaPlus.class, String.class), AS_QUERYABLE(Enumerable.class, "asQueryable"), ABSTRACT_ENUMERABLE_CTOR(AbstractEnumerable.class), + CHAR_DECIMAL_CAST(Primitive.class, "charToDecimalCast", String.class, int.class, int.class), + SHORT_INTERVAL_DECIMAL_CAST(Primitive.class, "shortIntervalToDecimalCast", + Long.class, int.class, int.class, BigDecimal.class), + LONG_INTERVAL_DECIMAL_CAST(Primitive.class, "longIntervalToDecimalCast", + Integer.class, int.class, int.class, BigDecimal.class), INTO(ExtendedEnumerable.class, "into", Collection.class), REMOVE_ALL(ExtendedEnumerable.class, "removeAll", Collection.class), SCHEMA_GET_SUB_SCHEMA(Schema.class, "getSubSchema", String.class), diff --git a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml index 2a029cfd09c2..43f8ae050600 100644 --- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml @@ -6327,7 +6327,7 @@ LogicalProject(I=[$0]) diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java index 6aa575c795cf..94578647e303 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java @@ -388,6 +388,65 @@ static void checkRoundedRange(Number value, double min, double max) { return requireNonNull(primitive, "primitive").numberValue((Number) value); } + static BigDecimal checkOverflow(BigDecimal value, int precision, int scale) { + // The rounding mode is not specified in any Calcite docs, + // but elsewhere Calcite is rounding down. For example, Calcite is frequently + // calling BigDecimal.longValue(), which is rounding down, by ignoring all + // digits after the decimal point. + BigDecimal result = value.setScale(scale, RoundingMode.DOWN); + result = result.stripTrailingZeros(); + if (result.scale() < scale) { + // stripTrailingZeros also removes zeros if there is no + // decimal point, converting 1000 to 1e+3, using a negative scale. + // Here we undo this change. + result = result.setScale(scale, RoundingMode.DOWN); + } + int actualPrecision = result.precision(); + if (actualPrecision > precision) { + throw new ArithmeticException("Value " + value + + " cannot be represented as a DECIMAL(" + precision + ", " + scale + ")"); + } + return result; + } + + /** Called from BuiltInMethod.CHAR_DECIMAL_CAST */ + public static @Nullable Object charToDecimalCast( + @Nullable String value, int precision, int scale) { + if (value == null) { + return null; + } + BigDecimal result = new BigDecimal(value.trim()); + return checkOverflow(result, precision, scale); + } + + /** + * Convert a short time interval to a decimal value. + * Called from BuiltInMethod.SHORT_INTERVAL_DECIMAL_CAST. + * @param unitScale Scale describing source interval type */ + public static @Nullable Object shortIntervalToDecimalCast( + @Nullable Long value, int precision, int scale, BigDecimal unitScale) { + if (value == null) { + return null; + } + // Divide with the scale expected of the result + BigDecimal result = new BigDecimal(value).divide(unitScale, scale, RoundingMode.DOWN); + return checkOverflow(result, precision, scale); + } + + /** + * Convert a long time interval to a decimal value. + * Called from BuiltInMethod.LONG_INTERVAL_DECIMAL_CAST. + * @param unitScale Scale describing source interval type */ + public static @Nullable Object longIntervalToDecimalCast( + @Nullable Integer value, int precision, int scale, BigDecimal unitScale) { + if (value == null) { + return null; + } + // Divide with the scale expected of the result + BigDecimal result = new BigDecimal(value).divide(unitScale, scale, RoundingMode.DOWN); + return checkOverflow(result, precision, scale); + } + /** * Converts a number into a value of the type specified by this primitive * using the SQL CAST rules. If the value conversion causes loss of significant digits, diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java index d9576e1a252b..d018b5de7673 100644 --- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java +++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java @@ -120,6 +120,7 @@ import static org.apache.calcite.rel.type.RelDataTypeImpl.NON_NULLABLE_SUFFIX; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.PI; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.QUANTIFY_OPERATORS; +import static org.apache.calcite.sql.test.ResultCheckers.isDecimal; import static org.apache.calcite.sql.test.ResultCheckers.isExactDateTime; import static org.apache.calcite.sql.test.ResultCheckers.isExactTime; import static org.apache.calcite.sql.test.ResultCheckers.isExactly; @@ -740,30 +741,32 @@ void testCastToExactNumeric(CastType castType, SqlOperatorFixture f) { @MethodSource("safeParameters") void testCastStringToDecimal(CastType castType, SqlOperatorFixture f) { f.setFor(SqlStdOperatorTable.CAST, VmName.EXPAND); - if (!DECIMAL) { - return; - } // string to decimal f.checkScalarExact("cast('1.29' as decimal(2,1))", "DECIMAL(2, 1) NOT NULL", - "1.3"); + "1.2"); f.checkScalarExact("cast(' 1.25 ' as decimal(2,1))", "DECIMAL(2, 1) NOT NULL", - "1.3"); + "1.2"); f.checkScalarExact("cast('1.21' as decimal(2,1))", "DECIMAL(2, 1) NOT NULL", "1.2"); f.checkScalarExact("cast(' -1.29 ' as decimal(2,1))", "DECIMAL(2, 1) NOT NULL", - "-1.3"); + "-1.2"); f.checkScalarExact("cast('-1.25' as decimal(2,1))", "DECIMAL(2, 1) NOT NULL", - "-1.3"); + "-1.2"); f.checkScalarExact("cast(' -1.21 ' as decimal(2,1))", "DECIMAL(2, 1) NOT NULL", "-1.2"); - f.checkFails("cast(' -1.21e' as decimal(2,1))", INVALID_CHAR_MESSAGE, - true); + String shouldFail = "cast(' -1.21e' as decimal(2,1))"; + if (castType == CastType.CAST) { + f.checkFails(shouldFail, INVALID_CHAR_MESSAGE, true); + } else { + // safe casts never fail + f.checkNull(shouldFail); + } } @ParameterizedTest @@ -771,42 +774,40 @@ void testCastStringToDecimal(CastType castType, SqlOperatorFixture f) { void testCastIntervalToNumeric(CastType castType, SqlOperatorFixture f) { f.setFor(SqlStdOperatorTable.CAST, VmName.EXPAND); - // interval to decimal - if (DECIMAL) { - f.checkScalarExact("cast(INTERVAL '1.29' second(1,2) as decimal(2,1))", - "DECIMAL(2, 1) NOT NULL", - "1.3"); - f.checkScalarExact("cast(INTERVAL '1.25' second as decimal(2,1))", - "DECIMAL(2, 1) NOT NULL", - "1.3"); - f.checkScalarExact("cast(INTERVAL '-1.29' second as decimal(2,1))", - "DECIMAL(2, 1) NOT NULL", - "-1.3"); - f.checkScalarExact("cast(INTERVAL '-1.25' second as decimal(2,1))", - "DECIMAL(2, 1) NOT NULL", - "-1.3"); - f.checkScalarExact("cast(INTERVAL '-1.21' second as decimal(2,1))", - "DECIMAL(2, 1) NOT NULL", - "-1.2"); - f.checkScalarExact("cast(INTERVAL '5' minute as decimal(2,1))", - "DECIMAL(2, 1) NOT NULL", - "5.0"); - f.checkScalarExact("cast(INTERVAL '5' hour as decimal(2,1))", - "DECIMAL(2, 1) NOT NULL", - "5.0"); - f.checkScalarExact("cast(INTERVAL '5' day as decimal(2,1))", - "DECIMAL(2, 1) NOT NULL", - "5.0"); - f.checkScalarExact("cast(INTERVAL '5' month as decimal(2,1))", - "DECIMAL(2, 1) NOT NULL", - "5.0"); - f.checkScalarExact("cast(INTERVAL '5' year as decimal(2,1))", - "DECIMAL(2, 1) NOT NULL", - "5.0"); - f.checkScalarExact("cast(INTERVAL '-5' day as decimal(2,1))", - "DECIMAL(2, 1) NOT NULL", - "-5.0"); - } + // Interval to Decimal + f.checkScalarExact("cast(INTERVAL '1.29' second(1,2) as decimal(2,1))", + "DECIMAL(2, 1) NOT NULL", + "1.2"); + f.checkScalarExact("cast(INTERVAL '1.25' second as decimal(2,1))", + "DECIMAL(2, 1) NOT NULL", + "1.2"); + f.checkScalarExact("cast(INTERVAL '-1.29' second as decimal(2,1))", + "DECIMAL(2, 1) NOT NULL", + "-1.2"); + f.checkScalarExact("cast(INTERVAL '-1.25' second as decimal(2,1))", + "DECIMAL(2, 1) NOT NULL", + "-1.2"); + f.checkScalarExact("cast(INTERVAL '-1.21' second as decimal(2,1))", + "DECIMAL(2, 1) NOT NULL", + "-1.2"); + f.checkScalarExact("cast(INTERVAL '5' minute as decimal(2,1))", + "DECIMAL(2, 1) NOT NULL", + isDecimal("5.0")); + f.checkScalarExact("cast(INTERVAL '5' hour as decimal(2,1))", + "DECIMAL(2, 1) NOT NULL", + isDecimal("5.0")); + f.checkScalarExact("cast(INTERVAL '5' day as decimal(2,1))", + "DECIMAL(2, 1) NOT NULL", + isDecimal("5.0")); + f.checkScalarExact("cast(INTERVAL '5' month as decimal(2,1))", + "DECIMAL(2, 1) NOT NULL", + isDecimal("5.0")); + f.checkScalarExact("cast(INTERVAL '5' year as decimal(2,1))", + "DECIMAL(2, 1) NOT NULL", + isDecimal("5.0")); + f.checkScalarExact("cast(INTERVAL '-5' day as decimal(2,1))", + "DECIMAL(2, 1) NOT NULL", + isDecimal("-5.0")); // Interval to bigint f.checkScalarExact("cast(INTERVAL '1.25' second as bigint)",