From e0b7bbaf749e6a496ff2f888b2707a4e7f1b8317 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 26 Apr 2024 18:38:52 +0900 Subject: [PATCH] enable DirectByteBufferAllocator for ParquetFileReader --- .../parquet/rapids/AsyncParquetReader.scala | 2 + .../parquet/rapids/ByteBufferIsConsumer.java | 86 +++++++ .../rapids/DirectByteBufferIsConsumer.java | 239 ++++++++++++++++++ ...mer.java => HeapByteBufferIsConsumer.java} | 51 ++-- .../rapids/ParquetVectorUpdaterFactory.java | 2 +- .../rapids/VectorizedPlainValuesReader.java | 6 +- .../rapids/RapidsWritableColumnVector.java | 41 ++- .../vectorized/rapids/UnsafeMemoryUtils.java | 55 ++++ 8 files changed, 435 insertions(+), 47 deletions(-) create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ByteBufferIsConsumer.java create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/DirectByteBufferIsConsumer.java rename sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/{MultiByteBuffersConsumer.java => HeapByteBufferIsConsumer.java} (85%) create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/execution/vectorized/rapids/UnsafeMemoryUtils.java diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/AsyncParquetReader.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/AsyncParquetReader.scala index f99b75a86ef..76a5567af78 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/AsyncParquetReader.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/AsyncParquetReader.scala @@ -32,6 +32,7 @@ import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{AutoCloseableAttemptSplit import org.apache.hadoop.conf.Configuration import org.apache.parquet.{HadoopReadOptions, VersionParser} import org.apache.parquet.VersionParser.ParsedVersion +import org.apache.parquet.bytes.DirectByteBufferAllocator import org.apache.parquet.column.page.PageReadStore import org.apache.parquet.hadoop.ParquetFileReader import org.apache.parquet.schema.MessageType @@ -96,6 +97,7 @@ class AsyncParquetReader( val options = HadoopReadOptions.builder(conf) .withRange(offset, offset + len) .withCodecFactory(new ParquetCodecFactory(conf, 0)) + .withAllocator(new DirectByteBufferAllocator) .build() val bufferFile = new HMBInputFile(fileBuffer, length = Some(offset + len)) val reader = new ParquetFileReader(bufferFile, options) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ByteBufferIsConsumer.java b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ByteBufferIsConsumer.java new file mode 100644 index 00000000000..91f42a86857 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ByteBufferIsConsumer.java @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet.rapids; + +import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.io.api.Binary; +import org.apache.spark.sql.execution.vectorized.rapids.WritableColumnVector; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Iterator; +import java.util.List; + +public abstract class ByteBufferIsConsumer { + + protected Iterator iterator; + protected ByteBuffer current = null; + + public ByteBufferIsConsumer(Iterator bufferIterator) { + iterator = bufferIterator; + if (iterator.hasNext()) { + pointToNextBuffer(); + } + } + + protected void pointToNextBuffer() { + current = iterator.next(); + current.order(ByteOrder.LITTLE_ENDIAN); + } + + public void advance(long sizeInByte) { + long remaining = sizeInByte; + while (remaining > 0) { + if (!current.hasRemaining()) pointToNextBuffer(); + + int batchSize = (remaining >= current.remaining()) ? current.remaining() : (int) remaining; + current.position(current.position() + batchSize); + remaining -= batchSize; + } + } + + public abstract void readInts(int total, WritableColumnVector c, int rowId); + public abstract void readLongs(int total, WritableColumnVector c, int rowId); + public abstract void readFloats(int total, WritableColumnVector c, int rowId); + public abstract void readDoubles(int total, WritableColumnVector c, int rowId); + public abstract void readUIntsAsLongs(int total, WritableColumnVector c, int rowId); + public abstract void readIntsAsShorts(int total, WritableColumnVector c, int rowId); + public abstract void readIntsAsBytes(int total, WritableColumnVector c, int rowId); + public abstract void readBinaries(int total, WritableColumnVector c, int rowId); + public abstract byte getByte(); + public abstract int getInt(); + public abstract long getLong(); + public abstract float getFloat(); + public abstract double getDouble(); + public abstract Binary getBinary(int len); + + public static ByteBufferIsConsumer create(ByteBufferInputStream bis) { + List buffers = bis.remainingBuffers(); + if (buffers.isEmpty()) { + throw new IllegalArgumentException("Got empty ByteBufferInputStream"); + } + if (buffers.size() > 1) { + System.err.printf("create a MultiByteBuffersConsumer with %d buffers\n", buffers.size()); + } + // HeapByteBufferIsConsumer for HeapByteBuffer; DirectByteBufferIsConsumer for DirectByteBuffer + if (buffers.get(0).hasArray()) { + return new HeapByteBufferIsConsumer(buffers.iterator()); + } + return new DirectByteBufferIsConsumer(buffers.iterator()); + } + +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/DirectByteBufferIsConsumer.java b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/DirectByteBufferIsConsumer.java new file mode 100644 index 00000000000..d97faabe056 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/DirectByteBufferIsConsumer.java @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet.rapids; + +import java.lang.reflect.Field; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.util.Iterator; + +import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.io.api.Binary; +import org.apache.spark.sql.execution.vectorized.rapids.RapidsWritableColumnVector; +import org.apache.spark.sql.execution.vectorized.rapids.UnsafeMemoryUtils; +import org.apache.spark.sql.execution.vectorized.rapids.WritableColumnVector; + +public class DirectByteBufferIsConsumer extends ByteBufferIsConsumer { + + static Field addrField; + + static { + try { + addrField = Buffer.class.getDeclaredField("address"); + addrField.setAccessible(true); + } catch (NoSuchFieldException e) { + throw new RuntimeException(e); + } + } + + private long address; + + public DirectByteBufferIsConsumer(Iterator bufferIterator) { + super(bufferIterator); + } + + protected void pointToNextBuffer() { + super.pointToNextBuffer(); + assert current.isDirect(); + try { + address = addrField.getLong(current); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + private void readFixedLengthData(int total, WritableColumnVector c, int rowId, int bitWidth) { + assert c instanceof RapidsWritableColumnVector : "Only supports RapidsWritableColumnVector"; + RapidsWritableColumnVector cv = (RapidsWritableColumnVector) c; + + int remaining = total << bitWidth; + int tgtOffset = rowId; + + while (remaining > 0) { + if (!current.hasRemaining()) pointToNextBuffer(); + + int size = Math.min(remaining, current.remaining()); + // TODO: handle buffer tails separately, when there are multiple buffers + if (size >> 2 << 2 != size) { + throw new RuntimeException("Will support the special handling of buffer tails, when there are multiple buffers"); + } + long srcOffset = address + current.position(); + int sizeInRow = size >> bitWidth; + cv.putFixedLengthElementsUnsafely(tgtOffset, srcOffset, size, bitWidth); + current.position(current.position() + size); + tgtOffset += sizeInRow; + remaining -= size; + } + } + + @Override + public void readInts(int total, WritableColumnVector c, int rowId) { + readFixedLengthData(total, c, rowId, 2); + } + + @Override + public void readLongs(int total, WritableColumnVector c, int rowId) { + readFixedLengthData(total, c, rowId, 3); + } + + @Override + public void readFloats(int total, WritableColumnVector c, int rowId) { + readFixedLengthData(total, c, rowId, 2); + } + + @Override + public void readDoubles(int total, WritableColumnVector c, int rowId) { + readFixedLengthData(total, c, rowId, 3); + } + + @Override + public void readBinaries(int total, WritableColumnVector c, int rowId) { + assert c instanceof RapidsWritableColumnVector : "Only supports RapidsWritableColumnVector"; + RapidsWritableColumnVector cv = (RapidsWritableColumnVector) c; + RapidsWritableColumnVector charVector = (RapidsWritableColumnVector) cv.arrayData(); + + for (int i = 0; i < total; ++i) { + if (!current.hasRemaining()) pointToNextBuffer(); + + int curLength = current.getInt(); + int prevOffset = charVector.getElementsAppended(); + + if (curLength > 0) { + charVector.reserve(prevOffset + curLength); + + int remainLen = curLength; + int charOffset = prevOffset; + while (remainLen > 0) { + if (!current.hasRemaining()) pointToNextBuffer(); + + int size = Math.min(remainLen, current.remaining()); + int bufPos = current.position(); + charVector.putFixedLengthElementsUnsafely(charOffset, this.address + bufPos, size, 0); + charOffset += size; + current.position(bufPos + size); + remainLen -= size; + } + charVector.addElementsAppended(curLength); + } + + cv.commitStringAppend(rowId + i, prevOffset, curLength); + } + } + + @Override + public void readUIntsAsLongs(int total, WritableColumnVector c, int rowId) { + int remaining = total * 4; + int tgtOffset = rowId; + + while (remaining > 0) { + if (!current.hasRemaining()) pointToNextBuffer(); + + int size = Math.min(remaining, current.remaining()); + for (int i = 0; i < size >> 2; ++i) { + c.putLong(tgtOffset + i, Integer.toUnsignedLong(current.getInt())); + } + tgtOffset += size >> 3; + remaining -= size; + } + } + + @Override + public void readIntsAsShorts(int total, WritableColumnVector c, int rowId) { + int remaining = total * 4; + int tgtOffset = rowId; + + while (remaining > 0) { + if (!current.hasRemaining()) pointToNextBuffer(); + + int size = Math.min(remaining, current.remaining()); + for (int i = 0; i < size >> 2; ++i) { + c.putShort(tgtOffset + i, (short) current.getInt()); + } + tgtOffset += size >> 1; + remaining -= size; + } + } + + @Override + public void readIntsAsBytes(int total, WritableColumnVector c, int rowId) { + int remaining = total * 4; + int tgtOffset = rowId; + + while (remaining > 0) { + if (!current.hasRemaining()) pointToNextBuffer(); + + int size = Math.min(remaining, current.remaining()); + int pos = current.position(); + for (int i = 0; i < size >> 2; ++i) { + c.putByte(tgtOffset + i, current.get(pos + (i << 2))); + } + tgtOffset += size; + current.position(pos + size); + remaining -= size; + } + } + + @Override + public byte getByte() { + if (!current.hasRemaining()) pointToNextBuffer(); + return current.get(); + } + + @Override + public int getInt() { + if (!current.hasRemaining()) pointToNextBuffer(); + return current.getInt(); + } + + @Override + public long getLong() { + if (!current.hasRemaining()) pointToNextBuffer(); + return current.getLong(); + } + + @Override + public float getFloat() { + if (!current.hasRemaining()) pointToNextBuffer(); + return current.getFloat(); + } + + @Override + public double getDouble() { + if (!current.hasRemaining()) pointToNextBuffer(); + return current.getDouble(); + } + + @Override + public Binary getBinary(int len) { + byte[] target = new byte[len]; + int targetOffset = 0; + + do { + if (!current.hasRemaining()) pointToNextBuffer(); + + int batchSize = Math.min(len - targetOffset, current.remaining()); + UnsafeMemoryUtils.copyMemory(null, + this.address + current.position(), target, targetOffset, batchSize); + current.position(current.position() + batchSize); + targetOffset += batchSize; + } + while (targetOffset < len); + + return Binary.fromConstantByteArray(target); + } + +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/MultiByteBuffersConsumer.java b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/HeapByteBufferIsConsumer.java similarity index 85% rename from sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/MultiByteBuffersConsumer.java rename to sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/HeapByteBufferIsConsumer.java index 32b48606ff2..feccbc44874 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/MultiByteBuffersConsumer.java +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/HeapByteBufferIsConsumer.java @@ -22,35 +22,21 @@ import org.apache.spark.sql.execution.vectorized.rapids.WritableColumnVector; import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.util.Iterator; -import java.util.List; -public class MultiByteBuffersConsumer -{ - - Iterator iterator; - ByteBuffer current = null; +public class HeapByteBufferIsConsumer extends ByteBufferIsConsumer { // Reference of underlying data structure of HeapByteBuffer, since we assume the BIS is // backed by the HeapByteBuffer byte[] hb; int arrayOffset; - public MultiByteBuffersConsumer(ByteBufferInputStream bis) { - List buffers = bis.remainingBuffers(); - if (buffers.size() > 1) { - System.err.printf("create a MultiByteBuffersConsumer with %d buffers\n", buffers.size()); - } - iterator = buffers.iterator(); - if (iterator.hasNext()) { - pointToNextBuffer(); - } + public HeapByteBufferIsConsumer(Iterator bufferIterator) { + super(bufferIterator); } - private void pointToNextBuffer() { - current = iterator.next(); - current.order(ByteOrder.LITTLE_ENDIAN); + protected void pointToNextBuffer() { + super.pointToNextBuffer(); assert current.hasArray(); hb = current.array(); arrayOffset = current.arrayOffset(); @@ -64,6 +50,10 @@ public void readInts(int total, WritableColumnVector c, int rowId) { if (!current.hasRemaining()) pointToNextBuffer(); int size = Math.min(remaining, current.remaining()); + // TODO: handle buffer tails separately, when there are multiple buffers + if (size >> 2 << 2 != size) { + throw new RuntimeException("Will support the special handling of buffer tails, when there are multiple buffers"); + } int srcOffset = this.arrayOffset + current.position(); int sizeInRow = size >> 2; c.putIntsLittleEndian(tgtOffset, sizeInRow, hb, srcOffset); @@ -81,6 +71,10 @@ public void readLongs(int total, WritableColumnVector c, int rowId) { if (!current.hasRemaining()) pointToNextBuffer(); int size = Math.min(remaining, current.remaining()); + // TODO: handle buffer tails separately, when there are multiple buffers + if (size >> 3 << 3 != size) { + throw new RuntimeException("Will support the special handling of buffer tails, when there are multiple buffers"); + } int srcOffset = this.arrayOffset + current.position(); int sizeInRow = size >> 3; c.putLongsLittleEndian(tgtOffset, sizeInRow, hb, srcOffset); @@ -98,6 +92,10 @@ public void readFloats(int total, WritableColumnVector c, int rowId) { if (!current.hasRemaining()) pointToNextBuffer(); int size = Math.min(remaining, current.remaining()); + // TODO: handle buffer tails separately, when there are multiple buffers + if (size >> 2 << 2 != size) { + throw new RuntimeException("Will support the special handling of buffer tails, when there are multiple buffers"); + } int srcOffset = this.arrayOffset + current.position(); int sizeInRow = size >> 2; c.putFloatsLittleEndian(tgtOffset, sizeInRow, hb, srcOffset); @@ -115,6 +113,10 @@ public void readDoubles(int total, WritableColumnVector c, int rowId) { if (!current.hasRemaining()) pointToNextBuffer(); int size = Math.min(remaining, current.remaining()); + // TODO: handle buffer tails separately, when there are multiple buffers + if (size >> 3 << 3 != size) { + throw new RuntimeException("Will support the special handling of buffer tails, when there are multiple buffers"); + } int srcOffset = this.arrayOffset + current.position(); int sizeInRow = size >> 3; c.putDoublesLittleEndian(tgtOffset, sizeInRow, hb, srcOffset); @@ -249,15 +251,4 @@ public Binary getBinary(int len) { return Binary.fromConstantByteArray(target); } - public void advance(long sizeInByte) { - long remaining = sizeInByte; - while (remaining > 0) { - if (!current.hasRemaining()) pointToNextBuffer(); - - int batchSize = (remaining >= current.remaining()) ? current.remaining() : (int) remaining; - current.position(current.position() + batchSize); - remaining -= batchSize; - } - } - } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetVectorUpdaterFactory.java b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetVectorUpdaterFactory.java index ea0e63ee828..09d2257bd20 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetVectorUpdaterFactory.java +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetVectorUpdaterFactory.java @@ -799,7 +799,7 @@ public void decodeSingleDictionaryId( OffHeapBinaryDictionary offHeapDict = (OffHeapBinaryDictionary) dictionary; int[] dctOff = offHeapDict.getOffsets(); int id = dictionaryIds.getDictId(offset); - cv.putBytesUnsafely(offset, offHeapDict.getData(), dctOff[id], dctOff[id + 1] - dctOff[id]); + cv.copyStringFromOther(offset, offHeapDict.getData(), dctOff[id], dctOff[id + 1] - dctOff[id]); } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/VectorizedPlainValuesReader.java b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/VectorizedPlainValuesReader.java index c3423471700..2a69e003792 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/VectorizedPlainValuesReader.java +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/VectorizedPlainValuesReader.java @@ -31,7 +31,7 @@ */ public class VectorizedPlainValuesReader extends ValuesReader implements VectorizedValuesReader { - private MultiByteBuffersConsumer consumer = null; + private ByteBufferIsConsumer consumer = null; // Only used for booleans. private int bitOffset; @@ -41,8 +41,8 @@ public VectorizedPlainValuesReader() { } @Override - public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException { - consumer = new MultiByteBuffersConsumer(in); + public void initFromPage(int valueCount, ByteBufferInputStream in) { + consumer = ByteBufferIsConsumer.create(in); } @Override diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/vectorized/rapids/RapidsWritableColumnVector.java b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/vectorized/rapids/RapidsWritableColumnVector.java index 5223156f54e..e586bed1b65 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/vectorized/rapids/RapidsWritableColumnVector.java +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/vectorized/rapids/RapidsWritableColumnVector.java @@ -505,7 +505,34 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { return result; } - public void putBytesUnsafely(int rowId, HostMemoryBuffer buffer, int offset, int length) { + // Append fixed-width elements directly from UnsafeMemory (such as DirectByteBuffer) + public void putFixedLengthElementsUnsafely(int rowId, long srcOffset, long copySize, int bitWidth) { + assert copySize >> bitWidth << bitWidth == copySize : "copySize is not aligned to bitWidth"; + + rowGroupIndex += (int) (copySize >> bitWidth); + rowId += currentRowGroupOffset; + + UnsafeMemoryUtils.copyMemory( + null, + srcOffset, + null, + data.getAddress() + ((long) rowId << bitWidth), + copySize); + } + + public void commitStringAppend(int rowId, int prevOffset, int curLength) { + rowGroupIndex++; + rowId += currentRowGroupOffset; + prevOffset += rowGroupStringOffset; + + for (int i = lastCharRowId + 1; i < rowId; ++i) { + charOffsets.setInt((i + 1) * 4L, prevOffset); + } + charOffsets.setInt((rowId + 1) * 4L, prevOffset + curLength); + lastCharRowId = rowId; + } + + public void copyStringFromOther(int rowId, HostMemoryBuffer buffer, int offset, int length) { rowGroupIndex++; rowId += currentRowGroupOffset; @@ -542,18 +569,6 @@ public void putBytesUnsafely(int rowId, HostMemoryBuffer buffer, int offset, int lastCharRowId = rowId; } - public void commitStringAppend(int rowId, int prevOffset, int curLength) { - rowGroupIndex++; - rowId += currentRowGroupOffset; - prevOffset += rowGroupStringOffset; - - for (int i = lastCharRowId + 1; i < rowId; ++i) { - charOffsets.setInt((i + 1) * 4L, prevOffset); - } - charOffsets.setInt((rowId + 1) * 4L, prevOffset + curLength); - lastCharRowId = rowId; - } - @Override public void reserve(int requiredCapacity) { super.reserve(requiredCapacity + currentRowGroupOffset); diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/vectorized/rapids/UnsafeMemoryUtils.java b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/vectorized/rapids/UnsafeMemoryUtils.java new file mode 100644 index 00000000000..c655af14268 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/vectorized/rapids/UnsafeMemoryUtils.java @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized.rapids; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +public class UnsafeMemoryUtils { + + static Method copyMemoryDirectly; + + static { + try { + Class clz = Class.forName("ai.rapids.cudf.UnsafeMemoryAccessor"); + copyMemoryDirectly = clz.getMethod("copyMemory", + Object.class, long.class, Object.class, long.class, long.class); + copyMemoryDirectly.setAccessible(true); + } catch (ClassNotFoundException | NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + /** + * The reflection of `ai.rapids.cudf.UnsafeMemoryAccessor.copyMemory` + * Copy memory from one address to the other. + */ + public static void copyMemory(Object src, long srcOffset, Object dst, long dstOffset, + long length) { + try { + copyMemoryDirectly.invoke(null, + src, + srcOffset, + dst, + dstOffset, + length); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException(e); + } + } + +}