Skip to content

Commit

Permalink
GpuInsertIntoHiveTable supports parquet
Browse files Browse the repository at this point in the history
Signed-off-by: Firestarman <firestarmanllc@gmail.com>
  • Loading branch information
firestarman committed Feb 23, 2024
1 parent b1a6335 commit 4be7993
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,8 +17,9 @@
package org.apache.spark.sql.hive.rapids

import java.nio.charset.Charset
import java.util.Locale

import ai.rapids.cudf.{CSVWriterOptions, DType, QuoteStyle, Scalar, Table, TableWriter => CudfTableWriter}
import ai.rapids.cudf.{CompressionType, CSVWriterOptions, DType, ParquetWriterOptions, QuoteStyle, Scalar, Table, TableWriter => CudfTableWriter}
import com.google.common.base.Charsets
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.withResource
Expand All @@ -27,14 +28,65 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.hive.rapids.GpuHiveTextFileUtils._
import org.apache.spark.sql.hive.rapids.shims.GpuInsertIntoHiveTableMeta
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StringType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch

object GpuHiveTextFileFormat extends Logging {
object GpuHiveFileFormat extends Logging {

private def checkIfEnabled(meta: GpuInsertIntoHiveTableMeta): Unit = {
def tagGpuSupport(meta: GpuInsertIntoHiveTableMeta): Option[ColumnarFileFormat] = {
val insertCmd = meta.wrapped
// Bucketing write
if (insertCmd.table.bucketSpec.isDefined) {
meta.willNotWorkOnGpu("bucketed tables are not supported yet")
}

// Infer the file format from the serde string, similar as what Spark does in
// RelationConversions for Hive.
val serde = insertCmd.table.storage.serde.getOrElse("").toLowerCase(Locale.ROOT)
val tempFileFormat = if (serde.contains("parquet")) {
// Parquet specific tagging
tagGpuSupportForParquet(meta)
} else {
// Default to text file format
tagGpuSupportForText(meta)
}

if (meta.canThisBeReplaced) {
Some(tempFileFormat)
} else {
None
}
}

private def tagGpuSupportForParquet(meta: GpuInsertIntoHiveTableMeta): ColumnarFileFormat = {
val insertCmd = meta.wrapped
// Configs check for Parquet write enabling/disabling
// FIXME Need to check serde and output format classes ?

// FIXME Need a new format type for Hive Parquet write ?
FileFormatChecks.tag(meta, insertCmd.table.schema, ParquetFormatType, WriteFileOp)

// Compression
var compType = CompressionType.NONE
if (isCompressionEnabled(insertCmd.conf)) {
val parquetOptions = new ParquetOptions(insertCmd.table.properties, insertCmd.conf)
val compressionType =
GpuParquetFileFormat.parseCompressionType(parquetOptions.compressionCodecClassName)
if (compressionType.nonEmpty) {
compType = compressionType.get
} else {
meta.willNotWorkOnGpu(
s"compression codec ${parquetOptions.compressionCodecClassName} is not supported")
}
}
new GpuHiveParquetFileFormat(compType)
}

private def tagGpuSupportForText(meta: GpuInsertIntoHiveTableMeta): ColumnarFileFormat = {
if (!meta.conf.isHiveDelimitedTextEnabled) {
meta.willNotWorkOnGpu("Hive text I/O has been disabled. To enable this, " +
s"set ${RapidsConf.ENABLE_HIVE_TEXT} to true")
Expand All @@ -43,21 +95,16 @@ object GpuHiveTextFileFormat extends Logging {
meta.willNotWorkOnGpu("writing Hive delimited text tables has been disabled, " +
s"to enable this, set ${RapidsConf.ENABLE_HIVE_TEXT_WRITE} to true")
}
}

def tagGpuSupport(meta: GpuInsertIntoHiveTableMeta)
: Option[ColumnarFileFormat] = {
checkIfEnabled(meta)

val insertCommand = meta.wrapped
val storage = insertCommand.table.storage
if (storage.outputFormat.getOrElse("") != textOutputFormat) {
meta.willNotWorkOnGpu(s"unsupported output-format found: ${storage.outputFormat}, " +
s"only $textOutputFormat is currently supported")
s"only $textOutputFormat is currently supported for text")
}
if (storage.serde.getOrElse("") != lazySimpleSerDe) {
meta.willNotWorkOnGpu(s"unsupported serde found: ${storage.serde}, " +
s"only $lazySimpleSerDe is currently supported")
s"only $lazySimpleSerDe is currently supported for text")
}

val serializationFormat = storage.properties.getOrElse(serializationKey, "1")
Expand Down Expand Up @@ -86,22 +133,49 @@ object GpuHiveTextFileFormat extends Logging {
meta.willNotWorkOnGpu("only UTF-8 is supported as the charset")
}

if (insertCommand.table.bucketSpec.isDefined) {
meta.willNotWorkOnGpu("bucketed tables are not supported")
}

if (insertCommand.conf.getConfString("hive.exec.compress.output", "false").toLowerCase
!= "false") {
if (isCompressionEnabled(insertCommand.conf)) {
meta.willNotWorkOnGpu("compressed output is not supported, " +
"set hive.exec.compress.output to false to enable writing Hive text via GPU")
}

FileFormatChecks.tag(meta,
insertCommand.table.schema,
HiveDelimitedTextFormatType,
WriteFileOp)
FileFormatChecks.tag(meta, insertCommand.table.schema, HiveDelimitedTextFormatType,
WriteFileOp)

new GpuHiveTextFileFormat()
}

private def isCompressionEnabled(conf: SQLConf): Boolean = {
conf.getConfString("hive.exec.compress.output", "false").toBoolean
}
}

class GpuHiveParquetFileFormat(compType: CompressionType) extends ColumnarFileFormat {

override def prepareWrite(sparkSession: SparkSession, job: Job,
options: Map[String, String], dataSchema: StructType): ColumnarOutputWriterFactory = {
new ColumnarOutputWriterFactory {
override def getFileExtension(context: TaskAttemptContext): String = ".parquet"

override def newInstance(path: String,
dataSchema: StructType,
context: TaskAttemptContext): ColumnarOutputWriter = {
new GpuHiveParquetWriter(path, dataSchema, context, compType)
}
}
}
}

class GpuHiveParquetWriter(override val path: String, dataSchema: StructType,
context: TaskAttemptContext, compType: CompressionType)
extends ColumnarOutputWriter(context, dataSchema, "HiveParquet", false) {

Some(new GpuHiveTextFileFormat())
override protected val tableWriter: CudfTableWriter = {
// TODO How to set INT96 and FieldIDEnabled ?
val writeOptions = SchemaUtils
.writerOptionsFromSchema(ParquetWriterOptions.builder(), dataSchema)
.withCompressionType(compType)
.build()
Table.writeParquetChunked(writeOptions, this)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -58,7 +58,7 @@ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
import org.apache.spark.sql.hive.client.HiveClientImpl
import org.apache.spark.sql.hive.client.hive._
import org.apache.spark.sql.hive.execution.InsertIntoHiveTable
import org.apache.spark.sql.hive.rapids.{GpuHiveTextFileFormat, GpuSaveAsHiveFile, RapidsHiveErrors}
import org.apache.spark.sql.hive.rapids.{GpuHiveFileFormat, GpuSaveAsHiveFile, RapidsHiveErrors}
import org.apache.spark.sql.vectorized.ColumnarBatch

final class GpuInsertIntoHiveTableMeta(cmd: InsertIntoHiveTable,
Expand All @@ -70,16 +70,17 @@ final class GpuInsertIntoHiveTableMeta(cmd: InsertIntoHiveTable,
private var fileFormat: Option[ColumnarFileFormat] = None

override def tagSelfForGpuInternal(): Unit = {
// Only Hive delimited text writes are currently supported.
// Check whether that is the format currently in play.
fileFormat = GpuHiveTextFileFormat.tagGpuSupport(this)
fileFormat = Some(GpuHiveFileFormat.tagGpuSupport(this))
}

override def convertToGpu(): GpuDataWritingCommand = {
val format = fileFormat.getOrElse(
throw new IllegalStateException("fileFormat missing, tagSelfForGpu not called?"))

GpuInsertIntoHiveTable(
table = wrapped.table,
partition = wrapped.partition,
fileFormat = this.fileFormat.get,
fileFormat = format,
query = wrapped.query,
overwrite = wrapped.overwrite,
ifPartitionNotExists = wrapped.ifPartitionNotExists,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -45,7 +45,7 @@ import org.apache.spark.sql.execution.command.CommandUtils
import org.apache.spark.sql.hive.HiveExternalCatalog
import org.apache.spark.sql.hive.client.HiveClientImpl
import org.apache.spark.sql.hive.execution.InsertIntoHiveTable
import org.apache.spark.sql.hive.rapids.{GpuHiveTextFileFormat, GpuSaveAsHiveFile, RapidsHiveErrors}
import org.apache.spark.sql.hive.rapids.{GpuHiveFileFormat, GpuSaveAsHiveFile, RapidsHiveErrors}
import org.apache.spark.sql.vectorized.ColumnarBatch

final class GpuInsertIntoHiveTableMeta(cmd: InsertIntoHiveTable,
Expand All @@ -57,16 +57,17 @@ final class GpuInsertIntoHiveTableMeta(cmd: InsertIntoHiveTable,
private var fileFormat: Option[ColumnarFileFormat] = None

override def tagSelfForGpuInternal(): Unit = {
// Only Hive delimited text writes are currently supported.
// Check whether that is the format currently in play.
fileFormat = GpuHiveTextFileFormat.tagGpuSupport(this)
fileFormat = Some(GpuHiveFileFormat.tagGpuSupport(this))
}

override def convertToGpu(): GpuDataWritingCommand = {
val format = fileFormat.getOrElse(
throw new IllegalStateException("fileFormat missing, tagSelfForGpu not called?"))

GpuInsertIntoHiveTable(
table = wrapped.table,
partition = wrapped.partition,
fileFormat = this.fileFormat.get,
fileFormat = format,
query = wrapped.query,
overwrite = wrapped.overwrite,
ifPartitionNotExists = wrapped.ifPartitionNotExists,
Expand Down

0 comments on commit 4be7993

Please sign in to comment.