Skip to content

Commit

Permalink
Merge pull request #357 from TileDB-Inc/teo-buffer-zero-copy
Browse files Browse the repository at this point in the history
Add zero-copy serialization APIs.
  • Loading branch information
ypatia authored Dec 12, 2024
2 parents 9ec3a94 + 0b8795f commit 9a66fdf
Show file tree
Hide file tree
Showing 9 changed files with 528 additions and 180 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tiledb-go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:

# Tests TileDB-Go
- name: Test TileDB-Go
run: go test -v ./...
run: go test -gcflags=all=-d=checkptr=2 -v ./...

Macos_Test:
runs-on: macos-latest
Expand Down
86 changes: 85 additions & 1 deletion buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ package tiledb
import "C"
import (
"bytes"
"errors"
"fmt"
"io"
"math"
"runtime"
"unsafe"
Expand Down Expand Up @@ -94,6 +96,8 @@ func (b *Buffer) Type() (Datatype, error) {
}

// Serialize returns a copy of the bytes in the buffer.
//
// Deprecated: Use WriteTo or ReadAt instead for increased performance.
func (b *Buffer) Serialize(serializationType SerializationType) ([]byte, error) {
bs, err := b.dataCopy()
if err != nil {
Expand All @@ -103,14 +107,94 @@ func (b *Buffer) Serialize(serializationType SerializationType) ([]byte, error)
case TILEDB_CAPNP:
// The entire byte array contains Cap'nP data. Don't bother it.
case TILEDB_JSON:
// The data is a null-terminated string. Strip off the terminator.
// The data might be a null-terminated string. Strip off the terminator.
bs = bytes.TrimSuffix(bs, []byte{0})
default:
return nil, fmt.Errorf("unsupported serialization type: %v", serializationType)
}
return bs, nil
}

// ReadAt writes the contents of a Buffer at a given offset to a slice.
func (b *Buffer) ReadAt(p []byte, off int64) (int, error) {
if off < 0 {
return 0, errors.New("offset cannot be negative")
}

var cbuffer unsafe.Pointer
var csize C.uint64_t

ret := C.tiledb_buffer_get_data(b.context.tiledbContext, b.tiledbBuffer, &cbuffer, &csize)
if ret != C.TILEDB_OK {
return 0, fmt.Errorf("error getting tiledb buffer data: %w", b.context.LastError())
}

if uintptr(off) >= uintptr(csize) || cbuffer == nil {
// Match ReaderAt behavior of os.File and fail with io.EOF if the offset is greater or equal to the size.
return 0, io.EOF
}

availableBytes := uint64(csize) - uint64(off)
var sizeToRead int
if availableBytes > math.MaxInt {
sizeToRead = math.MaxInt
} else {
sizeToRead = int(availableBytes)
}

readSize := copy(p, unsafe.Slice((*byte)(unsafe.Pointer(uintptr(cbuffer)+uintptr(off))), sizeToRead))

var err error
if int64(readSize)+off == int64(csize) {
err = io.EOF
}

return readSize, err
}

// WriteTo writes the contents of a Buffer to an io.Writer.
func (b *Buffer) WriteTo(w io.Writer) (int64, error) {
var cbuffer unsafe.Pointer
var csize C.uint64_t

ret := C.tiledb_buffer_get_data(b.context.tiledbContext, b.tiledbBuffer, &cbuffer, &csize)
if ret != C.TILEDB_OK {
return 0, fmt.Errorf("error getting tiledb buffer data: %w", b.context.LastError())
}

if cbuffer == nil || csize == 0 {
return 0, nil
}

remaining := int64(csize)

// Because io.Writer supports writing up to 2GB of data at a time, we have to use a loop
// for the bigger buffers.
for remaining > 0 {
// TODO: Use min on Go 1.21+
var writeSize int
if remaining > math.MaxInt {
writeSize = math.MaxInt
} else {
writeSize = int(remaining)
}

// Construct a slice from the buffer's data without copying it.
n, err := w.Write(unsafe.Slice((*byte)(unsafe.Pointer(uintptr(cbuffer)+uintptr(csize)-uintptr(remaining))), writeSize))
remaining -= int64(n)

if err != nil {
return int64(csize) - remaining, fmt.Errorf("error writing buffer to writer: %w", err)
}
}

return int64(csize), nil
}

// Static assert that Buffer implements io.WriterTo.
var _ io.WriterTo = (*Buffer)(nil)
var _ io.ReaderAt = (*Buffer)(nil)

// SetBuffer sets the buffer to point at the given Go slice. The memory is now
// Go-managed.
func (b *Buffer) SetBuffer(buffer []byte) error {
Expand Down
33 changes: 33 additions & 0 deletions buffer_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package tiledb
import "C"
import (
"fmt"
"io"
)

// BufferList A list of TileDB BufferList objects
Expand Down Expand Up @@ -44,6 +45,36 @@ func (b *BufferList) Context() *Context {
return b.context
}

// WriteTo writes the contents of a BufferList to an io.Writer.
func (b *BufferList) WriteTo(w io.Writer) (int64, error) {
nbuffs, err := b.NumBuffers()
if err != nil {
return 0, err
}

written := int64(0)

for i := uint(0); i < uint(nbuffs); i++ {
buff, err := b.GetBuffer(i)
if err != nil {
return 0, err
}
n, err := buff.WriteTo(w)
written += n

buff.Free()

if err != nil {
return written, err
}
}

return written, nil
}

// Static assert that BufferList implements io.WriterTo.
var _ io.WriterTo = (*BufferList)(nil)

// NumBuffers returns number of buffers in the list.
func (b *BufferList) NumBuffers() (uint64, error) {
var numBuffers C.uint64_t
Expand Down Expand Up @@ -82,6 +113,8 @@ func (b *BufferList) TotalSize() (uint64, error) {
}

// Flatten copies and concatenates all buffers in the list into a new buffer.
//
// Deprecated: Use WriteTo instead for increased performance.
func (b *BufferList) Flatten() (*Buffer, error) {
buffer := Buffer{context: b.context}
freeOnGC(&buffer)
Expand Down
76 changes: 76 additions & 0 deletions buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tiledb

import (
"fmt"
"io"
"runtime"
"testing"

Expand Down Expand Up @@ -95,3 +96,78 @@ func TestBufferSafety(t *testing.T) {
t.Log("post gc 2")
verify()
}

type byteCounter struct {
BytesWritten int64
}

func (b *byteCounter) Write(x []byte) (int, error) {
b.BytesWritten += int64(len(x))
return len(x), nil
}

func TestWriteTo(t *testing.T) {
context, err := NewContext(nil)
require.NoError(t, err)
buffer, err := NewBuffer(context)
require.NoError(t, err)

testSizes := [5]int{0, 16, 256, 65536, 268435456}
for _, size := range testSizes {
err := buffer.SetBuffer(make([]byte, size))
require.NoError(t, err)

counter := new(byteCounter)
n, err := buffer.WriteTo(counter)
require.NoError(t, err)
assert.Equal(t, size, int(n))
}
}

func TestReadAt(t *testing.T) {
context, err := NewContext(nil)
require.NoError(t, err)
buffer, err := NewBuffer(context)
require.NoError(t, err)

err = buffer.SetBuffer([]byte{})
require.NoError(t, err)

n, err := buffer.ReadAt(make([]byte, 10), 0)
require.Equal(t, io.EOF, err)
require.Equal(t, 0, n)

testSizes := [4]int{16, 256, 65536, 256 << 20}
for _, size := range testSizes {
err = buffer.SetBuffer(make([]byte, size))
require.NoError(t, err)

readBuffer := make([]byte, 10)
n, err = buffer.ReadAt(readBuffer, 0)
require.NoError(t, err)
require.Equal(t, 10, n)

n, err = buffer.ReadAt(readBuffer, int64(size)-10)
require.Equal(t, io.EOF, err)
require.Equal(t, 10, n)

n, err = buffer.ReadAt(readBuffer, int64(size)-5)
require.Equal(t, io.EOF, err)
require.Equal(t, 5, n)

n, err = buffer.ReadAt(readBuffer, int64(size))
require.Equal(t, io.EOF, err)
require.Equal(t, 0, n)

n, err = buffer.ReadAt(readBuffer, int64(size)+1)
require.Equal(t, io.EOF, err)
require.Equal(t, 0, n)

n, err = buffer.ReadAt(readBuffer, int64(size)+100)
require.Equal(t, io.EOF, err)
require.Equal(t, 0, n)

_, err = buffer.ReadAt(readBuffer, -1)
require.EqualError(t, err, "offset cannot be negative")
}
}
19 changes: 0 additions & 19 deletions enumeration_experimental.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,25 +496,6 @@ func (ase *ArraySchemaEvolution) ApplyExtendedEnumeration(e *Enumeration) error
return nil
}

// DeserializeLoadEnumerationsRequest deserializes a LoadEnumerationsRequests. This is used by TileDB-Cloud.
func DeserializeLoadEnumerationsRequest(array *Array, serializationType SerializationType, request *Buffer) (*Buffer, error) {
response, err := NewBuffer(array.context)
if err != nil {
return nil, fmt.Errorf("error deserializing load enumerations request: %s", array.context.LastError())
}

ret := C.tiledb_handle_load_enumerations_request(array.context.tiledbContext, array.tiledbArray, C.tiledb_serialization_type_t(serializationType),
request.tiledbBuffer, response.tiledbBuffer)
if ret != C.TILEDB_OK {
return nil, fmt.Errorf("error deserializing load enumerations request: %s", array.context.LastError())
}

runtime.KeepAlive(request)
runtime.KeepAlive(array)

return response, nil
}

// copyUnsafeSliceOfEnumerationValues copies the values returned by tiledb_enumeration_get_data to a slice
// in go managed memory. This is for safety because the returned data points to unsafe memory handled by core.
// The tiledb_enumeration_get_data returns the aggregated size (sth like len() * sizeOf) so this methods
Expand Down
63 changes: 0 additions & 63 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@ package tiledb

import (
"encoding/json"
"errors"
"fmt"
"unsafe"
)

/*
#include <tiledb/tiledb_experimental.h>
#include <tiledb/tiledb_serialization.h>
#include <stdlib.h>
*/
import "C"
Expand All @@ -38,33 +34,6 @@ func NewGroup(tdbCtx *Context, uri string) (*Group, error) {
return &group, nil
}

// Deserialize deserializes the group from the given buffer.
func (g *Group) Deserialize(buffer *Buffer, serializationType SerializationType, clientSide bool) error {
var cClientSide C.int32_t
if clientSide {
cClientSide = 1
} else {
cClientSide = 0
}

b, err := buffer.dataCopy()
if err != nil {
return errors.New("failed to retrieve bytes from buffer")
}

// cstrings are null terminated. Go's are not, add it as a suffix
if err := buffer.SetBuffer(append(b, []byte("\u0000")...)); err != nil {
return errors.New("failed to add null terminator to buffer")
}

ret := C.tiledb_deserialize_group(g.context.tiledbContext, buffer.tiledbBuffer, C.tiledb_serialization_type_t(serializationType), cClientSide, g.group)
if ret != C.TILEDB_OK {
return fmt.Errorf("Error deserializing group: %s", g.context.LastError())
}

return nil
}

// Create creates a new TileDB group.
func (g *Group) Create() error {
curi := C.CString(g.uri)
Expand Down Expand Up @@ -436,38 +405,6 @@ func (g *Group) Dump(recurse bool) (string, error) {
return C.GoString(cOutput), nil
}

// SerializeGroupMetadata gets and serializes the group metadata
func SerializeGroupMetadata(g *Group, serializationType SerializationType) ([]byte, error) {
buffer := Buffer{context: g.context}
freeOnGC(&buffer)

ret := C.tiledb_serialize_group_metadata(g.context.tiledbContext, g.group, C.tiledb_serialization_type_t(serializationType), &buffer.tiledbBuffer)
if ret != C.TILEDB_OK {
return nil, fmt.Errorf("Error serializing group metadata: %s", g.context.LastError())
}

return buffer.Serialize(serializationType)
}

// DeserializeGroupMetadata deserializes group metadata
func DeserializeGroupMetadata(g *Group, buffer *Buffer, serializationType SerializationType) error {
b, err := buffer.dataCopy()
if err != nil {
return errors.New("failed to retrieve bytes from buffer")
}
// cstrings are null terminated. Go's are not, add it as a suffix
if err := buffer.SetBuffer(append(b, []byte("\u0000")...)); err != nil {
return errors.New("failed to add null terminator to buffer")
}

ret := C.tiledb_deserialize_group_metadata(g.context.tiledbContext, g.group, C.tiledb_serialization_type_t(serializationType), buffer.tiledbBuffer)
if ret != C.TILEDB_OK {
return fmt.Errorf("Error deserializing group metadata: %s", g.context.LastError())
}

return nil
}

// GetIsRelativeURIByName returns whether a named member of the group has a uri relative to the group
func (g *Group) GetIsRelativeURIByName(name string) (bool, error) {
cName := C.CString(name)
Expand Down
Loading

0 comments on commit 9a66fdf

Please sign in to comment.