Skip to content

Commit

Permalink
[CALCITE-6380] Casts from INTERVAL and STRING to DECIMAL are incorrect
Browse files Browse the repository at this point in the history
Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
  • Loading branch information
mihaibudiu committed Jun 20, 2024
1 parent a419a12 commit c0a53f6
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,35 @@ private Expression getConvertExpression(
return defaultExpression.get();
}

case DECIMAL: {
int precision = targetType.getPrecision();
int scale = targetType.getScale();
if (precision != RelDataType.PRECISION_NOT_SPECIFIED
&& scale != RelDataType.SCALE_NOT_SPECIFIED) {
if (sourceType.getFamily() == SqlTypeFamily.CHARACTER) {
return Expressions.call(
BuiltInMethod.CHAR_DECIMAL_CAST.method,
operand,
Expressions.constant(precision),
Expressions.constant(scale));
} else if (sourceType.getFamily() == SqlTypeFamily.INTERVAL_DAY_TIME) {
return Expressions.call(
BuiltInMethod.SHORT_INTERVAL_DECIMAL_CAST.method,
operand,
Expressions.constant(precision),
Expressions.constant(scale),
Expressions.constant(sourceType.getSqlTypeName().getEndUnit().multiplier));
} else if (sourceType.getFamily() == SqlTypeFamily.INTERVAL_YEAR_MONTH) {
return Expressions.call(
BuiltInMethod.LONG_INTERVAL_DECIMAL_CAST.method,
operand,
Expressions.constant(precision),
Expressions.constant(scale),
Expressions.constant(sourceType.getSqlTypeName().getEndUnit().multiplier));
}
}
return defaultExpression.get();
}
case BIGINT:
case INTEGER:
case TINYINT:
Expand Down Expand Up @@ -1067,14 +1096,19 @@ public Expression getRoot() {

/** If an expression is a {@code NUMERIC} derived from an {@code INTERVAL},
* scales it appropriately; returns the operand unchanged if the conversion
* is not from {@code INTERVAL} to {@code NUMERIC}. */
* is not from {@code INTERVAL} to {@code NUMERIC}.
* Does <b>not</b> scale values of type DECIMAL, these are expected
* to be already scaled. */
private static Expression scaleValue(
RelDataType sourceType,
RelDataType targetType,
Expression operand) {
final SqlTypeFamily targetFamily = targetType.getSqlTypeName().getFamily();
final SqlTypeFamily sourceFamily = sourceType.getSqlTypeName().getFamily();
if (targetFamily == SqlTypeFamily.NUMERIC
// multiplyDivide cannot handle DECIMALs, but for DECIMAL
// destination types the result is already scaled.
&& targetType.getSqlTypeName() != SqlTypeName.DECIMAL
&& (sourceFamily == SqlTypeFamily.INTERVAL_YEAR_MONTH
|| sourceFamily == SqlTypeFamily.INTERVAL_DAY_TIME)) {
// Scale to the given field.
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ public enum BuiltInMethod {
QueryProvider.class, SchemaPlus.class, String.class),
AS_QUERYABLE(Enumerable.class, "asQueryable"),
ABSTRACT_ENUMERABLE_CTOR(AbstractEnumerable.class),
CHAR_DECIMAL_CAST(Primitive.class, "charToDecimalCast", String.class, int.class, int.class),
SHORT_INTERVAL_DECIMAL_CAST(Primitive.class, "shortIntervalToDecimalCast",
Long.class, int.class, int.class, BigDecimal.class),
LONG_INTERVAL_DECIMAL_CAST(Primitive.class, "longIntervalToDecimalCast",
Integer.class, int.class, int.class, BigDecimal.class),
INTO(ExtendedEnumerable.class, "into", Collection.class),
REMOVE_ALL(ExtendedEnumerable.class, "removeAll", Collection.class),
SCHEMA_GET_SUB_SCHEMA(Schema.class, "getSubSchema", String.class),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6327,7 +6327,7 @@ LogicalProject(I=[$0])
<Resource name="plan">
<![CDATA[
LogicalAggregate(group=[{}], EXPR$0=[SUM($0)])
LogicalProject($f0=[0.1:DECIMAL(19, 9)])
LogicalProject($f0=[0.100000000:DECIMAL(19, 9)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
Expand Down
59 changes: 59 additions & 0 deletions linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,65 @@ static void checkRoundedRange(Number value, double min, double max) {
return requireNonNull(primitive, "primitive").numberValue((Number) value);
}

static BigDecimal checkOverflow(BigDecimal value, int precision, int scale) {
// The rounding mode is not specified in any Calcite docs,
// but elsewhere Calcite is rounding down. For example, Calcite is frequently
// calling BigDecimal.longValue(), which is rounding down, by ignoring all
// digits after the decimal point.
BigDecimal result = value.setScale(scale, RoundingMode.DOWN);
result = result.stripTrailingZeros();
if (result.scale() < scale) {
// stripTrailingZeros also removes zeros if there is no
// decimal point, converting 1000 to 1e+3, using a negative scale.
// Here we undo this change.
result = result.setScale(scale, RoundingMode.DOWN);
}
int actualPrecision = result.precision();
if (actualPrecision > precision) {
throw new ArithmeticException("Value " + value
+ " cannot be represented as a DECIMAL(" + precision + ", " + scale + ")");
}
return result;
}

/** Called from BuiltInMethod.CHAR_DECIMAL_CAST */
public static @Nullable Object charToDecimalCast(
@Nullable String value, int precision, int scale) {
if (value == null) {
return null;
}
BigDecimal result = new BigDecimal(value.trim());
return checkOverflow(result, precision, scale);
}

/**
* Convert a short time interval to a decimal value.
* Called from BuiltInMethod.SHORT_INTERVAL_DECIMAL_CAST.
* @param unitScale Scale describing source interval type */
public static @Nullable Object shortIntervalToDecimalCast(
@Nullable Long value, int precision, int scale, BigDecimal unitScale) {
if (value == null) {
return null;
}
// Divide with the scale expected of the result
BigDecimal result = new BigDecimal(value).divide(unitScale, scale, RoundingMode.DOWN);
return checkOverflow(result, precision, scale);
}

/**
* Convert a long time interval to a decimal value.
* Called from BuiltInMethod.LONG_INTERVAL_DECIMAL_CAST.
* @param unitScale Scale describing source interval type */
public static @Nullable Object longIntervalToDecimalCast(
@Nullable Integer value, int precision, int scale, BigDecimal unitScale) {
if (value == null) {
return null;
}
// Divide with the scale expected of the result
BigDecimal result = new BigDecimal(value).divide(unitScale, scale, RoundingMode.DOWN);
return checkOverflow(result, precision, scale);
}

/**
* Converts a number into a value of the type specified by this primitive
* using the SQL CAST rules. If the value conversion causes loss of significant digits,
Expand Down
91 changes: 46 additions & 45 deletions testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
import static org.apache.calcite.rel.type.RelDataTypeImpl.NON_NULLABLE_SUFFIX;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.PI;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.QUANTIFY_OPERATORS;
import static org.apache.calcite.sql.test.ResultCheckers.isDecimal;
import static org.apache.calcite.sql.test.ResultCheckers.isExactDateTime;
import static org.apache.calcite.sql.test.ResultCheckers.isExactTime;
import static org.apache.calcite.sql.test.ResultCheckers.isExactly;
Expand Down Expand Up @@ -740,73 +741,73 @@ void testCastToExactNumeric(CastType castType, SqlOperatorFixture f) {
@MethodSource("safeParameters")
void testCastStringToDecimal(CastType castType, SqlOperatorFixture f) {
f.setFor(SqlStdOperatorTable.CAST, VmName.EXPAND);
if (!DECIMAL) {
return;
}
// string to decimal
f.checkScalarExact("cast('1.29' as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"1.3");
"1.2");
f.checkScalarExact("cast(' 1.25 ' as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"1.3");
"1.2");
f.checkScalarExact("cast('1.21' as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"1.2");
f.checkScalarExact("cast(' -1.29 ' as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"-1.3");
"-1.2");
f.checkScalarExact("cast('-1.25' as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"-1.3");
"-1.2");
f.checkScalarExact("cast(' -1.21 ' as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"-1.2");
f.checkFails("cast(' -1.21e' as decimal(2,1))", INVALID_CHAR_MESSAGE,
true);
String shouldFail = "cast(' -1.21e' as decimal(2,1))";
if (castType == CastType.CAST) {
f.checkFails(shouldFail, INVALID_CHAR_MESSAGE, true);
} else {
// safe casts never fail
f.checkNull(shouldFail);
}
}

@ParameterizedTest
@MethodSource("safeParameters")
void testCastIntervalToNumeric(CastType castType, SqlOperatorFixture f) {
f.setFor(SqlStdOperatorTable.CAST, VmName.EXPAND);

// interval to decimal
if (DECIMAL) {
f.checkScalarExact("cast(INTERVAL '1.29' second(1,2) as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"1.3");
f.checkScalarExact("cast(INTERVAL '1.25' second as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"1.3");
f.checkScalarExact("cast(INTERVAL '-1.29' second as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"-1.3");
f.checkScalarExact("cast(INTERVAL '-1.25' second as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"-1.3");
f.checkScalarExact("cast(INTERVAL '-1.21' second as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"-1.2");
f.checkScalarExact("cast(INTERVAL '5' minute as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"5.0");
f.checkScalarExact("cast(INTERVAL '5' hour as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"5.0");
f.checkScalarExact("cast(INTERVAL '5' day as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"5.0");
f.checkScalarExact("cast(INTERVAL '5' month as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"5.0");
f.checkScalarExact("cast(INTERVAL '5' year as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"5.0");
f.checkScalarExact("cast(INTERVAL '-5' day as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"-5.0");
}
// Interval to Decimal
f.checkScalarExact("cast(INTERVAL '1.29' second(1,2) as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"1.2");
f.checkScalarExact("cast(INTERVAL '1.25' second as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"1.2");
f.checkScalarExact("cast(INTERVAL '-1.29' second as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"-1.2");
f.checkScalarExact("cast(INTERVAL '-1.25' second as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"-1.2");
f.checkScalarExact("cast(INTERVAL '-1.21' second as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
"-1.2");
f.checkScalarExact("cast(INTERVAL '5' minute as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
isDecimal("5.0"));
f.checkScalarExact("cast(INTERVAL '5' hour as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
isDecimal("5.0"));
f.checkScalarExact("cast(INTERVAL '5' day as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
isDecimal("5.0"));
f.checkScalarExact("cast(INTERVAL '5' month as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
isDecimal("5.0"));
f.checkScalarExact("cast(INTERVAL '5' year as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
isDecimal("5.0"));
f.checkScalarExact("cast(INTERVAL '-5' day as decimal(2,1))",
"DECIMAL(2, 1) NOT NULL",
isDecimal("-5.0"));

// Interval to bigint
f.checkScalarExact("cast(INTERVAL '1.25' second as bigint)",
Expand Down

0 comments on commit c0a53f6

Please sign in to comment.