From a72a5dfe525818ef6f941bf36b8b94fe5798d5b8 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Thu, 9 May 2024 18:43:50 +0800 Subject: [PATCH] add ParquetDirectCodecFactory --- .../nvidia/spark/rapids/GpuParquetScan.scala | 5 +- .../com/nvidia/spark/rapids/RapidsConf.scala | 26 ++- .../parquet/rapids/AsyncParquetReader.scala | 59 ++++--- .../parquet/rapids/ByteBufferIsConsumer.java | 4 +- .../rapids/DirectByteBufferIsConsumer.java | 6 +- .../parquet/rapids/EmptyBufferIsConsumer.java | 97 ++++++++++ .../rapids/ParquetDirectCodecFactory.java | 165 ++++++++++++++++++ ...tory.java => ParquetHeapCodecFactory.java} | 4 +- .../rapids/VectorizedPlainValuesReader.java | 2 - .../vectorized/rapids/UnsafeMemoryUtils.java | 43 ++++- 10 files changed, 365 insertions(+), 46 deletions(-) create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/EmptyBufferIsConsumer.java create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetDirectCodecFactory.java rename sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/{ParquetCodecFactory.java => ParquetHeapCodecFactory.java} (96%) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index 1d2c4685c472..249fbbf338a4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -1140,7 +1140,8 @@ case class GpuParquetMultiFilePartitionReaderFactory( rapidsConf.parquetScanHostBatchSizeBytes, 3, rapidsConf.parquetScanEnableDictLateMat, - rapidsConf.parquetScanHostAsync) + rapidsConf.parquetScanHostAsync, + rapidsConf.parquetScanUnsafeDecompression) // We can't use the coalescing files reader when InputFileName, InputFileBlockStart, // or InputFileBlockLength because we are combining all the files into a single buffer @@ -2623,7 +2624,7 @@ class MultiFileCloudParquetPartitionReader( hostBuffer, 0, dataSize, metrics, dateRebaseMode, timestampRebaseMode, hasInt96Timestamps, clippedSchema, readDataSchema, - slotAcquired, opts.enableDictLateMat, opts.async) + slotAcquired, opts.enableDictLateMat, opts.async, opts.unsafeDecompression) val batchIter = HostParquetIterator(asyncReader, opts, colTypes, metrics) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 1ab23e02239f..92e0c6e57b9f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -1545,28 +1545,34 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") .stringConf .createWithDefault("GPU_ONLY") - val PARQUET_HOST_SCAN_PARALLELISM = conf("spark.rapids.sql.parquet.scan.hostParallelism") + val PARQUET_CPU_SCAN_PARALLELISM = conf("spark.rapids.sql.parquet.scan.hostParallelism") .doc("The max concurrent capacity for host parquet scan(decode) tasks") .internal() .integerConf .createWithDefault(0) - val PARQUET_HOST_SCAN_BATCH_SIZE_BYTES = conf("spark.rapids.sql.parquet.scan.hostBatchSizeBytes") + val PARQUET_CPU_SCAN_BATCH_SIZE_BYTES = conf("spark.rapids.sql.parquet.scan.hostBatchSizeBytes") .doc("Similar to spark.rapids.sql.batchSizeBytes, but it is only for decode tasks run on CPUs") .internal() .integerConf .createWithDefault(1024 * 1024 * 128) - val PARQUET_HOST_SCAN_ASYNC = conf("spark.rapids.sql.parquet.scan.async") + val PARQUET_CPU_SCAN_ASYNC = conf("spark.rapids.sql.parquet.scan.async") .doc("Whether run host parquet decode tasks asynchronously or not") .internal() .booleanConf - .createWithDefault(false) + .createWithDefault(true) - val PARQUET_SCAN_DICT_LATE_MAT = conf("spark.rapids.sql.parquet.scan.enableDictLateMat") + val PARQUET_CPU_SCAN_DICT_LATE_MAT = conf("spark.rapids.sql.parquet.scan.enableDictLateMat") .doc("Whether pushing down binary dicts onto GPU and materializing via GPU or not") .internal() .booleanConf + .createWithDefault(true) + + val PARQUET_CPU_SCAN_UNSAFE_DECOMP = conf("spark.rapids.sql.parquet.scan.unsafeDecompress") + .doc("Whether using UnsafeDecompressor instead of the default one of parquet-hadoop or not") + .internal() + .booleanConf .createWithDefault(false) val ORC_DEBUG_DUMP_PREFIX = conf("spark.rapids.sql.orc.debug.dumpPrefix") @@ -2630,13 +2636,15 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val parquetScanHybridMode: String = get(PARQUET_HYBRID_SCAN_MODE) - lazy val parquetScanHostParallelism: Int = get(PARQUET_HOST_SCAN_PARALLELISM) + lazy val parquetScanHostParallelism: Int = get(PARQUET_CPU_SCAN_PARALLELISM) + + lazy val parquetScanHostBatchSizeBytes: Int = get(PARQUET_CPU_SCAN_BATCH_SIZE_BYTES) - lazy val parquetScanHostBatchSizeBytes: Int = get(PARQUET_HOST_SCAN_BATCH_SIZE_BYTES) + lazy val parquetScanHostAsync: Boolean = get(PARQUET_CPU_SCAN_ASYNC) - lazy val parquetScanHostAsync: Boolean = get(PARQUET_HOST_SCAN_ASYNC) + lazy val parquetScanEnableDictLateMat: Boolean = get(PARQUET_CPU_SCAN_DICT_LATE_MAT) - lazy val parquetScanEnableDictLateMat: Boolean = get(PARQUET_SCAN_DICT_LATE_MAT) + lazy val parquetScanUnsafeDecompression: Boolean = get(PARQUET_CPU_SCAN_UNSAFE_DECOMP) lazy val orcDebugDumpPrefix: Option[String] = get(ORC_DEBUG_DUMP_PREFIX) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/AsyncParquetReader.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/AsyncParquetReader.scala index 76a5567af782..fe9a2333b2d0 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/AsyncParquetReader.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/AsyncParquetReader.scala @@ -60,19 +60,19 @@ case class AsyncBatchResult(data: Array[HostColumnVector], class AsyncParquetReaderError(ex: Throwable) extends RuntimeException(ex) -class AsyncParquetReader( - conf: Configuration, - tgtBatchSize: Int, - fileBuffer: HostMemoryBuffer, - offset: Long, - len: Long, - metrics: Map[String, GpuMetric], - dateRebaseMode: DateTimeRebaseMode, - timestampRebaseMode: DateTimeRebaseMode, - clippedSchema: MessageType, - slotAcquired: Boolean, - enableDictLateMat: Boolean, - asynchronous: Boolean) +class AsyncParquetReader(conf: Configuration, + tgtBatchSize: Int, + fileBuffer: HostMemoryBuffer, + offset: Long, + len: Long, + metrics: Map[String, GpuMetric], + dateRebaseMode: DateTimeRebaseMode, + timestampRebaseMode: DateTimeRebaseMode, + clippedSchema: MessageType, + slotAcquired: Boolean, + enableDictLateMat: Boolean, + asynchronous: Boolean, + directBuffering: Boolean) extends Iterator[AsyncBatchResult] with AutoCloseable with Logging { private type Element = Either[AsyncBatchResult, Throwable] @@ -94,13 +94,24 @@ class AsyncParquetReader( private var currentGroup: PageReadStore = _ private val pageReader: ParquetFileReader = { - val options = HadoopReadOptions.builder(conf) - .withRange(offset, offset + len) - .withCodecFactory(new ParquetCodecFactory(conf, 0)) - .withAllocator(new DirectByteBufferAllocator) - .build() - val bufferFile = new HMBInputFile(fileBuffer, length = Some(offset + len)) - val reader = new ParquetFileReader(bufferFile, options) + val reader = if (!directBuffering) { + val options = HadoopReadOptions.builder(conf) + .withRange(offset, offset + len) + .withAllocator(new DirectByteBufferAllocator) + .withCodecFactory(new ParquetHeapCodecFactory(conf, 0)) + .build() + val bufferFile = new HMBInputFile(fileBuffer, length = Some(offset + len)) + new ParquetFileReader(bufferFile, options) + } else { + val options = HadoopReadOptions.builder(conf) + .withRange(offset, offset + len) + .withAllocator(new DirectByteBufferAllocator) + .withCodecFactory(new ParquetDirectCodecFactory(conf, 0)) + .build() + val bufferFile = new HMBInputFile(fileBuffer, length = Some(offset + len)) + new ParquetFileReader(bufferFile, options) + } + // The fileSchema here has already been clipped reader.setRequestedSchema(clippedSchema) reader @@ -578,13 +589,14 @@ object AsyncParquetReader { readDataSchema: StructType, slotAcquired: Boolean, enableDictLateMat: Boolean, - asynchronous: Boolean): AsyncParquetReader = { + asynchronous: Boolean, + directBuffering: Boolean): AsyncParquetReader = { new AsyncParquetReader(conf, tgtBatchSize, fileBuffer, offset, len, metrics, dateRebaseMode, timestampRebaseMode, clippedSchema, - slotAcquired, enableDictLateMat, asynchronous) + slotAcquired, enableDictLateMat, asynchronous, directBuffering) } } @@ -735,4 +747,5 @@ case class HybridParquetOpts(mode: String, batchSizeBytes: Long, pollInterval: Int, enableDictLateMat: Boolean, - async: Boolean) + async: Boolean, + unsafeDecompression: Boolean) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ByteBufferIsConsumer.java b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ByteBufferIsConsumer.java index 91f42a868572..39716cf5222e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ByteBufferIsConsumer.java +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ByteBufferIsConsumer.java @@ -71,7 +71,8 @@ public void advance(long sizeInByte) { public static ByteBufferIsConsumer create(ByteBufferInputStream bis) { List buffers = bis.remainingBuffers(); if (buffers.isEmpty()) { - throw new IllegalArgumentException("Got empty ByteBufferInputStream"); + System.err.println("Got empty ByteBufferInputStream"); + return new EmptyBufferIsConsumer(); } if (buffers.size() > 1) { System.err.printf("create a MultiByteBuffersConsumer with %d buffers\n", buffers.size()); @@ -82,5 +83,4 @@ public static ByteBufferIsConsumer create(ByteBufferInputStream bis) { } return new DirectByteBufferIsConsumer(buffers.iterator()); } - } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/DirectByteBufferIsConsumer.java b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/DirectByteBufferIsConsumer.java index d97faabe056c..88d7ce42a780 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/DirectByteBufferIsConsumer.java +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/DirectByteBufferIsConsumer.java @@ -21,7 +21,6 @@ import java.nio.ByteBuffer; import java.util.Iterator; -import org.apache.parquet.bytes.ByteBufferInputStream; import org.apache.parquet.io.api.Binary; import org.apache.spark.sql.execution.vectorized.rapids.RapidsWritableColumnVector; import org.apache.spark.sql.execution.vectorized.rapids.UnsafeMemoryUtils; @@ -47,6 +46,7 @@ public DirectByteBufferIsConsumer(Iterator bufferIterator) { } protected void pointToNextBuffer() { + // Close current buffer manually before pointing to next (since it is DirectBuffer) super.pointToNextBuffer(); assert current.isDirect(); try { @@ -109,6 +109,9 @@ public void readBinaries(int total, WritableColumnVector c, int rowId) { for (int i = 0; i < total; ++i) { if (!current.hasRemaining()) pointToNextBuffer(); + if (current.remaining() < 4) { + throw new AssertionError("DirectBuffer remaining < 4 bytes(" + current.remaining() + ")"); + } int curLength = current.getInt(); int prevOffset = charVector.getElementsAppended(); @@ -235,5 +238,4 @@ public Binary getBinary(int len) { return Binary.fromConstantByteArray(target); } - } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/EmptyBufferIsConsumer.java b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/EmptyBufferIsConsumer.java new file mode 100644 index 000000000000..e3c3113921a8 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/EmptyBufferIsConsumer.java @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet.rapids; + +import com.google.common.collect.Iterators; +import org.apache.parquet.io.api.Binary; +import org.apache.spark.sql.execution.vectorized.rapids.WritableColumnVector; + +public class EmptyBufferIsConsumer extends ByteBufferIsConsumer { + + public EmptyBufferIsConsumer() { + super(Iterators.emptyIterator()); + } + + @Override + public void readInts(int total, WritableColumnVector c, int rowId) { + throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer"); + } + + @Override + public void readLongs(int total, WritableColumnVector c, int rowId) { + throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer"); + } + + @Override + public void readFloats(int total, WritableColumnVector c, int rowId) { + throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer"); + } + + @Override + public void readDoubles(int total, WritableColumnVector c, int rowId) { + throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer"); + } + + @Override + public void readUIntsAsLongs(int total, WritableColumnVector c, int rowId) { + throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer"); + } + + @Override + public void readIntsAsShorts(int total, WritableColumnVector c, int rowId) { + throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer"); + } + + @Override + public void readIntsAsBytes(int total, WritableColumnVector c, int rowId) { + throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer"); + } + + @Override + public void readBinaries(int total, WritableColumnVector c, int rowId) { + throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer"); + } + + @Override + public byte getByte() { + throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer"); + } + + @Override + public int getInt() { + throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer");} + + @Override + public long getLong() { + throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer"); + } + + @Override + public float getFloat() { + throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer"); + } + + @Override + public double getDouble() { + throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer"); + } + + @Override + public Binary getBinary(int len) { + throw new AssertionError("We should NOT perform any reading from EmptyBufferIsConsumer"); + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetDirectCodecFactory.java b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetDirectCodecFactory.java new file mode 100644 index 000000000000..6cd4a830a712 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetDirectCodecFactory.java @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet.rapids; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +import com.github.luben.zstd.ZstdDirectBufferDecompressingStream; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.compress.CodecPool; +import org.apache.hadoop.io.compress.CompressionCodec; +import org.apache.hadoop.io.compress.Decompressor; +import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.bytes.BytesInput; +import org.apache.parquet.hadoop.CodecFactory; +import org.apache.parquet.hadoop.codec.ZstandardCodec; +import org.apache.parquet.hadoop.metadata.CompressionCodecName; + +public class ParquetDirectCodecFactory extends CodecFactory { + + public ParquetDirectCodecFactory(Configuration configuration, int pageSize) { + super(configuration, pageSize); + } + + @SuppressWarnings("deprecation") + class DirectBytesDecompressor extends BytesDecompressor { + + private final CompressionCodec codec; + private final Decompressor decompressor; + + DirectBytesDecompressor(CompressionCodecName codecName) { + this.codec = getCodec(codecName); + if (codec != null) { + decompressor = CodecPool.getDecompressor(codec); + } else { + decompressor = null; + } + } + + @Override + public BytesInput decompress(BytesInput bytes, int uncompressedSize) throws IOException { + if (codec == null) { + return bytes; + } + if (decompressor != null) { + decompressor.reset(); + } + if (!(codec instanceof ZstandardCodec)) { + InputStream is = codec.createInputStream(bytes.toInputStream(), decompressor); + return BytesInput.from(is, uncompressedSize); + } + + return decompressZstd(bytes); + } + + @Override + public void decompress( + ByteBuffer input, int compressedSize, ByteBuffer output, int uncompressedSize) + throws IOException { + ByteBuffer decompressed = + decompress(BytesInput.from(input), uncompressedSize).toByteBuffer(); + output.put(decompressed); + } + + @Override + public void release() { + if (decompressor != null) { + CodecPool.returnDecompressor(decompressor); + } + } + + /** + * Perform zstd decompression with direct memory as input and output + *

+ * 1. Exhaust the input data; + * 2. complete all the decompression work eagerly; + * 3. return uncompressed buffers; + */ + private BytesInput decompressZstd(BytesInput bytes) throws IOException { + List inputBuffers; + try (ByteBufferInputStream is = bytes.toInputStream()) { + inputBuffers = is.remainingBuffers(); + } + if (inputBuffers.isEmpty()) { + throw new IllegalArgumentException("Got empty ByteBufferInputStream"); + } + // Compute the total size of compressed data + int totalUncompressedSize = 0; + ByteBuffer tmpBlk = null; + + List decompressedBlocks = new ArrayList<>(); + ZstdDirectBufferDecompressingStream zis = null; + try { + for (ByteBuffer inputBuffer: inputBuffers) { + zis = new ZstdDirectBufferDecompressingStream(inputBuffer); + zis.setFinalize(false); + while (zis.hasRemaining()) { + if (tmpBlk == null || tmpBlk.remaining() < 64) { + if (tmpBlk != null) { + tmpBlk.flip(); + totalUncompressedSize += tmpBlk.remaining(); + decompressedBlocks.add(tmpBlk); + } + tmpBlk = ByteBuffer.allocateDirect(RECOMMENDED_BATCH_SIZE); + } + zis.read(tmpBlk); + } + zis.close(); + zis = null; + } + // append the tailing block if exists + if (tmpBlk != null) { + tmpBlk.flip(); + totalUncompressedSize += tmpBlk.remaining(); + decompressedBlocks.add(tmpBlk); + } + // merge all blocks into a continuous fused buffer + ByteBuffer concatBuffer = ByteBuffer.allocate(totalUncompressedSize); + for (ByteBuffer blk : decompressedBlocks) { + concatBuffer.put(blk); + } + concatBuffer.flip(); + + return BytesInput.from(concatBuffer); + + } catch (Throwable ex) { + if (zis != null) { + zis.close(); + } + for (ByteBuffer buf: inputBuffers) { + buf.clear(); + } + for (ByteBuffer buf: decompressedBlocks) { + buf.clear(); + } + throw new RuntimeException(ex); + } + } + } + + @Override + @SuppressWarnings("deprecation") + protected BytesDecompressor createDecompressor(CompressionCodecName codecName) { + return new ParquetDirectCodecFactory.DirectBytesDecompressor(codecName); + } + + private static final int RECOMMENDED_BATCH_SIZE = 128 * 1024; +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetCodecFactory.java b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetHeapCodecFactory.java similarity index 96% rename from sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetCodecFactory.java rename to sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetHeapCodecFactory.java index 9bd1d981f383..3776de226112 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetCodecFactory.java +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetHeapCodecFactory.java @@ -34,9 +34,9 @@ * workaround for memory issues encountered when reading from zstd-compressed files. For * details, see PARQUET-2160 */ -public class ParquetCodecFactory extends CodecFactory { +public class ParquetHeapCodecFactory extends CodecFactory { - public ParquetCodecFactory(Configuration configuration, int pageSize) { + public ParquetHeapCodecFactory(Configuration configuration, int pageSize) { super(configuration, pageSize); } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/VectorizedPlainValuesReader.java b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/VectorizedPlainValuesReader.java index 2a69e0037925..61d87ff0696b 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/VectorizedPlainValuesReader.java +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/VectorizedPlainValuesReader.java @@ -16,8 +16,6 @@ package org.apache.spark.sql.execution.datasources.parquet.rapids; -import java.io.IOException; - import org.apache.commons.lang3.NotImplementedException; import org.apache.parquet.bytes.ByteBufferInputStream; import org.apache.parquet.column.values.ValuesReader; diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/vectorized/rapids/UnsafeMemoryUtils.java b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/vectorized/rapids/UnsafeMemoryUtils.java index c655af14268e..cbfcf2cc5939 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/vectorized/rapids/UnsafeMemoryUtils.java +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/vectorized/rapids/UnsafeMemoryUtils.java @@ -16,22 +16,36 @@ package org.apache.spark.sql.execution.vectorized.rapids; +import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.nio.Buffer; +import java.nio.ByteBuffer; public class UnsafeMemoryUtils { - static Method copyMemoryDirectly; + static Method copyDirectMemory; + static Method freeDirectMemory; + static Field getDirectAddress; static { try { Class clz = Class.forName("ai.rapids.cudf.UnsafeMemoryAccessor"); - copyMemoryDirectly = clz.getMethod("copyMemory", + copyDirectMemory = clz.getMethod("copyMemory", Object.class, long.class, Object.class, long.class, long.class); - copyMemoryDirectly.setAccessible(true); + copyDirectMemory.setAccessible(true); + freeDirectMemory = clz.getMethod("free", long.class); + freeDirectMemory.setAccessible(true); } catch (ClassNotFoundException | NoSuchMethodException e) { throw new RuntimeException(e); } + + try { + getDirectAddress = Buffer.class.getDeclaredField("address"); + getDirectAddress.setAccessible(true); + } catch (NoSuchFieldException e) { + throw new RuntimeException(e); + } } /** @@ -41,7 +55,7 @@ public class UnsafeMemoryUtils { public static void copyMemory(Object src, long srcOffset, Object dst, long dstOffset, long length) { try { - copyMemoryDirectly.invoke(null, + copyDirectMemory.invoke(null, src, srcOffset, dst, @@ -52,4 +66,25 @@ public static void copyMemory(Object src, long srcOffset, Object dst, long dstOf } } + /** + * The reflection of `ai.rapids.cudf.UnsafeMemoryAccessor.free` + */ + public static void freeMemory(long address) { + try { + freeDirectMemory.invoke(null, address); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException(e); + } + } + + public static void freeDirectByteBuffer(ByteBuffer bb) { + assert bb.isDirect(); + try { + long address = (long) getDirectAddress.get(bb); + freeMemory(address); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + } + }