Skip to content

Commit

Permalink
feat(typing): allow annotate methods with pos_only when only have t…
Browse files Browse the repository at this point in the history
…he `self` argument (#5403)

* feat: allow annotate methods with `pos_only` when only have the `self` argument

* chore(typing): make arguments for auto-generated dunder methods positional-only

* docs: add more comments to improve readability

* style: fix nit suggestions

* Add test_self_only_pos_only() in tests/test_methods_and_attributes

* test: add docstring tests for generated dunder methods

* test: remove failed tests

* fix(test): run `gc.collect()` three times for refcount tests

---------

Co-authored-by: Ralf W. Grosse-Kunstleve <rgrossekunst@nvidia.com>
  • Loading branch information
XuehaiPan and rwgk authored Nov 11, 2024
1 parent 6d98d4d commit 7f94f24
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 25 deletions.
2 changes: 1 addition & 1 deletion include/pybind11/detail/init.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ struct pickle_factory<Get, Set, RetState(Self), NewInstance(ArgState)> {

template <typename Class, typename... Extra>
void execute(Class &cl, const Extra &...extra) && {
cl.def("__getstate__", std::move(get));
cl.def("__getstate__", std::move(get), pos_only());

#if defined(PYBIND11_CPP14)
cl.def(
Expand Down
59 changes: 42 additions & 17 deletions include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,20 @@ class cpp_function : public function {
constexpr bool has_kw_only_args = any_of<std::is_same<kw_only, Extra>...>::value,
has_pos_only_args = any_of<std::is_same<pos_only, Extra>...>::value,
has_arg_annotations = any_of<is_keyword<Extra>...>::value;
constexpr bool has_is_method = any_of<std::is_same<is_method, Extra>...>::value;
// The implicit `self` argument is not present and not counted in method definitions.
constexpr bool has_args = cast_in::args_pos >= 0;
constexpr bool is_method_with_self_arg_only = has_is_method && !has_args;
static_assert(has_arg_annotations || !has_kw_only_args,
"py::kw_only requires the use of argument annotations");
static_assert(has_arg_annotations || !has_pos_only_args,
static_assert(((/* Need `py::arg("arg_name")` annotation in function/method. */
has_arg_annotations)
|| (/* Allow methods with no arguments `def method(self, /): ...`.
* A method has at least one argument `self`. There can be no
* `py::arg` annotation. E.g. `class.def("method", py::pos_only())`.
*/
is_method_with_self_arg_only))
|| !has_pos_only_args,
"py::pos_only requires the use of argument annotations (for docstrings "
"and aligning the annotations to the argument)");

Expand Down Expand Up @@ -2022,17 +2033,20 @@ struct enum_base {
.format(std::move(type_name), enum_name(arg), int_(arg));
},
name("__repr__"),
is_method(m_base));
is_method(m_base),
pos_only());

m_base.attr("name") = property(cpp_function(&enum_name, name("name"), is_method(m_base)));
m_base.attr("name")
= property(cpp_function(&enum_name, name("name"), is_method(m_base), pos_only()));

m_base.attr("__str__") = cpp_function(
[](handle arg) -> str {
object type_name = type::handle_of(arg).attr("__name__");
return pybind11::str("{}.{}").format(std::move(type_name), enum_name(arg));
},
name("__str__"),
is_method(m_base));
is_method(m_base),
pos_only());

if (options::show_enum_members_docstring()) {
m_base.attr("__doc__") = static_property(
Expand Down Expand Up @@ -2087,7 +2101,8 @@ struct enum_base {
}, \
name(op), \
is_method(m_base), \
arg("other"))
arg("other"), \
pos_only())

#define PYBIND11_ENUM_OP_CONV(op, expr) \
m_base.attr(op) = cpp_function( \
Expand All @@ -2097,7 +2112,8 @@ struct enum_base {
}, \
name(op), \
is_method(m_base), \
arg("other"))
arg("other"), \
pos_only())

#define PYBIND11_ENUM_OP_CONV_LHS(op, expr) \
m_base.attr(op) = cpp_function( \
Expand All @@ -2107,7 +2123,8 @@ struct enum_base {
}, \
name(op), \
is_method(m_base), \
arg("other"))
arg("other"), \
pos_only())

if (is_convertible) {
PYBIND11_ENUM_OP_CONV_LHS("__eq__", !b.is_none() && a.equal(b));
Expand All @@ -2127,7 +2144,8 @@ struct enum_base {
m_base.attr("__invert__")
= cpp_function([](const object &arg) { return ~(int_(arg)); },
name("__invert__"),
is_method(m_base));
is_method(m_base),
pos_only());
}
} else {
PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)), return false);
Expand All @@ -2147,11 +2165,15 @@ struct enum_base {
#undef PYBIND11_ENUM_OP_CONV
#undef PYBIND11_ENUM_OP_STRICT

m_base.attr("__getstate__") = cpp_function(
[](const object &arg) { return int_(arg); }, name("__getstate__"), is_method(m_base));
m_base.attr("__getstate__") = cpp_function([](const object &arg) { return int_(arg); },
name("__getstate__"),
is_method(m_base),
pos_only());

m_base.attr("__hash__") = cpp_function(
[](const object &arg) { return int_(arg); }, name("__hash__"), is_method(m_base));
m_base.attr("__hash__") = cpp_function([](const object &arg) { return int_(arg); },
name("__hash__"),
is_method(m_base),
pos_only());
}

PYBIND11_NOINLINE void value(char const *name_, object value, const char *doc = nullptr) {
Expand Down Expand Up @@ -2243,9 +2265,9 @@ class enum_ : public class_<Type> {
m_base.init(is_arithmetic, is_convertible);

def(init([](Scalar i) { return static_cast<Type>(i); }), arg("value"));
def_property_readonly("value", [](Type value) { return (Scalar) value; });
def("__int__", [](Type value) { return (Scalar) value; });
def("__index__", [](Type value) { return (Scalar) value; });
def_property_readonly("value", [](Type value) { return (Scalar) value; }, pos_only());
def("__int__", [](Type value) { return (Scalar) value; }, pos_only());
def("__index__", [](Type value) { return (Scalar) value; }, pos_only());
attr("__setstate__") = cpp_function(
[](detail::value_and_holder &v_h, Scalar arg) {
detail::initimpl::setstate<Base>(
Expand All @@ -2254,7 +2276,8 @@ class enum_ : public class_<Type> {
detail::is_new_style_constructor(),
pybind11::name("__setstate__"),
is_method(*this),
arg("state"));
arg("state"),
pos_only());
}

/// Export enumeration entries into the parent scope
Expand Down Expand Up @@ -2440,7 +2463,8 @@ iterator make_iterator_impl(Iterator first, Sentinel last, Extra &&...extra) {

if (!detail::get_type_info(typeid(state), false)) {
class_<state>(handle(), "iterator", pybind11::module_local())
.def("__iter__", [](state &s) -> state & { return s; })
.def(
"__iter__", [](state &s) -> state & { return s; }, pos_only())
.def(
"__next__",
[](state &s) -> ValueType {
Expand All @@ -2457,6 +2481,7 @@ iterator make_iterator_impl(Iterator first, Sentinel last, Extra &&...extra) {
// NOLINTNEXTLINE(readability-const-return-type) // PR #3263
},
std::forward<Extra>(extra)...,
pos_only(),
Policy);
}

Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,11 @@ def pytest_assertrepr_compare(op, left, right): # noqa: ARG001


def gc_collect():
"""Run the garbage collector twice (needed when running
"""Run the garbage collector three times (needed when running
reference counting tests with PyPy)"""
gc.collect()
gc.collect()
gc.collect()


def pytest_configure():
Expand Down
60 changes: 60 additions & 0 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# ruff: noqa: SIM201 SIM300 SIM202
from __future__ import annotations

import re

import pytest

import env # noqa: F401
Expand Down Expand Up @@ -271,3 +273,61 @@ def test_docstring_signatures():
def test_str_signature():
for enum_type in [m.ScopedEnum, m.UnscopedEnum]:
assert enum_type.__str__.__doc__.startswith("__str__")


def test_generated_dunder_methods_pos_only():
for enum_type in [m.ScopedEnum, m.UnscopedEnum]:
for binary_op in [
"__eq__",
"__ne__",
"__ge__",
"__gt__",
"__lt__",
"__le__",
"__and__",
"__rand__",
# "__or__", # fail with some compilers (__doc__ = "Return self|value.")
# "__ror__", # fail with some compilers (__doc__ = "Return value|self.")
"__xor__",
"__rxor__",
"__rxor__",
]:
method = getattr(enum_type, binary_op, None)
if method is not None:
assert (
re.match(
rf"^{binary_op}\(self: [\w\.]+, other: [\w\.]+, /\)",
method.__doc__,
)
is not None
)
for unary_op in [
"__int__",
"__index__",
"__hash__",
"__str__",
"__repr__",
]:
method = getattr(enum_type, unary_op, None)
if method is not None:
assert (
re.match(
rf"^{unary_op}\(self: [\w\.]+, /\)",
method.__doc__,
)
is not None
)
assert (
re.match(
r"^__getstate__\(self: [\w\.]+, /\)",
enum_type.__getstate__.__doc__,
)
is not None
)
assert (
re.match(
r"^__setstate__\(self: [\w\.]+, state: [\w\.]+, /\)",
enum_type.__setstate__.__doc__,
)
is not None
)
2 changes: 1 addition & 1 deletion tests/test_methods_and_attributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ TEST_SUBMODULE(methods_and_attributes, m) {
static_cast<py::str (ExampleMandA::*)(int, int)>(
&ExampleMandA::overloaded));
})
.def("__str__", &ExampleMandA::toString)
.def("__str__", &ExampleMandA::toString, py::pos_only())
.def_readwrite("value", &ExampleMandA::value);

// test_copy_method
Expand Down
7 changes: 7 additions & 0 deletions tests/test_methods_and_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
)


def test_self_only_pos_only():
assert (
m.ExampleMandA.__str__.__doc__
== "__str__(self: pybind11_tests.methods_and_attributes.ExampleMandA, /) -> str\n"
)


def test_methods_and_attributes():
instance1 = m.ExampleMandA()
instance2 = m.ExampleMandA(32)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_pickling.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,20 @@ def test_roundtrip_simple_cpp_derived():
# Issue #3062: pickleable base C++ classes can incur object slicing
# if derived typeid is not registered with pybind11
assert not m.check_dynamic_cast_SimpleCppDerived(p2)


def test_new_style_pickle_getstate_pos_only():
assert (
re.match(
r"^__getstate__\(self: [\w\.]+, /\)", m.PickleableNew.__getstate__.__doc__
)
is not None
)
if hasattr(m, "PickleableWithDictNew"):
assert (
re.match(
r"^__getstate__\(self: [\w\.]+, /\)",
m.PickleableWithDictNew.__getstate__.__doc__,
)
is not None
)
30 changes: 25 additions & 5 deletions tests/test_sequences_and_iterators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import re

import pytest
from pytest import approx # noqa: PT013

Expand Down Expand Up @@ -253,16 +255,12 @@ def bad_next_call():

def test_iterator_passthrough():
"""#181: iterator passthrough did not compile"""
from pybind11_tests.sequences_and_iterators import iterator_passthrough

values = [3, 5, 7, 9, 11, 13, 15]
assert list(iterator_passthrough(iter(values))) == values
assert list(m.iterator_passthrough(iter(values))) == values


def test_iterator_rvp():
"""#388: Can't make iterators via make_iterator() with different r/v policies"""
import pybind11_tests.sequences_and_iterators as m

assert list(m.make_iterator_1()) == [1, 2, 3]
assert list(m.make_iterator_2()) == [1, 2, 3]
assert not isinstance(m.make_iterator_1(), type(m.make_iterator_2()))
Expand All @@ -274,3 +272,25 @@ def test_carray_iterator():
arr_h = m.CArrayHolder(*args_gt)
args = list(arr_h)
assert args_gt == args


def test_generated_dunder_methods_pos_only():
string_map = m.StringMap({"hi": "bye", "black": "white"})
for it in (
m.make_iterator_1(),
m.make_iterator_2(),
m.iterator_passthrough(iter([3, 5, 7])),
iter(m.Sequence(5)),
iter(string_map),
string_map.items(),
string_map.values(),
iter(m.CArrayHolder(*[float(i) for i in range(3)])),
):
assert (
re.match(r"^__iter__\(self: [\w\.]+, /\)", type(it).__iter__.__doc__)
is not None
)
assert (
re.match(r"^__next__\(self: [\w\.]+, /\)", type(it).__next__.__doc__)
is not None
)

0 comments on commit 7f94f24

Please sign in to comment.