Skip to content

Commit

Permalink
Fix buffer protocol implementation
Browse files Browse the repository at this point in the history
According to the buffer protocol, `ndim` is a _required_ field [1], and
should always be set correctly. Additionally, `shape` should be set if
flags includes `PyBUF_ND` or higher [2]. The current implementation only
set those fields if flags was `PyBUF_STRIDES`.

[1] https://docs.python.org/3/c-api/buffer.html#request-independent-fields
[2] https://docs.python.org/3/c-api/buffer.html#shape-strides-suboffsets
  • Loading branch information
QuLogic committed Oct 11, 2024
1 parent af67e87 commit 49a6f22
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 3 deletions.
7 changes: 4 additions & 3 deletions include/pybind11/detail/class.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,9 +602,9 @@ extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int fla
return -1;
}
view->obj = obj;
view->ndim = 1;
view->internal = info;
view->buf = info->ptr;
view->ndim = (int) info->ndim;
view->itemsize = info->itemsize;
view->len = view->itemsize;
for (auto s : info->shape) {
Expand All @@ -614,10 +614,11 @@ extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int fla
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
view->format = const_cast<char *>(info->format.c_str());
}
if ((flags & PyBUF_ND) == PyBUF_ND) {
view->shape = info->shape.data();
}
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
view->ndim = (int) info->ndim;
view->strides = info->strides.data();
view->shape = info->shape.data();
}
Py_INCREF(view->obj);
return 0;
Expand Down
54 changes: 54 additions & 0 deletions tests/test_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,4 +268,58 @@ TEST_SUBMODULE(buffers, m) {
});

m.def("get_buffer_info", [](const py::buffer &buffer) { return buffer.request(); });

// Expose Py_buffer for testing.
py::class_<Py_buffer>(m, "Py_buffer")
.def_readonly("len", &Py_buffer::len)
.def_readonly("readonly", &Py_buffer::readonly)
.def_readonly("itemsize", &Py_buffer::itemsize)
.def_readonly("format", &Py_buffer::format)
.def_readonly("ndim", &Py_buffer::ndim)
.def_property_readonly("shape",
[](const Py_buffer &buffer) -> py::object {
if (buffer.shape == nullptr) {
return py::none();
}
py::list l;
for (auto i = 0; i < buffer.ndim; i++) {
l.append(buffer.shape[i]);
}
return l;
})
.def_property_readonly("strides",
[](const Py_buffer &buffer) -> py::object {
if (buffer.strides == nullptr) {
return py::none();
}
py::list l;
for (auto i = 0; i < buffer.ndim; i++) {
l.append(buffer.strides[i]);
}
return l;
})
.def_property_readonly("suboffsets", [](const Py_buffer &buffer) -> py::object {
if (buffer.suboffsets == nullptr) {
return py::none();
}
py::list l;
for (auto i = 0; i < buffer.ndim; i++) {
l.append(buffer.suboffsets[i]);
}
return l;
});
m.attr("PyBUF_SIMPLE") = PyBUF_SIMPLE;
m.attr("PyBUF_ND") = PyBUF_ND;
m.attr("PyBUF_STRIDES") = PyBUF_STRIDES;
m.attr("PyBUF_INDIRECT") = PyBUF_INDIRECT;

m.def("get_py_buffer", [](const py::object &object, int flags) {
Py_buffer buffer;
memset(&buffer, 0, sizeof(Py_buffer));
if (PyObject_GetBuffer(object.ptr(), &buffer, flags) == -1) {
throw py::error_already_set();
}
// TODO: This leaks...
return buffer;
});
}
37 changes: 37 additions & 0 deletions tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,40 @@ def test_buffer_exception():
memoryview(m.BrokenMatrix(1, 1))
assert isinstance(excinfo.value.__cause__, RuntimeError)
assert "for context" in str(excinfo.value.__cause__)


def test_to_pybuffer():
mat = m.Matrix(5, 4)

info = m.get_py_buffer(mat, m.PyBUF_SIMPLE)
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
assert info.len == mat.rows() * mat.cols() * info.itemsize
assert info.ndim == 2
assert info.shape is None
assert info.strides is None
assert info.suboffsets is None
assert not info.readonly
info = m.get_py_buffer(mat, m.PyBUF_ND)
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
assert info.len == mat.rows() * mat.cols() * info.itemsize
assert info.ndim == 2
assert info.shape == [5, 4]
assert info.strides is None
assert info.suboffsets is None
assert not info.readonly
info = m.get_py_buffer(mat, m.PyBUF_STRIDES)
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
assert info.len == mat.rows() * mat.cols() * info.itemsize
assert info.ndim == 2
assert info.shape == [5, 4]
assert info.strides == [4 * info.itemsize, info.itemsize]
assert info.suboffsets is None
assert not info.readonly
info = m.get_py_buffer(mat, m.PyBUF_INDIRECT)
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
assert info.len == mat.rows() * mat.cols() * info.itemsize
assert info.ndim == 2
assert info.shape == [5, 4]
assert info.strides == [4 * info.itemsize, info.itemsize]
assert info.suboffsets is None # Should be filled in here, but we don't use it.
assert not info.readonly

0 comments on commit 49a6f22

Please sign in to comment.