Skip to content

Commit

Permalink
[CALCITE-6111] Explicit cast from expression to numeric type doesn't …
Browse files Browse the repository at this point in the history
…check overflow

Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
  • Loading branch information
mihaibudiu committed Jan 19, 2024
1 parent d5fa3eb commit e2c84a6
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.apache.calcite.sql.SqlWindowTableFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlConformance;
import org.apache.calcite.util.BuiltInMethod;
Expand Down Expand Up @@ -534,6 +535,19 @@ private Expression getConvertExpression(
return defaultExpression.get();
}

case BIGINT:
case INTEGER:
case TINYINT:
case SMALLINT: {
if (SqlTypeName.NUMERIC_TYPES.contains(sourceType.getSqlTypeName())) {
return Expressions.call(
BuiltInMethod.INTEGER_CAST.method,
Expressions.constant(Primitive.of(typeFactory.getJavaClass(targetType))),
operand);
}
return defaultExpression.get();
}

default:
return defaultExpression.get();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ public enum BuiltInMethod {
ENUMERABLE_TO_LIST(ExtendedEnumerable.class, "toList"),
ENUMERABLE_TO_MAP(ExtendedEnumerable.class, "toMap", Function1.class, Function1.class),
AS_LIST(Primitive.class, "asList", Object.class),
INTEGER_CAST(Primitive.class, "integerCast", Primitive.class, Object.class),
MEMORY_GET0(MemoryFactory.Memory.class, "get"),
MEMORY_GET1(MemoryFactory.Memory.class, "get", int.class),
ENUMERATOR_CURRENT(Enumerator.class, "current"),
Expand Down
11 changes: 11 additions & 0 deletions linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,10 @@ static void checkRoundedRange(Number value, double min, double max) {
}
}

public static @Nullable Object integerCast(Primitive primitive, final Object value) {
return requireNonNull(primitive, "primitive").numberValue((Number) value);
}

/**
* Converts a number into a value of the type specified by this primitive
* using the SQL CAST rules. If the value conversion causes loss of significant digits,
Expand Down Expand Up @@ -424,6 +428,13 @@ static void checkRoundedRange(Number value, double min, double max) {
// longValueExact will throw ArithmeticException if out of range
return decimal.longValueExact();
}
if (value instanceof BigDecimal) {
BigDecimal decimal = ((BigDecimal) value)
// Round to an integer
.setScale(0, RoundingMode.DOWN);
// longValueExact will throw ArithmeticException if out of range
return decimal.longValueExact();
}
throw new AssertionError("Unexpected Number type "
+ value.getClass().getSimpleName());
case FLOAT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ public interface SqlOperatorFixture extends AutoCloseable {
// TODO: Change message
String INVALID_CHAR_MESSAGE = "(?s).*";

String OUT_OF_RANGE_MESSAGE = ".* out of range";
String OUT_OF_RANGE_MESSAGE = ".* out of range.*";

String WRONG_FORMAT_MESSAGE = "Number has wrong format.*";

// TODO: Change message
String DIVISION_BY_ZERO_MESSAGE = "(?s).*";
Expand Down Expand Up @@ -643,8 +645,12 @@ default void checkCastToScalarOkay(String value, String targetType,

default void checkCastFails(String value, String targetType,
String expectedError, boolean runtime, CastType castType) {
final String castString = getCastString(value, targetType, !runtime, castType);
checkFails(castString, expectedError, runtime);
final String query = getCastString(value, targetType, !runtime, castType);
if (castType == CastType.CAST || !runtime) {
checkFails(query, expectedError, runtime);
} else {
checkNull(query);
}
}

default void checkCastToString(String value, @Nullable String type,
Expand Down
70 changes: 55 additions & 15 deletions testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
import static org.apache.calcite.sql.test.SqlOperatorFixture.INVALID_EXTRACT_UNIT_VALIDATION_ERROR;
import static org.apache.calcite.sql.test.SqlOperatorFixture.LITERAL_OUT_OF_RANGE_MESSAGE;
import static org.apache.calcite.sql.test.SqlOperatorFixture.OUT_OF_RANGE_MESSAGE;
import static org.apache.calcite.sql.test.SqlOperatorFixture.WRONG_FORMAT_MESSAGE;
import static org.apache.calcite.util.DateTimeStringUtils.getDateFormatter;

import static org.hamcrest.CoreMatchers.equalTo;
Expand Down Expand Up @@ -623,13 +624,16 @@ void testCastExactNumericLimits(CastType castType, SqlOperatorFixture f) {

// Overflow test
if (numeric == Numeric.BIGINT) {
// Literal of range
// Calcite cannot even represent a literal so large, so
// for this query even the safe casts fail at compile-time
// (runtime == false).
f.checkCastFails(numeric.maxOverflowNumericString,
type, LITERAL_OUT_OF_RANGE_MESSAGE, false, castType);
f.checkCastFails(numeric.minOverflowNumericString,
type, LITERAL_OUT_OF_RANGE_MESSAGE, false, castType);
} else {
if (numeric != Numeric.DECIMAL5_2 || Bug.CALCITE_2539_FIXED) {
if (numeric != Numeric.DECIMAL5_2) {
// This condition is for bug [CALCITE-6078], not yet fixed
f.checkCastFails(numeric.maxOverflowNumericString,
type, OUT_OF_RANGE_MESSAGE, true, castType);
f.checkCastFails(numeric.minOverflowNumericString,
Expand All @@ -643,11 +647,12 @@ void testCastExactNumericLimits(CastType castType, SqlOperatorFixture f) {
f.checkCastToScalarOkay("'" + numeric.minNumericString + "'",
type, numeric.minNumericString, castType);

if (Bug.CALCITE_2539_FIXED) {
if (numeric != Numeric.DECIMAL5_2) {
// The above condition is for bug CALCITE-6078
f.checkCastFails("'" + numeric.maxOverflowNumericString + "'",
type, OUT_OF_RANGE_MESSAGE, true, castType);
type, WRONG_FORMAT_MESSAGE, true, castType);
f.checkCastFails("'" + numeric.minOverflowNumericString + "'",
type, OUT_OF_RANGE_MESSAGE, true, castType);
type, WRONG_FORMAT_MESSAGE, true, castType);
}

// Convert from type to string
Expand All @@ -657,10 +662,8 @@ void testCastExactNumericLimits(CastType castType, SqlOperatorFixture f) {
f.checkCastToString(numeric.minNumericString, null, null, castType);
f.checkCastToString(numeric.minNumericString, type, null, castType);

if (Bug.CALCITE_2539_FIXED) {
f.checkCastFails("'notnumeric'", type, INVALID_CHAR_MESSAGE, true,
castType);
}
f.checkCastFails("'notnumeric'", type, INVALID_CHAR_MESSAGE, true,
castType);
});
}

Expand Down Expand Up @@ -1128,14 +1131,14 @@ void testCastInvalid(CastType castType, SqlOperatorFixture f) {
// ExceptionInInitializerError.
f.checkScalarExact("cast('15' as integer)", "INTEGER NOT NULL", "15");
if (castType == CastType.CAST) { // Safe casts should not fail
f.checkFails("cast('15.4' as integer)", "Number has wrong format.*", true);
f.checkFails("cast('15.6' as integer)", "Number has wrong format.*", true);
f.checkFails("cast('15.4' as integer)", WRONG_FORMAT_MESSAGE, true);
f.checkFails("cast('15.6' as integer)", WRONG_FORMAT_MESSAGE, true);
f.checkFails("cast('ue' as boolean)", "Invalid character for cast.*", true);
f.checkFails("cast('' as boolean)", "Invalid character for cast.*", true);
f.checkFails("cast('' as integer)", "Number has wrong format.*", true);
f.checkFails("cast('' as real)", "Number has wrong format.*", true);
f.checkFails("cast('' as double)", "Number has wrong format.*", true);
f.checkFails("cast('' as smallint)", "Number has wrong format.*", true);
f.checkFails("cast('' as integer)", WRONG_FORMAT_MESSAGE, true);
f.checkFails("cast('' as real)", WRONG_FORMAT_MESSAGE, true);
f.checkFails("cast('' as double)", WRONG_FORMAT_MESSAGE, true);
f.checkFails("cast('' as smallint)", WRONG_FORMAT_MESSAGE, true);
} else {
f.checkNull("cast('15.4' as integer)");
f.checkNull("cast('15.6' as integer)");
Expand Down Expand Up @@ -13695,6 +13698,43 @@ private static void checkLogicalOrFunc(SqlOperatorFixture f) {
}
}

/**
* Test cases for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6111">[CALCITE-6111]
* Explicit cast from expression to numeric type doesn't check overflow</a>. */
@Test public void testOverflow() {
final SqlOperatorFixture f = fixture();
f.checkFails(String.format(Locale.US, "SELECT cast(%d+30 as tinyint)", Byte.MAX_VALUE),
OUT_OF_RANGE_MESSAGE, true);
f.checkFails(String.format(Locale.US, "SELECT cast(%d+30 as smallint)", Short.MAX_VALUE),
OUT_OF_RANGE_MESSAGE, true);
// We use a long value because otherwise calcite interprets the literal as an integer.
f.checkFails(String.format(Locale.US, "SELECT cast(%d as int)", Long.MAX_VALUE),
OUT_OF_RANGE_MESSAGE, true);

// Casting a floating point value larger than the maximum allowed value.
// 1e60 is larger than the largest BIGINT value allowed.
f.checkFails("SELECT cast(1e60+30 as tinyint)",
OUT_OF_RANGE_MESSAGE, true);
f.checkFails("SELECT cast(1e60+30 as smallint)",
OUT_OF_RANGE_MESSAGE, true);
f.checkFails("SELECT cast(1e60+30 as int)",
OUT_OF_RANGE_MESSAGE, true);
f.checkFails("SELECT cast(1e60+30 as bigint)",
".*Overflow", true);

// Casting a decimal value larger than the maximum allowed value.
// Concatenating .0 to a value makes it decimal.
f.checkFails(String.format(Locale.US, "SELECT cast(%d.0 AS tinyint)", Short.MAX_VALUE),
OUT_OF_RANGE_MESSAGE, true);
f.checkFails(String.format(Locale.US, "SELECT cast(%d.0 AS smallint)", Integer.MAX_VALUE),
OUT_OF_RANGE_MESSAGE, true);
// Dividing Long.MAX_VALUE by 10 ensures that the resulting decimal does not exceed the
// maximum allowed precision for decimals but is still too large for an integer.
f.checkFails(String.format(Locale.US, "SELECT cast(%d.0 AS int)", Long.MAX_VALUE / 10),
OUT_OF_RANGE_MESSAGE, true);
}

@ParameterizedTest
@MethodSource("safeParameters")
void testCastTruncates(CastType castType, SqlOperatorFixture f) {
Expand Down

0 comments on commit e2c84a6

Please sign in to comment.