Skip to content

Commit

Permalink
add ParquetDirectCodecFactory
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <lovedreamf@gmail.com>
  • Loading branch information
sperlingxx committed May 9, 2024
1 parent c6baa5a commit bec2ec4
Show file tree
Hide file tree
Showing 10 changed files with 365 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,8 @@ case class GpuParquetMultiFilePartitionReaderFactory(
rapidsConf.parquetScanHostBatchSizeBytes,
3,
rapidsConf.parquetScanEnableDictLateMat,
rapidsConf.parquetScanHostAsync)
rapidsConf.parquetScanHostAsync,
rapidsConf.parquetScanUnsafeDecompression)

// We can't use the coalescing files reader when InputFileName, InputFileBlockStart,
// or InputFileBlockLength because we are combining all the files into a single buffer
Expand Down Expand Up @@ -2623,7 +2624,7 @@ class MultiFileCloudParquetPartitionReader(
hostBuffer, 0, dataSize, metrics,
dateRebaseMode, timestampRebaseMode, hasInt96Timestamps,
clippedSchema, readDataSchema,
slotAcquired, opts.enableDictLateMat, opts.async)
slotAcquired, opts.enableDictLateMat, opts.async, opts.unsafeDecompression)

val batchIter = HostParquetIterator(asyncReader, opts, colTypes, metrics)

Expand Down
26 changes: 17 additions & 9 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1545,28 +1545,34 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern")
.stringConf
.createWithDefault("GPU_ONLY")

val PARQUET_HOST_SCAN_PARALLELISM = conf("spark.rapids.sql.parquet.scan.hostParallelism")
val PARQUET_CPU_SCAN_PARALLELISM = conf("spark.rapids.sql.parquet.scan.hostParallelism")
.doc("The max concurrent capacity for host parquet scan(decode) tasks")
.internal()
.integerConf
.createWithDefault(0)

val PARQUET_HOST_SCAN_BATCH_SIZE_BYTES = conf("spark.rapids.sql.parquet.scan.hostBatchSizeBytes")
val PARQUET_CPU_SCAN_BATCH_SIZE_BYTES = conf("spark.rapids.sql.parquet.scan.hostBatchSizeBytes")
.doc("Similar to spark.rapids.sql.batchSizeBytes, but it is only for decode tasks run on CPUs")
.internal()
.integerConf
.createWithDefault(1024 * 1024 * 128)

val PARQUET_HOST_SCAN_ASYNC = conf("spark.rapids.sql.parquet.scan.async")
val PARQUET_CPU_SCAN_ASYNC = conf("spark.rapids.sql.parquet.scan.async")
.doc("Whether run host parquet decode tasks asynchronously or not")
.internal()
.booleanConf
.createWithDefault(false)
.createWithDefault(true)

val PARQUET_SCAN_DICT_LATE_MAT = conf("spark.rapids.sql.parquet.scan.enableDictLateMat")
val PARQUET_CPU_SCAN_DICT_LATE_MAT = conf("spark.rapids.sql.parquet.scan.enableDictLateMat")
.doc("Whether pushing down binary dicts onto GPU and materializing via GPU or not")
.internal()
.booleanConf
.createWithDefault(true)

val PARQUET_CPU_SCAN_UNSAFE_DECOMP = conf("spark.rapids.sql.parquet.scan.unsafeDecompress")
.doc("Whether using UnsafeDecompressor instead of the default one of parquet-hadoop or not")
.internal()
.booleanConf
.createWithDefault(false)

val ORC_DEBUG_DUMP_PREFIX = conf("spark.rapids.sql.orc.debug.dumpPrefix")
Expand Down Expand Up @@ -2630,13 +2636,15 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val parquetScanHybridMode: String = get(PARQUET_HYBRID_SCAN_MODE)

lazy val parquetScanHostParallelism: Int = get(PARQUET_HOST_SCAN_PARALLELISM)
lazy val parquetScanHostParallelism: Int = get(PARQUET_CPU_SCAN_PARALLELISM)

lazy val parquetScanHostBatchSizeBytes: Int = get(PARQUET_CPU_SCAN_BATCH_SIZE_BYTES)

lazy val parquetScanHostBatchSizeBytes: Int = get(PARQUET_HOST_SCAN_BATCH_SIZE_BYTES)
lazy val parquetScanHostAsync: Boolean = get(PARQUET_CPU_SCAN_ASYNC)

lazy val parquetScanHostAsync: Boolean = get(PARQUET_HOST_SCAN_ASYNC)
lazy val parquetScanEnableDictLateMat: Boolean = get(PARQUET_CPU_SCAN_DICT_LATE_MAT)

lazy val parquetScanEnableDictLateMat: Boolean = get(PARQUET_SCAN_DICT_LATE_MAT)
lazy val parquetScanUnsafeDecompression: Boolean = get(PARQUET_CPU_SCAN_UNSAFE_DECOMP)

lazy val orcDebugDumpPrefix: Option[String] = get(ORC_DEBUG_DUMP_PREFIX)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,19 @@ case class AsyncBatchResult(data: Array[HostColumnVector],

class AsyncParquetReaderError(ex: Throwable) extends RuntimeException(ex)

class AsyncParquetReader(
conf: Configuration,
tgtBatchSize: Int,
fileBuffer: HostMemoryBuffer,
offset: Long,
len: Long,
metrics: Map[String, GpuMetric],
dateRebaseMode: DateTimeRebaseMode,
timestampRebaseMode: DateTimeRebaseMode,
clippedSchema: MessageType,
slotAcquired: Boolean,
enableDictLateMat: Boolean,
asynchronous: Boolean)
class AsyncParquetReader(conf: Configuration,
tgtBatchSize: Int,
fileBuffer: HostMemoryBuffer,
offset: Long,
len: Long,
metrics: Map[String, GpuMetric],
dateRebaseMode: DateTimeRebaseMode,
timestampRebaseMode: DateTimeRebaseMode,
clippedSchema: MessageType,
slotAcquired: Boolean,
enableDictLateMat: Boolean,
asynchronous: Boolean,
directBuffering: Boolean)
extends Iterator[AsyncBatchResult] with AutoCloseable with Logging {

private type Element = Either[AsyncBatchResult, Throwable]
Expand All @@ -94,13 +94,24 @@ class AsyncParquetReader(
private var currentGroup: PageReadStore = _

private val pageReader: ParquetFileReader = {
val options = HadoopReadOptions.builder(conf)
.withRange(offset, offset + len)
.withCodecFactory(new ParquetCodecFactory(conf, 0))
.withAllocator(new DirectByteBufferAllocator)
.build()
val bufferFile = new HMBInputFile(fileBuffer, length = Some(offset + len))
val reader = new ParquetFileReader(bufferFile, options)
val reader = if (!directBuffering) {
val options = HadoopReadOptions.builder(conf)
.withRange(offset, offset + len)
.withAllocator(new DirectByteBufferAllocator)
.withCodecFactory(new ParquetHeapCodecFactory(conf, 0))
.build()
val bufferFile = new HMBInputFile(fileBuffer, length = Some(offset + len))
new ParquetFileReader(bufferFile, options)
} else {
val options = HadoopReadOptions.builder(conf)
.withRange(offset, offset + len)
.withAllocator(new DirectByteBufferAllocator)
.withCodecFactory(new ParquetDirectCodecFactory(conf, 0))
.build()
val bufferFile = new HMBInputFile(fileBuffer, length = Some(offset + len))
new ParquetFileReader(bufferFile, options)
}

// The fileSchema here has already been clipped
reader.setRequestedSchema(clippedSchema)
reader
Expand Down Expand Up @@ -578,13 +589,14 @@ object AsyncParquetReader {
readDataSchema: StructType,
slotAcquired: Boolean,
enableDictLateMat: Boolean,
asynchronous: Boolean): AsyncParquetReader = {
asynchronous: Boolean,
directBuffering: Boolean): AsyncParquetReader = {
new AsyncParquetReader(conf,
tgtBatchSize, fileBuffer, offset, len,
metrics,
dateRebaseMode, timestampRebaseMode,
clippedSchema,
slotAcquired, enableDictLateMat, asynchronous)
slotAcquired, enableDictLateMat, asynchronous, directBuffering)
}
}

Expand Down Expand Up @@ -735,4 +747,5 @@ case class HybridParquetOpts(mode: String,
batchSizeBytes: Long,
pollInterval: Int,
enableDictLateMat: Boolean,
async: Boolean)
async: Boolean,
unsafeDecompression: Boolean)
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ public void advance(long sizeInByte) {
public static ByteBufferIsConsumer create(ByteBufferInputStream bis) {
List<ByteBuffer> buffers = bis.remainingBuffers();
if (buffers.isEmpty()) {
throw new IllegalArgumentException("Got empty ByteBufferInputStream");
System.err.println("Got empty ByteBufferInputStream");
return new EmptyBufferIsConsumer();
}
if (buffers.size() > 1) {
System.err.printf("create a MultiByteBuffersConsumer with %d buffers\n", buffers.size());
Expand All @@ -82,5 +83,4 @@ public static ByteBufferIsConsumer create(ByteBufferInputStream bis) {
}
return new DirectByteBufferIsConsumer(buffers.iterator());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.nio.ByteBuffer;
import java.util.Iterator;

import org.apache.parquet.bytes.ByteBufferInputStream;
import org.apache.parquet.io.api.Binary;
import org.apache.spark.sql.execution.vectorized.rapids.RapidsWritableColumnVector;
import org.apache.spark.sql.execution.vectorized.rapids.UnsafeMemoryUtils;
Expand All @@ -47,6 +46,7 @@ public DirectByteBufferIsConsumer(Iterator<ByteBuffer> bufferIterator) {
}

protected void pointToNextBuffer() {
// Close current buffer manually before pointing to next (since it is DirectBuffer)
super.pointToNextBuffer();
assert current.isDirect();
try {
Expand Down Expand Up @@ -109,6 +109,9 @@ public void readBinaries(int total, WritableColumnVector c, int rowId) {
for (int i = 0; i < total; ++i) {
if (!current.hasRemaining()) pointToNextBuffer();

if (current.remaining() < 4) {
throw new AssertionError("DirectBuffer remaining < 4 bytes(" + current.remaining() + ")");
}
int curLength = current.getInt();
int prevOffset = charVector.getElementsAppended();

Expand Down Expand Up @@ -235,5 +238,4 @@ public Binary getBinary(int len) {

return Binary.fromConstantByteArray(target);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.parquet.rapids;

import com.google.common.collect.Iterators;
import org.apache.parquet.io.api.Binary;
import org.apache.spark.sql.execution.vectorized.rapids.WritableColumnVector;

public class EmptyBufferIsConsumer extends ByteBufferIsConsumer {

public EmptyBufferIsConsumer() {
super(Iterators.emptyIterator());
}

@Override
public void readInts(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public void readLongs(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public void readFloats(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public void readDoubles(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public void readUIntsAsLongs(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public void readIntsAsShorts(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public void readIntsAsBytes(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public void readBinaries(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public byte getByte() {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public int getInt() {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");}

@Override
public long getLong() {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public float getFloat() {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public double getDouble() {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public Binary getBinary(int len) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}
}
Loading

0 comments on commit bec2ec4

Please sign in to comment.