Skip to content

Commit

Permalink
Specialized function dispatchers for very simple functions (#944)
Browse files Browse the repository at this point in the history
Nanobind previously distinguished between a "complex" and a "simple"
function dispatcher. This PR adds variants of the simple dispatcher that
further specializes to 0 and 1-argument functions without overloads. One
common class of functions that benefits are property getters.

The speedup is pretty small (~2%), but we will take it :-).
  • Loading branch information
wjakob authored Feb 21, 2025
1 parent c4a10ea commit 7aa69a2
Showing 1 changed file with 116 additions and 2 deletions.
118 changes: 116 additions & 2 deletions src/nb_func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ NAMESPACE_BEGIN(detail)
// Forward/external declarations
extern Buffer buf;

static PyObject *nb_func_vectorcall_simple_0(PyObject *, PyObject *const *,
size_t, PyObject *) noexcept;
static PyObject *nb_func_vectorcall_simple_1(PyObject *, PyObject *const *,
size_t, PyObject *) noexcept;
static PyObject *nb_func_vectorcall_simple(PyObject *, PyObject *const *,
size_t, PyObject *) noexcept;
static PyObject *nb_func_vectorcall_complex(PyObject *, PyObject *const *,
Expand Down Expand Up @@ -335,8 +339,20 @@ PyObject *nb_func_new(const void *in_) noexcept {

func->max_nargs = max_nargs;
func->complex_call = complex_call;
func->vectorcall = complex_call ? nb_func_vectorcall_complex
: nb_func_vectorcall_simple;


PyObject* (*vectorcall)(PyObject *, PyObject * const*, size_t, PyObject *);
if (complex_call) {
vectorcall = nb_func_vectorcall_complex;
} else {
if (f->nargs == 0 && !prev_overloads)
vectorcall = nb_func_vectorcall_simple_0;
else if (f->nargs == 1 && !prev_overloads)
vectorcall = nb_func_vectorcall_simple_1;
else
vectorcall = nb_func_vectorcall_simple;
}
func->vectorcall = vectorcall;

#if !defined(NB_FREE_THREADED)
// Register the function
Expand Down Expand Up @@ -954,6 +970,104 @@ static PyObject *nb_func_vectorcall_simple(PyObject *self,
return result;
}

/// Simplified nb_func_vectorcall variant for non-overloaded functions with 0 args
static PyObject *nb_func_vectorcall_simple_0(PyObject *self,
PyObject *const *args_in,
size_t nargsf,
PyObject *kwargs_in) noexcept {
func_data *fr = nb_func_data(self);
const size_t nargs_in = (size_t) NB_VECTORCALL_NARGS(nargsf);

// Handler routine that will be invoked in case of an error condition
PyObject *(*error_handler)(PyObject *, PyObject *const *, size_t,
PyObject *) noexcept = nullptr;

PyObject *result = nullptr;

if (kwargs_in == nullptr && nargs_in == 0) {
try {
result = fr->impl((void *) fr->capture, (PyObject **) args_in,
nullptr, (rv_policy) (fr->flags & 0b111), nullptr);
if (result == NB_NEXT_OVERLOAD)
error_handler = nb_func_error_overload;
else if (!result)
error_handler = nb_func_error_noconvert;
} catch (builtin_exception &e) {
if (!set_builtin_exception_status(e))
error_handler = nb_func_error_overload;
} catch (python_error &e) {
e.restore();
} catch (...) {
nb_func_convert_cpp_exception();
}
} else {
error_handler = nb_func_error_overload;
}

if (NB_UNLIKELY(error_handler))
result = error_handler(self, args_in, nargs_in, kwargs_in);

return result;
}

/// Simplified nb_func_vectorcall variant for non-overloaded functions with 1 arg
static PyObject *nb_func_vectorcall_simple_1(PyObject *self,
PyObject *const *args_in,
size_t nargsf,
PyObject *kwargs_in) noexcept {
func_data *fr = nb_func_data(self);
const size_t nargs_in = (size_t) NB_VECTORCALL_NARGS(nargsf);
bool is_constructor = fr->flags & (uint32_t) func_flags::is_constructor;

// Handler routine that will be invoked in case of an error condition
PyObject *(*error_handler)(PyObject *, PyObject *const *, size_t,
PyObject *) noexcept = nullptr;

PyObject *result = nullptr;

if (kwargs_in == nullptr && nargs_in == 1 && args_in[0] != Py_None) {
PyObject *arg = args_in[0];
cleanup_list cleanup(arg);
uint8_t args_flags[1] = {
(uint8_t) (is_constructor ? (1 | (uint8_t) cast_flags::construct) : 1)
};

try {
result = fr->impl((void *) fr->capture, (PyObject **) args_in,
args_flags, (rv_policy) (fr->flags & 0b111), &cleanup);
if (result == NB_NEXT_OVERLOAD) {
error_handler = nb_func_error_overload;
} else if (!result) {
error_handler = nb_func_error_noconvert;
} else if (is_constructor) {
nb_inst *arg_nb = (nb_inst *) arg;
arg_nb->destruct = true;
arg_nb->state = nb_inst::state_ready;
if (NB_UNLIKELY(arg_nb->intrusive))
nb_type_data(Py_TYPE(arg))
->set_self_py(inst_ptr(arg_nb), arg);
}
} catch (builtin_exception &e) {
if (!set_builtin_exception_status(e))
error_handler = nb_func_error_overload;
} catch (python_error &e) {
e.restore();
} catch (...) {
nb_func_convert_cpp_exception();
}

if (NB_UNLIKELY(cleanup.used()))
cleanup.release();
} else {
error_handler = nb_func_error_overload;
}

if (NB_UNLIKELY(error_handler))
result = error_handler(self, args_in, nargs_in, kwargs_in);

return result;
}

static PyObject *nb_bound_method_vectorcall(PyObject *self,
PyObject *const *args_in,
size_t nargsf,
Expand Down

0 comments on commit 7aa69a2

Please sign in to comment.