Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HybridScan supports ZSTD decomp with DirectBuffer #1

Merged
merged 1 commit into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,8 @@ case class GpuParquetMultiFilePartitionReaderFactory(
rapidsConf.parquetScanHostBatchSizeBytes,
3,
rapidsConf.parquetScanEnableDictLateMat,
rapidsConf.parquetScanHostAsync)
rapidsConf.parquetScanHostAsync,
rapidsConf.parquetScanUnsafeDecompression)

// We can't use the coalescing files reader when InputFileName, InputFileBlockStart,
// or InputFileBlockLength because we are combining all the files into a single buffer
Expand Down Expand Up @@ -2623,7 +2624,7 @@ class MultiFileCloudParquetPartitionReader(
hostBuffer, 0, dataSize, metrics,
dateRebaseMode, timestampRebaseMode, hasInt96Timestamps,
clippedSchema, readDataSchema,
slotAcquired, opts.enableDictLateMat, opts.async)
slotAcquired, opts.enableDictLateMat, opts.async, opts.unsafeDecompression)

val batchIter = HostParquetIterator(asyncReader, opts, colTypes, metrics)

Expand Down
26 changes: 17 additions & 9 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1545,28 +1545,34 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern")
.stringConf
.createWithDefault("GPU_ONLY")

val PARQUET_HOST_SCAN_PARALLELISM = conf("spark.rapids.sql.parquet.scan.hostParallelism")
val PARQUET_CPU_SCAN_PARALLELISM = conf("spark.rapids.sql.parquet.scan.hostParallelism")
.doc("The max concurrent capacity for host parquet scan(decode) tasks")
.internal()
.integerConf
.createWithDefault(0)

val PARQUET_HOST_SCAN_BATCH_SIZE_BYTES = conf("spark.rapids.sql.parquet.scan.hostBatchSizeBytes")
val PARQUET_CPU_SCAN_BATCH_SIZE_BYTES = conf("spark.rapids.sql.parquet.scan.hostBatchSizeBytes")
.doc("Similar to spark.rapids.sql.batchSizeBytes, but it is only for decode tasks run on CPUs")
.internal()
.integerConf
.createWithDefault(1024 * 1024 * 128)

val PARQUET_HOST_SCAN_ASYNC = conf("spark.rapids.sql.parquet.scan.async")
val PARQUET_CPU_SCAN_ASYNC = conf("spark.rapids.sql.parquet.scan.async")
.doc("Whether run host parquet decode tasks asynchronously or not")
.internal()
.booleanConf
.createWithDefault(false)
.createWithDefault(true)

val PARQUET_SCAN_DICT_LATE_MAT = conf("spark.rapids.sql.parquet.scan.enableDictLateMat")
val PARQUET_CPU_SCAN_DICT_LATE_MAT = conf("spark.rapids.sql.parquet.scan.enableDictLateMat")
.doc("Whether pushing down binary dicts onto GPU and materializing via GPU or not")
.internal()
.booleanConf
.createWithDefault(true)

val PARQUET_CPU_SCAN_UNSAFE_DECOMP = conf("spark.rapids.sql.parquet.scan.unsafeDecompress")
.doc("Whether using UnsafeDecompressor instead of the default one of parquet-hadoop or not")
.internal()
.booleanConf
.createWithDefault(false)

val ORC_DEBUG_DUMP_PREFIX = conf("spark.rapids.sql.orc.debug.dumpPrefix")
Expand Down Expand Up @@ -2630,13 +2636,15 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val parquetScanHybridMode: String = get(PARQUET_HYBRID_SCAN_MODE)

lazy val parquetScanHostParallelism: Int = get(PARQUET_HOST_SCAN_PARALLELISM)
lazy val parquetScanHostParallelism: Int = get(PARQUET_CPU_SCAN_PARALLELISM)

lazy val parquetScanHostBatchSizeBytes: Int = get(PARQUET_CPU_SCAN_BATCH_SIZE_BYTES)

lazy val parquetScanHostBatchSizeBytes: Int = get(PARQUET_HOST_SCAN_BATCH_SIZE_BYTES)
lazy val parquetScanHostAsync: Boolean = get(PARQUET_CPU_SCAN_ASYNC)

lazy val parquetScanHostAsync: Boolean = get(PARQUET_HOST_SCAN_ASYNC)
lazy val parquetScanEnableDictLateMat: Boolean = get(PARQUET_CPU_SCAN_DICT_LATE_MAT)

lazy val parquetScanEnableDictLateMat: Boolean = get(PARQUET_SCAN_DICT_LATE_MAT)
lazy val parquetScanUnsafeDecompression: Boolean = get(PARQUET_CPU_SCAN_UNSAFE_DECOMP)

lazy val orcDebugDumpPrefix: Option[String] = get(ORC_DEBUG_DUMP_PREFIX)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,19 @@ case class AsyncBatchResult(data: Array[HostColumnVector],

class AsyncParquetReaderError(ex: Throwable) extends RuntimeException(ex)

class AsyncParquetReader(
conf: Configuration,
tgtBatchSize: Int,
fileBuffer: HostMemoryBuffer,
offset: Long,
len: Long,
metrics: Map[String, GpuMetric],
dateRebaseMode: DateTimeRebaseMode,
timestampRebaseMode: DateTimeRebaseMode,
clippedSchema: MessageType,
slotAcquired: Boolean,
enableDictLateMat: Boolean,
asynchronous: Boolean)
class AsyncParquetReader(conf: Configuration,
tgtBatchSize: Int,
fileBuffer: HostMemoryBuffer,
offset: Long,
len: Long,
metrics: Map[String, GpuMetric],
dateRebaseMode: DateTimeRebaseMode,
timestampRebaseMode: DateTimeRebaseMode,
clippedSchema: MessageType,
slotAcquired: Boolean,
enableDictLateMat: Boolean,
asynchronous: Boolean,
directBuffering: Boolean)
extends Iterator[AsyncBatchResult] with AutoCloseable with Logging {

private type Element = Either[AsyncBatchResult, Throwable]
Expand All @@ -94,13 +94,24 @@ class AsyncParquetReader(
private var currentGroup: PageReadStore = _

private val pageReader: ParquetFileReader = {
val options = HadoopReadOptions.builder(conf)
.withRange(offset, offset + len)
.withCodecFactory(new ParquetCodecFactory(conf, 0))
.withAllocator(new DirectByteBufferAllocator)
.build()
val bufferFile = new HMBInputFile(fileBuffer, length = Some(offset + len))
val reader = new ParquetFileReader(bufferFile, options)
val reader = if (!directBuffering) {
val options = HadoopReadOptions.builder(conf)
.withRange(offset, offset + len)
.withAllocator(new DirectByteBufferAllocator)
.withCodecFactory(new ParquetHeapCodecFactory(conf, 0))
.build()
val bufferFile = new HMBInputFile(fileBuffer, length = Some(offset + len))
new ParquetFileReader(bufferFile, options)
} else {
val options = HadoopReadOptions.builder(conf)
.withRange(offset, offset + len)
.withAllocator(new DirectByteBufferAllocator)
.withCodecFactory(new ParquetDirectCodecFactory(conf, 0))
.build()
val bufferFile = new HMBInputFile(fileBuffer, length = Some(offset + len))
new ParquetFileReader(bufferFile, options)
}

// The fileSchema here has already been clipped
reader.setRequestedSchema(clippedSchema)
reader
Expand Down Expand Up @@ -578,13 +589,14 @@ object AsyncParquetReader {
readDataSchema: StructType,
slotAcquired: Boolean,
enableDictLateMat: Boolean,
asynchronous: Boolean): AsyncParquetReader = {
asynchronous: Boolean,
directBuffering: Boolean): AsyncParquetReader = {
new AsyncParquetReader(conf,
tgtBatchSize, fileBuffer, offset, len,
metrics,
dateRebaseMode, timestampRebaseMode,
clippedSchema,
slotAcquired, enableDictLateMat, asynchronous)
slotAcquired, enableDictLateMat, asynchronous, directBuffering)
}
}

Expand Down Expand Up @@ -735,4 +747,5 @@ case class HybridParquetOpts(mode: String,
batchSizeBytes: Long,
pollInterval: Int,
enableDictLateMat: Boolean,
async: Boolean)
async: Boolean,
unsafeDecompression: Boolean)
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ public void advance(long sizeInByte) {
public static ByteBufferIsConsumer create(ByteBufferInputStream bis) {
List<ByteBuffer> buffers = bis.remainingBuffers();
if (buffers.isEmpty()) {
throw new IllegalArgumentException("Got empty ByteBufferInputStream");
System.err.println("Got empty ByteBufferInputStream");
return new EmptyBufferIsConsumer();
}
if (buffers.size() > 1) {
System.err.printf("create a MultiByteBuffersConsumer with %d buffers\n", buffers.size());
Expand All @@ -82,5 +83,4 @@ public static ByteBufferIsConsumer create(ByteBufferInputStream bis) {
}
return new DirectByteBufferIsConsumer(buffers.iterator());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.nio.ByteBuffer;
import java.util.Iterator;

import org.apache.parquet.bytes.ByteBufferInputStream;
import org.apache.parquet.io.api.Binary;
import org.apache.spark.sql.execution.vectorized.rapids.RapidsWritableColumnVector;
import org.apache.spark.sql.execution.vectorized.rapids.UnsafeMemoryUtils;
Expand All @@ -47,6 +46,7 @@ public DirectByteBufferIsConsumer(Iterator<ByteBuffer> bufferIterator) {
}

protected void pointToNextBuffer() {
// Close current buffer manually before pointing to next (since it is DirectBuffer)
super.pointToNextBuffer();
assert current.isDirect();
try {
Expand Down Expand Up @@ -109,6 +109,9 @@ public void readBinaries(int total, WritableColumnVector c, int rowId) {
for (int i = 0; i < total; ++i) {
if (!current.hasRemaining()) pointToNextBuffer();

if (current.remaining() < 4) {
throw new AssertionError("DirectBuffer remaining < 4 bytes(" + current.remaining() + ")");
}
int curLength = current.getInt();
int prevOffset = charVector.getElementsAppended();

Expand Down Expand Up @@ -235,5 +238,4 @@ public Binary getBinary(int len) {

return Binary.fromConstantByteArray(target);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.parquet.rapids;

import com.google.common.collect.Iterators;
import org.apache.parquet.io.api.Binary;
import org.apache.spark.sql.execution.vectorized.rapids.WritableColumnVector;

public class EmptyBufferIsConsumer extends ByteBufferIsConsumer {

public EmptyBufferIsConsumer() {
super(Iterators.emptyIterator());
}

@Override
public void readInts(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public void readLongs(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public void readFloats(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public void readDoubles(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public void readUIntsAsLongs(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public void readIntsAsShorts(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public void readIntsAsBytes(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public void readBinaries(int total, WritableColumnVector c, int rowId) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public byte getByte() {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public int getInt() {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");}

@Override
public long getLong() {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public float getFloat() {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public double getDouble() {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}

@Override
public Binary getBinary(int len) {
throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");
}
}
Loading
Loading