Skip to content

Commit

Permalink
enable DirectByteBufferAllocator for ParquetFileReader
Browse files Browse the repository at this point in the history
  • Loading branch information
sperlingxx authored and wjxiz1992 committed May 7, 2024
1 parent f103c4e commit e0b7bba
Show file tree
Hide file tree
Showing 8 changed files with 435 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{AutoCloseableAttemptSplit
import org.apache.hadoop.conf.Configuration
import org.apache.parquet.{HadoopReadOptions, VersionParser}
import org.apache.parquet.VersionParser.ParsedVersion
import org.apache.parquet.bytes.DirectByteBufferAllocator
import org.apache.parquet.column.page.PageReadStore
import org.apache.parquet.hadoop.ParquetFileReader
import org.apache.parquet.schema.MessageType
Expand Down Expand Up @@ -96,6 +97,7 @@ class AsyncParquetReader(
val options = HadoopReadOptions.builder(conf)
.withRange(offset, offset + len)
.withCodecFactory(new ParquetCodecFactory(conf, 0))
.withAllocator(new DirectByteBufferAllocator)
.build()
val bufferFile = new HMBInputFile(fileBuffer, length = Some(offset + len))
val reader = new ParquetFileReader(bufferFile, options)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.parquet.rapids;

import org.apache.parquet.bytes.ByteBufferInputStream;
import org.apache.parquet.io.api.Binary;
import org.apache.spark.sql.execution.vectorized.rapids.WritableColumnVector;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Iterator;
import java.util.List;

public abstract class ByteBufferIsConsumer {

protected Iterator<ByteBuffer> iterator;
protected ByteBuffer current = null;

public ByteBufferIsConsumer(Iterator<ByteBuffer> bufferIterator) {
iterator = bufferIterator;
if (iterator.hasNext()) {
pointToNextBuffer();
}
}

protected void pointToNextBuffer() {
current = iterator.next();
current.order(ByteOrder.LITTLE_ENDIAN);
}

public void advance(long sizeInByte) {
long remaining = sizeInByte;
while (remaining > 0) {
if (!current.hasRemaining()) pointToNextBuffer();

int batchSize = (remaining >= current.remaining()) ? current.remaining() : (int) remaining;
current.position(current.position() + batchSize);
remaining -= batchSize;
}
}

public abstract void readInts(int total, WritableColumnVector c, int rowId);
public abstract void readLongs(int total, WritableColumnVector c, int rowId);
public abstract void readFloats(int total, WritableColumnVector c, int rowId);
public abstract void readDoubles(int total, WritableColumnVector c, int rowId);
public abstract void readUIntsAsLongs(int total, WritableColumnVector c, int rowId);
public abstract void readIntsAsShorts(int total, WritableColumnVector c, int rowId);
public abstract void readIntsAsBytes(int total, WritableColumnVector c, int rowId);
public abstract void readBinaries(int total, WritableColumnVector c, int rowId);
public abstract byte getByte();
public abstract int getInt();
public abstract long getLong();
public abstract float getFloat();
public abstract double getDouble();
public abstract Binary getBinary(int len);

public static ByteBufferIsConsumer create(ByteBufferInputStream bis) {
List<ByteBuffer> buffers = bis.remainingBuffers();
if (buffers.isEmpty()) {
throw new IllegalArgumentException("Got empty ByteBufferInputStream");
}
if (buffers.size() > 1) {
System.err.printf("create a MultiByteBuffersConsumer with %d buffers\n", buffers.size());
}
// HeapByteBufferIsConsumer for HeapByteBuffer; DirectByteBufferIsConsumer for DirectByteBuffer
if (buffers.get(0).hasArray()) {
return new HeapByteBufferIsConsumer(buffers.iterator());
}
return new DirectByteBufferIsConsumer(buffers.iterator());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.parquet.rapids;

import java.lang.reflect.Field;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.util.Iterator;

import org.apache.parquet.bytes.ByteBufferInputStream;
import org.apache.parquet.io.api.Binary;
import org.apache.spark.sql.execution.vectorized.rapids.RapidsWritableColumnVector;
import org.apache.spark.sql.execution.vectorized.rapids.UnsafeMemoryUtils;
import org.apache.spark.sql.execution.vectorized.rapids.WritableColumnVector;

public class DirectByteBufferIsConsumer extends ByteBufferIsConsumer {

static Field addrField;

static {
try {
addrField = Buffer.class.getDeclaredField("address");
addrField.setAccessible(true);
} catch (NoSuchFieldException e) {
throw new RuntimeException(e);
}
}

private long address;

public DirectByteBufferIsConsumer(Iterator<ByteBuffer> bufferIterator) {
super(bufferIterator);
}

protected void pointToNextBuffer() {
super.pointToNextBuffer();
assert current.isDirect();
try {
address = addrField.getLong(current);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}

private void readFixedLengthData(int total, WritableColumnVector c, int rowId, int bitWidth) {
assert c instanceof RapidsWritableColumnVector : "Only supports RapidsWritableColumnVector";
RapidsWritableColumnVector cv = (RapidsWritableColumnVector) c;

int remaining = total << bitWidth;
int tgtOffset = rowId;

while (remaining > 0) {
if (!current.hasRemaining()) pointToNextBuffer();

int size = Math.min(remaining, current.remaining());
// TODO: handle buffer tails separately, when there are multiple buffers
if (size >> 2 << 2 != size) {
throw new RuntimeException("Will support the special handling of buffer tails, when there are multiple buffers");
}
long srcOffset = address + current.position();
int sizeInRow = size >> bitWidth;
cv.putFixedLengthElementsUnsafely(tgtOffset, srcOffset, size, bitWidth);
current.position(current.position() + size);
tgtOffset += sizeInRow;
remaining -= size;
}
}

@Override
public void readInts(int total, WritableColumnVector c, int rowId) {
readFixedLengthData(total, c, rowId, 2);
}

@Override
public void readLongs(int total, WritableColumnVector c, int rowId) {
readFixedLengthData(total, c, rowId, 3);
}

@Override
public void readFloats(int total, WritableColumnVector c, int rowId) {
readFixedLengthData(total, c, rowId, 2);
}

@Override
public void readDoubles(int total, WritableColumnVector c, int rowId) {
readFixedLengthData(total, c, rowId, 3);
}

@Override
public void readBinaries(int total, WritableColumnVector c, int rowId) {
assert c instanceof RapidsWritableColumnVector : "Only supports RapidsWritableColumnVector";
RapidsWritableColumnVector cv = (RapidsWritableColumnVector) c;
RapidsWritableColumnVector charVector = (RapidsWritableColumnVector) cv.arrayData();

for (int i = 0; i < total; ++i) {
if (!current.hasRemaining()) pointToNextBuffer();

int curLength = current.getInt();
int prevOffset = charVector.getElementsAppended();

if (curLength > 0) {
charVector.reserve(prevOffset + curLength);

int remainLen = curLength;
int charOffset = prevOffset;
while (remainLen > 0) {
if (!current.hasRemaining()) pointToNextBuffer();

int size = Math.min(remainLen, current.remaining());
int bufPos = current.position();
charVector.putFixedLengthElementsUnsafely(charOffset, this.address + bufPos, size, 0);
charOffset += size;
current.position(bufPos + size);
remainLen -= size;
}
charVector.addElementsAppended(curLength);
}

cv.commitStringAppend(rowId + i, prevOffset, curLength);
}
}

@Override
public void readUIntsAsLongs(int total, WritableColumnVector c, int rowId) {
int remaining = total * 4;
int tgtOffset = rowId;

while (remaining > 0) {
if (!current.hasRemaining()) pointToNextBuffer();

int size = Math.min(remaining, current.remaining());
for (int i = 0; i < size >> 2; ++i) {
c.putLong(tgtOffset + i, Integer.toUnsignedLong(current.getInt()));
}
tgtOffset += size >> 3;
remaining -= size;
}
}

@Override
public void readIntsAsShorts(int total, WritableColumnVector c, int rowId) {
int remaining = total * 4;
int tgtOffset = rowId;

while (remaining > 0) {
if (!current.hasRemaining()) pointToNextBuffer();

int size = Math.min(remaining, current.remaining());
for (int i = 0; i < size >> 2; ++i) {
c.putShort(tgtOffset + i, (short) current.getInt());
}
tgtOffset += size >> 1;
remaining -= size;
}
}

@Override
public void readIntsAsBytes(int total, WritableColumnVector c, int rowId) {
int remaining = total * 4;
int tgtOffset = rowId;

while (remaining > 0) {
if (!current.hasRemaining()) pointToNextBuffer();

int size = Math.min(remaining, current.remaining());
int pos = current.position();
for (int i = 0; i < size >> 2; ++i) {
c.putByte(tgtOffset + i, current.get(pos + (i << 2)));
}
tgtOffset += size;
current.position(pos + size);
remaining -= size;
}
}

@Override
public byte getByte() {
if (!current.hasRemaining()) pointToNextBuffer();
return current.get();
}

@Override
public int getInt() {
if (!current.hasRemaining()) pointToNextBuffer();
return current.getInt();
}

@Override
public long getLong() {
if (!current.hasRemaining()) pointToNextBuffer();
return current.getLong();
}

@Override
public float getFloat() {
if (!current.hasRemaining()) pointToNextBuffer();
return current.getFloat();
}

@Override
public double getDouble() {
if (!current.hasRemaining()) pointToNextBuffer();
return current.getDouble();
}

@Override
public Binary getBinary(int len) {
byte[] target = new byte[len];
int targetOffset = 0;

do {
if (!current.hasRemaining()) pointToNextBuffer();

int batchSize = Math.min(len - targetOffset, current.remaining());
UnsafeMemoryUtils.copyMemory(null,
this.address + current.position(), target, targetOffset, batchSize);
current.position(current.position() + batchSize);
targetOffset += batchSize;
}
while (targetOffset < len);

return Binary.fromConstantByteArray(target);
}

}
Loading

0 comments on commit e0b7bba

Please sign in to comment.