Skip to content

Commit

Permalink
enable decimal to decimal and enable castoptions to be passed
Browse files Browse the repository at this point in the history
use a regex to match arrow invalid argument error.
  • Loading branch information
himadripal committed Feb 11, 2025
1 parent f099e6e commit a2689a3
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
9 changes: 8 additions & 1 deletion native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,13 @@ fn cast_array(
let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?;
let from_type = array.data_type().clone();

let native_cast_options: CastOptions = CastOptions {
safe: !matches!(cast_options.eval_mode, EvalMode::Ansi), // take safe mode from cast_options passed
format_options: FormatOptions::new()
.with_timestamp_tz_format(TIMESTAMP_FORMAT)
.with_timestamp_format(TIMESTAMP_FORMAT),
};

let array = match &from_type {
Dictionary(key_type, value_type)
if key_type.as_ref() == &Int32
Expand Down Expand Up @@ -963,7 +970,7 @@ fn cast_array(
|| is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) =>
{
// use DataFusion cast only when we know that it is compatible with Spark
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
Ok(cast_with_options(&array, to_type, &native_cast_options)?)
}
_ => {
// we should never reach this code because the Scala code should be checking
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,8 @@ object CometCast {
case _ =>
Unsupported
}
case (from: DecimalType, to: DecimalType) =>
if (to.precision < from.precision) {
// https://github.com/apache/datafusion/issues/13492
Incompatible(Some("Casting to smaller precision is not supported"))
} else {
Compatible()
}
case (_: DecimalType, _: DecimalType) =>
Compatible()
case (DataTypes.StringType, _) =>
canCastFromString(toType, timeZoneId, evalMode)
case (_, DataTypes.StringType) =>
Expand Down
17 changes: 13 additions & 4 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -913,8 +913,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val values = Seq(BigDecimal("12345.6789"), BigDecimal("9876.5432"), BigDecimal("123.4567"))
val df = withNulls(values)
.toDF("b")
.withColumn("a", col("b").cast(DecimalType(6, 2)))
checkSparkAnswer(df)
.withColumn("a", col("b").cast(DecimalType(38, 28)))
castTest(df, DecimalType(6, 2))
}

test("cast between decimals with higher precision than source") {
Expand Down Expand Up @@ -1133,8 +1133,17 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
.replace(".WITH_SUGGESTION] ", "]")
.startsWith(cometMessage))
} else if (CometSparkSessionExtensions.isSpark34Plus) {
// for Spark 3.4 we expect to reproduce the error message exactly
assert(cometMessage == sparkMessage)
// for comet decimal conversion throws ArrowError(string) from arrow
if (cometMessage.contains("Invalid argument error")) {
val regex =
"\\[\\[?(NUMERIC_VALUE_OUT_OF_RANGE|Invalid argument error)\\]?:? .*? (\\d+(\\.\\d+)?) .*? Decimal\\(?(\\d+),?\\s?(\\d+)\\)?.*?\\]?"
assert(cometMessage.matches(regex) == sparkMessage.matches(regex))

} else {
// for Spark 3.4 we expect to reproduce the error message exactly
assert(cometMessage == sparkMessage)

}
} else {
// for Spark 3.3 we just need to strip the prefix from the Comet message
// before comparing
Expand Down

0 comments on commit a2689a3

Please sign in to comment.