Skip to content

Commit

Permalink
Update Numpy type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
InvincibleRMC committed Feb 6, 2025
1 parent 2a68bed commit b2eb337
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 59 deletions.
22 changes: 12 additions & 10 deletions include/pybind11/eigen/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,21 +224,23 @@ struct EigenProps {
static constexpr bool show_f_contiguous
= !show_c_contiguous && show_order && requires_col_major;

static constexpr auto descriptor
= const_name("numpy.typing.NDArray[") + npy_format_descriptor<Scalar>::name
+ const_name("[") + const_name<fixed_rows>(const_name<(size_t) rows>(), const_name("m"))
+ const_name(", ") + const_name<fixed_cols>(const_name<(size_t) cols>(), const_name("n"))
+ const_name("]") +
static constexpr auto descriptor
= const_name("typing.Annotated[")
+ io_name("numpy.typing.ArrayLike, ", "numpy.typing.NDArray[")
+ npy_format_descriptor<Scalar>::name + io_name("", "]") + const_name(", \"[")
+ const_name<fixed_rows>(const_name<(size_t) rows>(), const_name("m")) + const_name(", ")
+ const_name<fixed_cols>(const_name<(size_t) cols>(), const_name("n"))
+ const_name("]\"")
// For a reference type (e.g. Ref<MatrixXd>) we have other constraints that might need to
// be satisfied: writeable=True (for a mutable reference), and, depending on the map's
// stride options, possibly f_contiguous or c_contiguous. We include them in the
// descriptor output to provide some hint as to why a TypeError is occurring (otherwise
// it can be confusing to see that a function accepts a
// 'numpy.typing.NDArray[float64[3,2]]' and an error message that you *gave* a
// numpy.ndarray of the right type and dimensions.
const_name<show_writeable>(", flags.writeable", "")
+ const_name<show_c_contiguous>(", flags.c_contiguous", "")
+ const_name<show_f_contiguous>(", flags.f_contiguous", "") + const_name("]");
// 'typing.Annotated[numpy.typing.NDArray[numpy.float64], "[3,2]"]' and an error message
// that you *gave* a numpy.ndarray of the right type and dimensions.
+ const_name<show_writeable>(", \"flags.writeable\"", "")
+ const_name<show_c_contiguous>(", \"flags.c_contiguous\"", "")
+ const_name<show_f_contiguous>(", \"flags.f_contiguous\"", "") + const_name("]");
};

// Casts an Eigen type to numpy array. If given a base, the numpy array references the src data,
Expand Down
13 changes: 8 additions & 5 deletions include/pybind11/eigen/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,16 @@ struct eigen_tensor_helper<
template <typename Type, bool ShowDetails, bool NeedsWriteable = false>
struct get_tensor_descriptor {
static constexpr auto details
= const_name<NeedsWriteable>(", flags.writeable", "") + const_name
= const_name<NeedsWriteable>(", \"flags.writeable\"", "") + const_name
< static_cast<int>(Type::Layout)
== static_cast<int>(Eigen::RowMajor) > (", flags.c_contiguous", ", flags.f_contiguous");
== static_cast<int>(Eigen::RowMajor)
> (", \"flags.c_contiguous\"", ", \"flags.f_contiguous\"");
static constexpr auto value
= const_name("numpy.typing.NDArray[") + npy_format_descriptor<typename Type::Scalar>::name
+ const_name("[") + eigen_tensor_helper<remove_cv_t<Type>>::dimensions_descriptor
+ const_name("]") + const_name<ShowDetails>(details, const_name("")) + const_name("]");
= const_name("typing.Annotated[")
+ io_name("numpy.typing.ArrayLike, ", "numpy.typing.NDArray[")
+ npy_format_descriptor<typename Type::Scalar>::name + io_name("", "]")
+ const_name(", \"[") + eigen_tensor_helper<remove_cv_t<Type>>::dimensions_descriptor
+ const_name("]\"") + const_name<ShowDetails>(details, const_name("")) + const_name("]");
};

// When EIGEN_AVOID_STL_ARRAY is defined, Eigen::DSizes<T, 0> does not have the begin() member
Expand Down
6 changes: 2 additions & 4 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -1444,9 +1444,7 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
return src.inc_ref();
}
PYBIND11_TYPE_CASTER(type,
io_name("numpy.typing.ArrayLike",
"numpy.typing.NDArray[" + npy_format_descriptor<T>::name + "]"));
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name);
};

template <typename T>
Expand Down Expand Up @@ -2184,7 +2182,7 @@ vectorize_helper<Func, Return, Args...> vectorize_extractor(const Func &f, Retur
template <typename T, int Flags>
struct handle_type_name<array_t<T, Flags>> {
static constexpr auto name
= const_name("numpy.typing.NDArray[") + npy_format_descriptor<T>::name + const_name("]");
= io_name("typing.Annotated[numpy.typing.ArrayLike, ", "numpy.typing.NDArray[") + npy_format_descriptor<T>::name + const_name("]");
};

PYBIND11_NAMESPACE_END(detail)
Expand Down
38 changes: 19 additions & 19 deletions tests/test_eigen_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,19 @@ def test_mutator_descriptors():
with pytest.raises(TypeError) as excinfo:
m.fixed_mutator_r(zc)
assert (
"(arg0: numpy.typing.NDArray[numpy.float32[5, 6],"
" flags.writeable, flags.c_contiguous]) -> None" in str(excinfo.value)
"(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, \"[5, 6]\","
" \"flags.writeable\", \"flags.c_contiguous\"]) -> None" in str(excinfo.value)
)
with pytest.raises(TypeError) as excinfo:
m.fixed_mutator_c(zr)
assert (
"(arg0: numpy.typing.NDArray[numpy.float32[5, 6],"
" flags.writeable, flags.f_contiguous]) -> None" in str(excinfo.value)
"(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, \"[5, 6]\","
" \"flags.writeable\", \"flags.f_contiguous\"]) -> None" in str(excinfo.value)
)
with pytest.raises(TypeError) as excinfo:
m.fixed_mutator_a(np.array([[1, 2], [3, 4]], dtype="float32"))
assert (
"(arg0: numpy.typing.NDArray[numpy.float32[5, 6], flags.writeable]) -> None"
"(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, \"[5, 6]\", \"flags.writeable\"]) -> None"
in str(excinfo.value)
)
zr.flags.writeable = False
Expand Down Expand Up @@ -202,7 +202,7 @@ def test_negative_stride_from_python(msg):
msg(excinfo.value)
== """
double_threer(): incompatible function arguments. The following argument types are supported:
1. (arg0: numpy.typing.NDArray[numpy.float32[1, 3], flags.writeable]) -> None
1. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, \"[1, 3]\", \"flags.writeable\"]) -> None
Invoked with: """
+ repr(np.array([5.0, 4.0, 3.0], dtype="float32"))
Expand All @@ -214,7 +214,7 @@ def test_negative_stride_from_python(msg):
msg(excinfo.value)
== """
double_threec(): incompatible function arguments. The following argument types are supported:
1. (arg0: numpy.typing.NDArray[numpy.float32[3, 1], flags.writeable]) -> None
1. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, \"[3, 1]\", \"flags.writeable\"]) -> None
Invoked with: """
+ repr(np.array([7.0, 4.0, 1.0], dtype="float32"))
Expand Down Expand Up @@ -635,37 +635,37 @@ def test_nocopy_wrapper():
with pytest.raises(TypeError) as excinfo:
m.get_elem_nocopy(int_matrix_colmajor)
assert "get_elem_nocopy(): incompatible function arguments." in str(excinfo.value)
assert ", flags.f_contiguous" in str(excinfo.value)
assert ", \"flags.f_contiguous\"" in str(excinfo.value)
assert m.get_elem_nocopy(dbl_matrix_colmajor) == 8
with pytest.raises(TypeError) as excinfo:
m.get_elem_nocopy(int_matrix_rowmajor)
assert "get_elem_nocopy(): incompatible function arguments." in str(excinfo.value)
assert ", flags.f_contiguous" in str(excinfo.value)
assert ", \"flags.f_contiguous\"" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
m.get_elem_nocopy(dbl_matrix_rowmajor)
assert "get_elem_nocopy(): incompatible function arguments." in str(excinfo.value)
assert ", flags.f_contiguous" in str(excinfo.value)
assert ", \"flags.f_contiguous\"" in str(excinfo.value)

# For the row-major test, we take a long matrix in row-major, so only the third is allowed:
with pytest.raises(TypeError) as excinfo:
m.get_elem_rm_nocopy(int_matrix_colmajor)
assert "get_elem_rm_nocopy(): incompatible function arguments." in str(
excinfo.value
)
assert ", flags.c_contiguous" in str(excinfo.value)
assert ", \"flags.c_contiguous\"" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
m.get_elem_rm_nocopy(dbl_matrix_colmajor)
assert "get_elem_rm_nocopy(): incompatible function arguments." in str(
excinfo.value
)
assert ", flags.c_contiguous" in str(excinfo.value)
assert ", \"flags.c_contiguous\"" in str(excinfo.value)
assert m.get_elem_rm_nocopy(int_matrix_rowmajor) == 8
with pytest.raises(TypeError) as excinfo:
m.get_elem_rm_nocopy(dbl_matrix_rowmajor)
assert "get_elem_rm_nocopy(): incompatible function arguments." in str(
excinfo.value
)
assert ", flags.c_contiguous" in str(excinfo.value)
assert ", \"flags.c_contiguous\"" in str(excinfo.value)


def test_eigen_ref_life_support():
Expand Down Expand Up @@ -701,25 +701,25 @@ def test_dense_signature(doc):
assert (
doc(m.double_col)
== """
double_col(arg0: numpy.typing.NDArray[numpy.float32[m, 1]]) -> numpy.typing.NDArray[numpy.float32[m, 1]]
double_col(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, \"[m, 1]\"]) -> typing.Annotated[numpy.typing.NDArray[numpy.float32], \"[m, 1]\"]
"""
)
assert (
doc(m.double_row)
== """
double_row(arg0: numpy.typing.NDArray[numpy.float32[1, n]]) -> numpy.typing.NDArray[numpy.float32[1, n]]
double_row(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, \"[1, n]\"]) -> typing.Annotated[numpy.typing.NDArray[numpy.float32], \"[1, n]\"]
"""
)
assert doc(m.double_complex) == (
"""
double_complex(arg0: numpy.typing.NDArray[numpy.complex64[m, 1]])"""
""" -> numpy.typing.NDArray[numpy.complex64[m, 1]]
double_complex(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.complex64, \"[m, 1]\"])"""
""" -> typing.Annotated[numpy.typing.NDArray[numpy.complex64], \"[m, 1]\"]
"""
)
assert doc(m.double_mat_rm) == (
"""
double_mat_rm(arg0: numpy.typing.NDArray[numpy.float32[m, n]])"""
""" -> numpy.typing.NDArray[numpy.float32[m, n]]
double_mat_rm(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, \"[m, n]\"])"""
""" -> typing.Annotated[numpy.typing.NDArray[numpy.float32], \"[m, n]\"]
"""
)

Expand Down
16 changes: 8 additions & 8 deletions tests/test_eigen_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,23 +272,23 @@ def test_round_trip_references_actually_refer(m):
def test_doc_string(m, doc):
assert (
doc(m.copy_tensor)
== "copy_tensor() -> numpy.typing.NDArray[numpy.float64[?, ?, ?]]"
== "copy_tensor() -> typing.Annotated[numpy.typing.NDArray[numpy.float64], \"[?, ?, ?]\"]"
)
assert (
doc(m.copy_fixed_tensor)
== "copy_fixed_tensor() -> numpy.typing.NDArray[numpy.float64[3, 5, 2]]"
== "copy_fixed_tensor() -> typing.Annotated[numpy.typing.NDArray[numpy.float64], \"[3, 5, 2]\"]"
)
assert (
doc(m.reference_const_tensor)
== "reference_const_tensor() -> numpy.typing.NDArray[numpy.float64[?, ?, ?]]"
== "reference_const_tensor() -> typing.Annotated[numpy.typing.NDArray[numpy.float64], \"[?, ?, ?]\"]"
)

order_flag = f"flags.{m.needed_options.lower()}_contiguous"
order_flag = f"\"flags.{m.needed_options.lower()}_contiguous\""
assert doc(m.round_trip_view_tensor) == (
f"round_trip_view_tensor(arg0: numpy.typing.NDArray[numpy.float64[?, ?, ?], flags.writeable, {order_flag}])"
f" -> numpy.typing.NDArray[numpy.float64[?, ?, ?], flags.writeable, {order_flag}]"
f"round_trip_view_tensor(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float64, \"[?, ?, ?]\", \"flags.writeable\", {order_flag}])"
f" -> typing.Annotated[numpy.typing.NDArray[numpy.float64], \"[?, ?, ?]\", \"flags.writeable\", {order_flag}]"
)
assert doc(m.round_trip_const_view_tensor) == (
f"round_trip_const_view_tensor(arg0: numpy.typing.NDArray[numpy.float64[?, ?, ?], {order_flag}])"
" -> numpy.typing.NDArray[numpy.float64[?, ?, ?]]"
f"round_trip_const_view_tensor(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float64, \"[?, ?, ?]\", {order_flag}])"
" -> typing.Annotated[numpy.typing.NDArray[numpy.float64], \"[?, ?, ?]\"]"
)
16 changes: 8 additions & 8 deletions tests/test_numpy_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,13 +321,13 @@ def test_overload_resolution(msg):
msg(excinfo.value)
== """
overloaded(): incompatible function arguments. The following argument types are supported:
1. (arg0: numpy.typing.NDArray[numpy.float64]) -> str
2. (arg0: numpy.typing.NDArray[numpy.float32]) -> str
3. (arg0: numpy.typing.NDArray[numpy.int32]) -> str
4. (arg0: numpy.typing.NDArray[numpy.uint16]) -> str
5. (arg0: numpy.typing.NDArray[numpy.int64]) -> str
6. (arg0: numpy.typing.NDArray[numpy.complex128]) -> str
7. (arg0: numpy.typing.NDArray[numpy.complex64]) -> str
1. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float64]) -> str
2. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32]) -> str
3. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.int32]) -> str
4. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.uint16]) -> str
5. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.int64]) -> str
6. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.complex128]) -> str
7. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.complex64]) -> str
Invoked with: 'not an array'
"""
Expand Down Expand Up @@ -528,7 +528,7 @@ def test_index_using_ellipsis():
],
)
def test_format_descriptors_for_floating_point_types(test_func):
assert "numpy.typing.NDArray[numpy.float" in test_func.__doc__
assert "numpy.typing.ArrayLike, numpy.float" in test_func.__doc__


@pytest.mark.parametrize("forcecast", [False, True])
Expand Down
10 changes: 5 additions & 5 deletions tests/test_numpy_vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_docs(doc):
assert (
doc(m.vectorized_func)
== """
vectorized_func(arg0: numpy.typing.NDArray[numpy.int32], arg1: numpy.typing.NDArray[numpy.float32], arg2: numpy.typing.NDArray[numpy.float64]) -> object
vectorized_func(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.int32], arg1: typing.Annotated[numpy.typing.ArrayLike, numpy.float32], arg2: typing.Annotated[numpy.typing.ArrayLike, numpy.float64]) -> object
"""
)

Expand Down Expand Up @@ -212,12 +212,12 @@ def test_passthrough_arguments(doc):
+ ", ".join(
[
"arg0: float",
"arg1: numpy.typing.NDArray[numpy.float64]",
"arg2: numpy.typing.NDArray[numpy.float64]",
"arg3: numpy.typing.NDArray[numpy.int32]",
"arg1: typing.Annotated[numpy.typing.ArrayLike, numpy.float64]",
"arg2: typing.Annotated[numpy.typing.ArrayLike, numpy.float64]",
"arg3: typing.Annotated[numpy.typing.ArrayLike, numpy.int32]",
"arg4: int",
"arg5: m.numpy_vectorize.NonPODClass",
"arg6: numpy.typing.NDArray[numpy.float64]",
"arg6: typing.Annotated[numpy.typing.ArrayLike, numpy.float64]",
]
)
+ ") -> object"
Expand Down

0 comments on commit b2eb337

Please sign in to comment.