Skip to content

Commit

Permalink
[py trajectories] Adjust Clone and MakeDerivative for unique_ptr (#22399
Browse files Browse the repository at this point in the history
)

Because we allow implementations of Trajectory as Python subclasses,
we cannot assume that the deleter associated with a call to Clone is
`delete MyClass`, so we now adjust our PyTrajectory override logic to
wrap its return value in a WrappedTrajectory. The only reason the
unique_ptr used to work is Drake's custom fork or pybind11 with evil
hacks, which will be going away soon.

We also now warn Python subclasses to implement the canonical spelling
of the __deepcopy__ method, instead of overriding the public Clone
method. (Overriding Clone was already documented as deprecated in a
prior commit; this just adds the warning.)
  • Loading branch information
jwnimmer-tri authored Jan 8, 2025
1 parent 2fd6673 commit c13c2cc
Show file tree
Hide file tree
Showing 7 changed files with 347 additions and 16 deletions.
19 changes: 18 additions & 1 deletion bindings/pydrake/_trajectories_extra.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pydrake.common import _MangledName
from pydrake.common import (
_MangledName,
pretty_class_name as _pretty_class_name,
)


def __getattr__(name):
Expand All @@ -7,3 +10,17 @@ def __getattr__(name):
"""
return _MangledName.module_getattr(
module_name=__name__, module_globals=globals(), name=name)


def _wrapped_trajectory_repr(wrapped_trajectory):
cls = type(wrapped_trajectory)
return f"{_pretty_class_name(cls)}({wrapped_trajectory.unwrap()!r})"


def _add_repr_functions():
for param in _WrappedTrajectory_.param_list:
cls = _WrappedTrajectory_[param]
setattr(cls, "__repr__", _wrapped_trajectory_repr)


_add_repr_functions()
40 changes: 33 additions & 7 deletions bindings/pydrake/test/trajectories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
PiecewiseQuaternionSlerp_,
StackedTrajectory_,
Trajectory,
Trajectory_
Trajectory_,
_WrappedTrajectory_,
)
from pydrake.symbolic import Variable, Expression

Expand All @@ -37,10 +38,12 @@ class CustomTrajectory(Trajectory):
def __init__(self):
Trajectory.__init__(self)

# TODO(jwnimmer-tri) Should be __deepcopy__ not Clone.
def Clone(self):
def __deepcopy__(self, memo):
return CustomTrajectory()

def __repr__(self):
return "CustomTrajectory()"

def do_value(self, t):
return np.array([[t + 1.0, t + 2.0]])

Expand Down Expand Up @@ -129,6 +132,7 @@ def test_custom_trajectory(self):
self.assertEqual(trajectory.start_time(), 3.0)
self.assertEqual(trajectory.end_time(), 4.0)
self.assertTrue(trajectory.has_derivative())
self.assertEqual(repr(trajectory), "CustomTrajectory()")
numpy_compare.assert_float_equal(trajectory.value(t=1.5),
np.array([[2.5, 3.5]]))
numpy_compare.assert_float_equal(
Expand All @@ -137,12 +141,18 @@ def test_custom_trajectory(self):
numpy_compare.assert_float_equal(
trajectory.EvalDerivative(t=2.3, derivative_order=2),
np.zeros((1, 2)))

clone = trajectory.Clone()
numpy_compare.assert_float_equal(clone.value(t=1.5),
np.array([[2.5, 3.5]]))
self.assertEqual(repr(clone), "_WrappedTrajectory(CustomTrajectory())")

deriv = trajectory.MakeDerivative(derivative_order=1)
numpy_compare.assert_float_equal(
deriv.value(t=2.3), np.ones((1, 2)))
self.assertIn(
"_WrappedTrajectory(<pydrake.trajectories.DerivativeTrajectory",
repr(deriv))

def test_legacy_custom_trajectory(self):
trajectory = LegacyCustomTrajectory()
Expand All @@ -164,10 +174,10 @@ def test_legacy_custom_trajectory(self):
# that we trigger the deprecation warnings.
stacked = StackedTrajectory_[float]()
with catch_drake_warnings():
# The C++ code calls rows() and cols() -- both of which cause
# deprecation warnings with the legacy overrides -- but the total
# number of calls varies between Debug and Release, so we don't
# check the exact tally here.
# The C++ code calls rows() and cols() and Clone() -- all of which
# cause deprecation warnings with the legacy overrides -- but the
# total number of calls varies between Debug and Release, so we
# don't check the exact tally here.
stacked.Append(trajectory)
with catch_drake_warnings(expected_count=1):
stacked.value(t=1.5)
Expand Down Expand Up @@ -798,3 +808,19 @@ def test_stacked_trajectory(self, T):
dut.Clone()
copy.copy(dut)
copy.deepcopy(dut)

@numpy_compare.check_all_types
def test_wrapped_trajectory(self, T):
breaks = [0, 1, 2]
samples = [[[0]], [[1]], [[2]]]
zoh = PiecewisePolynomial_[T].ZeroOrderHold(breaks, samples)
dut = _WrappedTrajectory_[T](trajectory=zoh)
self.assertEqual(dut.rows(), 1)
self.assertEqual(dut.cols(), 1)
if T is float:
self.assertIn("_WrappedTrajectory(", repr(dut))
else:
self.assertIn("_WrappedTrajectory_[", repr(dut))
self.assertIn("PiecewisePolynomial", repr(dut))
clone = dut.Clone()
self.assertIn("PiecewisePolynomial", repr(clone))
88 changes: 80 additions & 8 deletions bindings/pydrake/trajectories_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "drake/bindings/pydrake/pydrake_pybind.h"
#include "drake/common/nice_type_name.h"
#include "drake/common/polynomial.h"
#include "drake/common/scope_exit.h"
#include "drake/common/trajectories/bezier_curve.h"
#include "drake/common/trajectories/bspline_trajectory.h"
#include "drake/common/trajectories/composite_trajectory.h"
Expand All @@ -19,6 +20,7 @@
#include "drake/common/trajectories/piecewise_quaternion.h"
#include "drake/common/trajectories/stacked_trajectory.h"
#include "drake/common/trajectories/trajectory.h"
#include "drake/common/trajectories/wrapped_trajectory.h"

namespace drake {
namespace pydrake {
Expand Down Expand Up @@ -184,13 +186,58 @@ struct Impl {
"2025-08-01");
}

// Utility function that takes a Python object which is-a Trajectory and
// wraps it in a unique_ptr that manages object lifetime when returned back
// to C++.
static std::unique_ptr<Trajectory<T>> WrapPyTrajectory(py::object py_traj) {
DRAKE_THROW_UNLESS(!py_traj.is_none());
// Convert py_traj to a shared_ptr<Trajectory<T>> whose C++ lifetime keeps
// the python object alive.
Trajectory<T>* cpp_traj = py::cast<Trajectory<T>*>(py_traj);
DRAKE_THROW_UNLESS(cpp_traj != nullptr);
std::shared_ptr<Trajectory<T>> shared_cpp_traj(
/* stored pointer = */ cpp_traj,
/* deleter = */ [captured_py_traj = std::move(py_traj)]( // BR
void*) mutable {
py::gil_scoped_acquire deleter_guard;
captured_py_traj = py::none();
});
// Wrap the shared_ptr inside a WrappedTrajectory and return that via
// unique_ptr to meet our required return signature.
return std::make_unique<trajectories::internal::WrappedTrajectory<T>>(
std::move(shared_cpp_traj));
}

// Trampoline virtual methods.

std::unique_ptr<Trajectory<T>> DoClone() const final {
// TODO(jwnimmer-tri) Rewrite cloning to use __deepcopy__ in lieu of
// Clone (or DoClone).
PYBIND11_OVERLOAD_PURE(
std::unique_ptr<Trajectory<T>>, Trajectory<T>, Clone);
py::gil_scoped_acquire guard;
// Trajectory subclasses in Python must implement cloning by defining
// either a __deepcopy__ (preferred) or Clone (legacy) method. We'll try
// Clone first so it has priority, but if it doesn't exist we'll fall back
// to __deepcopy__ and just let the "no such method deepcopy" error
// message propagate if both were missing. Because the
// PYBIND11_OVERLOAD_INT macro embeds a conditional `return ...;`
// statement, we must wrap it in lambda so that we can post-process the
// return value in case it does return.
bool used_legacy_clone = true;
auto make_python_deepcopy = [&]() -> py::object {
PYBIND11_OVERLOAD_INT(py::object, Trajectory<T>, "Clone");
used_legacy_clone = false;
auto deepcopy = py::module_::import("copy").attr("deepcopy");
return deepcopy(this);
};
py::object copied = make_python_deepcopy();
if (used_legacy_clone) {
WarnDeprecated(
fmt::format(
"Support for overriding {}.Clone as a virtual function is "
"deprecated. Subclasses should implement __deepcopy__, "
"instead.",
NiceTypeName::Get(*this)),
"2025-08-01");
}
return WrapPyTrajectory(std::move(copied));
}

MatrixX<T> do_value(const T& t) const final {
Expand All @@ -216,10 +263,18 @@ struct Impl {

std::unique_ptr<Trajectory<T>> DoMakeDerivative(
int derivative_order) const final {
PYBIND11_OVERLOAD_INT(std::unique_ptr<Trajectory<T>>, Trajectory<T>,
"DoMakeDerivative", derivative_order);
// If the macro did not return, use default functionality.
return Base::DoMakeDerivative(derivative_order);
py::gil_scoped_acquire guard;
// Because the PYBIND11_OVERLOAD_INT macro embeds a `return ...;`
// statement, we must wrap it in lambda so that we can post-process the
// return value.
auto make_python_derivative = [&]() -> py::object {
PYBIND11_OVERLOAD_INT(
py::object, Trajectory<T>, "DoMakeDerivative", derivative_order);
// If the macro did not return, use the base class error message.
Base::DoMakeDerivative(derivative_order);
DRAKE_UNREACHABLE();
};
return WrapPyTrajectory(make_python_derivative());
}

T do_start_time() const final {
Expand Down Expand Up @@ -760,6 +815,23 @@ struct Impl {
cls_doc.Append.doc);
DefCopyAndDeepCopy(&cls);
}

{
using Class = trajectories::internal::WrappedTrajectory<T>;
auto cls = DefineTemplateClassWithDefault<Class, Trajectory<T>>(
m, "_WrappedTrajectory", param, "(Internal use only)");
cls // BR
.def(py::init([](const Trajectory<T>& trajectory) {
// The keep_alive is responsible for object lifetime, so we'll give
// the constructor an unowned pointer.
return std::make_unique<Class>(
make_unowned_shared_ptr_from_raw(&trajectory));
}),
py::arg("trajectory"),
// Keep alive, ownership: `return` keeps `trajectory` alive.
py::keep_alive<0, 1>())
.def("unwrap", &Class::unwrap, py_rvp::reference_internal);
}
}
};
} // namespace
Expand Down
19 changes: 19 additions & 0 deletions common/trajectories/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ drake_cc_package_library(
":piecewise_trajectory",
":stacked_trajectory",
":trajectory",
":wrapped_trajectory",
],
)

Expand Down Expand Up @@ -219,6 +220,15 @@ drake_cc_library(
],
)

drake_cc_library(
name = "wrapped_trajectory",
srcs = ["wrapped_trajectory.cc"],
hdrs = ["wrapped_trajectory.h"],
deps = [
":trajectory",
],
)

# === test/ ===

drake_cc_googletest(
Expand Down Expand Up @@ -388,4 +398,13 @@ drake_cc_googletest(
],
)

drake_cc_googletest(
name = "wrapped_trajectory_test",
deps = [
":function_handle_trajectory",
":wrapped_trajectory",
"//common/test_utilities:expect_throws_message",
],
)

add_lint_tests()
60 changes: 60 additions & 0 deletions common/trajectories/test/wrapped_trajectory_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "drake/common/trajectories/wrapped_trajectory.h"

#include <gtest/gtest.h>

#include "drake/common/test_utilities/expect_throws_message.h"
#include "drake/common/trajectories/function_handle_trajectory.h"

namespace drake {
namespace trajectories {
namespace internal {
namespace {

Eigen::Vector2d Circle(const double& t) {
return Eigen::Vector2d(std::sin(t), std::cos(t));
}

Eigen::Vector2d CircleDerivative(const double& t, int order) {
DRAKE_DEMAND(order == 2);
return -Circle(t);
}

std::shared_ptr<const Trajectory<double>> MakeFunctionHandleTrajectory() {
const double start_time = 0;
const double end_time = 1;
auto result = std::make_shared<FunctionHandleTrajectory<double>>(
&Circle, 2, 1, start_time, end_time);
result->set_derivative(&CircleDerivative);
return result;
}

GTEST_TEST(WrappedTrajectoryTest, BasicTest) {
const WrappedTrajectory<double> dut(MakeFunctionHandleTrajectory());
EXPECT_EQ(dut.rows(), 2);
EXPECT_EQ(dut.cols(), 1);
EXPECT_EQ(dut.start_time(), 0);
EXPECT_EQ(dut.end_time(), 1);
EXPECT_TRUE(dut.has_derivative());
const double t = 0.25;
EXPECT_EQ(dut.value(t), Circle(t));
EXPECT_EQ(dut.EvalDerivative(t, 2), -Circle(t));
EXPECT_EQ(dut.MakeDerivative(2)->value(t), -Circle(t));

auto clone = dut.Clone();
EXPECT_EQ(clone->rows(), 2);
EXPECT_EQ(clone->cols(), 1);
EXPECT_EQ(clone->start_time(), 0);
EXPECT_EQ(clone->end_time(), 1);
EXPECT_TRUE(clone->has_derivative());
EXPECT_EQ(clone->value(t), Circle(t));
EXPECT_EQ(clone->EvalDerivative(t, 2), -Circle(t));
EXPECT_EQ(clone->MakeDerivative(2)->value(t), -Circle(t));
// We want a FunctionHandleTrajectory, not wrapped. See comment in cc file.
EXPECT_TRUE(
dynamic_cast<const FunctionHandleTrajectory<double>*>(clone.get()));
}

} // namespace
} // namespace internal
} // namespace trajectories
} // namespace drake
Loading

0 comments on commit c13c2cc

Please sign in to comment.