Skip to content

Commit

Permalink
Major refactor to pybind
Browse files Browse the repository at this point in the history
  • Loading branch information
MAminSFV committed Sep 12, 2024
1 parent 6311fe2 commit 11c6835
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 80 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
run: |
python3 -m venv .venv
source .venv/bin/activate
python3 -m pip install nanobind
python3 -m pip install pybind11
python3 -m pip install pytest
python3 -m pip install --verbose .
Expand Down
51 changes: 21 additions & 30 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
# Inspired and adapted from https://github.com/wjakob/nanobind_example/blob/master/CMakeLists.txt
cmake_minimum_required(VERSION 3.16...3.26)
project(drake_extension CXX)
# Inspired and adapted from:
# https://github.com/wjakob/nanobind_example/blob/master/CMakeLists.txt
# and
# https://github.com/pybind/scikit_build_example/blob/master/CMakeLists.txt
cmake_minimum_required(VERSION 3.16...3.27)

# Scikit-build-core sets these values for you, or you can just hard-code the
# name and version.
project(
${SKBUILD_PROJECT_NAME}
VERSION ${SKBUILD_PROJECT_VERSION}
LANGUAGES CXX)


if (NOT SKBUILD)
Expand Down Expand Up @@ -33,41 +42,23 @@ find_package(drake CONFIG REQUIRED)

# Try to import all Python components potentially needed by nanobind
find_package(Python 3.12
REQUIRED COMPONENTS Interpreter Development.Module
OPTIONAL_COMPONENTS Development.SABIModule)

# Import nanobind through CMake's find_package mechanism
find_package(nanobind CONFIG REQUIRED)
REQUIRED COMPONENTS Interpreter Development.Module)
find_package(pybind11 CONFIG REQUIRED)

set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

add_library(custom_system STATIC src/custom_system.cpp)
target_link_libraries(custom_system PUBLIC drake::drake)
set_target_properties(custom_system PROPERTIES CXX_VISIBILITY_PRESET default)

nanobind_add_module(
# Name of the extension
drake_extension_ext
# Target the stable ABI for Python 3.12+, which reduces
# the number of binary wheels that must be built. This
# does nothing on older Python versions
STABLE_ABI
# Build libnanobind statically and merge it into the
# extension (which itself remains a shared library)
#
# If your project builds multiple extensions, you can
# replace this flag by NB_SHARED to conserve space by
# reusing a shared libnanobind across libraries
NB_STATIC

# Source code goes here
src/drake_extension_ext.cc)

target_link_libraries(drake_extension_ext PUBLIC custom_system)
# Add a library using FindPython's tooling (pybind11 also provides a helper like
# this)
python_add_library(drake_extension_ext MODULE src/drake_extension_ext.cpp WITH_SOABI)
target_link_libraries(drake_extension_ext PUBLIC pybind11::headers drake::drake)
set_target_properties(drake_extension_ext PROPERTIES CXX_VISIBILITY_PRESET default)

# This is passing in the version as a define just as an example
target_compile_definitions(_core PRIVATE VERSION_INFO=${PROJECT_VERSION})

# Install directive for scikit-build-core
install(TARGETS drake_extension_ext LIBRARY DESTINATION drake_extension)
73 changes: 60 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
[build-system]
requires = ["scikit-build-core >=0.4.3", "nanobind >=1.3.2"]
requires = ["scikit-build-core>=0.3.3", "pybind11"]
build-backend = "scikit_build_core.build"

[project]
name = "drake_extension"
version = "0.0.1"
description = "An example minimal project that compiles bindings using nanobind and scikit-build"
description = "An example minimal project that compiles bindings using Pybind and scikit-build"
readme = "README.md"
requires-python = ">=3.12"
authors = [
Expand All @@ -18,24 +18,71 @@ classifiers = [
# "drake",
#]

[project.optional-dependencies]
test = ["pytest"]

[project.urls]
Homepage = "https://github.com/MAminSFV/drake_extension"


[tool.scikit-build]
# Protect the configuration against future changes in scikit-build-core
minimum-version = "0.4"
wheel.expand-macos-universal-tags = true

# Setuptools-style build caching in a local directory
build-dir = "build/{wheel_tag}"

# Build stable ABI wheels for CPython 3.12+
wheel.py-api = "cp312"
[tool.pytest.ini_options]
minversion = "6.0"
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"]
xfail_strict = true
log_cli_level = "INFO"
filterwarnings = [
"error",
"ignore::pytest.PytestCacheWarning",
]
testpaths = ["tests"]

[tool.cibuildwheel]
# Necessary to see build output from the actual compilation
build-verbosity = 1

# Run pytest to ensure that the package was correctly built
[tool.cibuildwheel]
build-frontend = "build[uv]"
test-command = "pytest {project}/tests"
test-requires = "pytest"
test-extras = ["test"]

[tool.cibuildwheel.pyodide]
environment.CFLAGS = "-fexceptions"
environment.LDFLAGS = "-fexceptions"
build-frontend = {name = "build", args = ["--exports", "whole_archive"]}

[tool.ruff]
src = ["src"]

[tool.ruff.lint]
extend-select = [
"B", # flake8-bugbear
"I", # isort
"ARG", # flake8-unused-arguments
"C4", # flake8-comprehensions
"EM", # flake8-errmsg
"ICN", # flake8-import-conventions
"G", # flake8-logging-format
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PL", # pylint
"PT", # flake8-pytest-style
"PTH", # flake8-use-pathlib
"RET", # flake8-return
"RUF", # Ruff-specific
"SIM", # flake8-simplify
"T20", # flake8-print
"UP", # pyupgrade
"YTT", # flake8-2020
"EXE", # flake8-executable
"NPY", # NumPy specific rules
"PD", # pandas-vet
]
ignore = [
"PLR09", # Too many X
"PLR2004", # Magic comparison
]
isort.required-imports = ["from __future__ import annotations"]

[tool.ruff.lint.per-file-ignores]
"tests/**" = ["T20"]
27 changes: 0 additions & 27 deletions src/custom_system.cpp

This file was deleted.

2 changes: 1 addition & 1 deletion src/drake_extension/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__all__ = ["SimpleAdder"]

from .custom_system import SimpleAdder
from .drake_extension import SimpleAdder
51 changes: 44 additions & 7 deletions src/drake_extension_ext.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,51 @@
#include <nanobind/nanobind.h>
/**
* @file
* Provides an example of creating a simple Drake C++ system and binding it in
* pybind11, to be used with pydrake.
*/
#include <pybind11/pybind11.h>
#include <drake/systems/framework/leaf_system.h>

namespace nb = nanobind;
using namespace nb::literals;
namespace py = pybind11;

NB_MODULE(drake_extension_ext, m) {
m.doc() = "Example module interfacing with Drake C++";
using drake::systems::BasicVector;
using drake::systems::Context;
using drake::systems::LeafSystem;
using drake::systems::kVectorValued;

namespace drake_extension {

/// Adds a constant to an input.
template <typename T>
class SimpleAdder : public LeafSystem<T> {
public:
explicit SimpleAdder(T add)
: add_(add) {
this->DeclareInputPort("in", kVectorValued, 1);
this->DeclareVectorOutputPort(
"out", BasicVector<T>(1), &SimpleAdder::CalcOutput);
}

private:
void CalcOutput(const Context<T>& context, BasicVector<T>* output) const {
auto u = this->get_input_port(0).Eval(context);
auto&& y = output->get_mutable_value();
y.array() = u.array() + add_;
}

const T add_{};
};


PYBIND11_MODULE(drake_extension, m) {
m.doc() = "Example module interfacing with pydrake and Drake C++";

py::module::import("pydrake.systems.framework");

using T = double;

nb::class_<SimpleAdder<T>, LeafSystem<T>>(m, "SimpleAdder")
.def(nb::init<T>(), nb::arg("add"));
py::class_<SimpleAdder<T>, LeafSystem<T>>(m, "SimpleAdder")
.def(py::init<T>(), py::arg("add"));
}

} // namespace drake_extension
3 changes: 2 additions & 1 deletion tests/test_custom_system.py → tests/test_simple_adder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from drake_extension import SimpleAdder

from pydrake.systems.analysis import Simulator
Expand All @@ -24,4 +25,4 @@ def test_custom_system():

x = logger.FindLog(simulator.get_context()).data()
print("Output values: {}".format(x))
#assert np.allclose(x, 110.)
assert np.allclose(x, 110.)

0 comments on commit 11c6835

Please sign in to comment.