From 580fe6012677c97017cfa1570bde3edaa07623da Mon Sep 17 00:00:00 2001 From: Erick Fuentes Date: Fri, 6 Dec 2024 18:50:36 -0500 Subject: [PATCH] Add example of calling matplotlib from c++ (#346) --- common/BUILD | 28 +++++ common/matplotlib.cc | 49 +++++++++ common/matplotlib.hh | 17 +++ common/matplotlib_test.cc | 13 +++ common/python/embedded_py.bzl | 100 ++++++++++++++++++ experimental/overhead_matching/BUILD | 1 + .../overhead_matching/spectacular_log_test.cc | 31 +++++- 7 files changed, 238 insertions(+), 1 deletion(-) create mode 100644 common/matplotlib.cc create mode 100644 common/matplotlib.hh create mode 100644 common/matplotlib_test.cc create mode 100644 common/python/embedded_py.bzl diff --git a/common/BUILD b/common/BUILD index 9d2bf3da..ed0bddf6 100644 --- a/common/BUILD +++ b/common/BUILD @@ -3,6 +3,8 @@ package(features=["warning_compile_flags"]) load("@pip//:requirements.bzl", "requirement") load("@rules_python//python:defs.bzl", "py_library") +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") +load("@//common/python:embedded_py.bzl", "cc_py_library", "cc_py_test") cc_library( name="argument_wrapper", @@ -63,6 +65,32 @@ cc_test( ] ) +cc_py_library( + name = "matplotlib", + hdrs = ["matplotlib.hh"], + srcs = ["matplotlib.cc"], + visibility= ["//visibility:public"], + py_deps = [ + requirement("matplotlib"), + requirement("PyGObject"), + ], + deps = [ + "@pybind11", + "@rules_python//python/cc:current_py_cc_headers", + "@rules_python//python/cc:current_py_cc_libs", + ], +) + +cc_test( + name = "matplotlib_test", + srcs = ["matplotlib_test.cc"], + tags = ["manual"], + deps = [ + ":matplotlib", + "@com_google_googletest//:gtest_main", + ], +) + py_library( name = "torch", srcs = ["torch.py"], diff --git a/common/matplotlib.cc b/common/matplotlib.cc new file mode 100644 index 00000000..12696dd0 --- /dev/null +++ b/common/matplotlib.cc @@ -0,0 +1,49 @@ + +#include "common/matplotlib.hh" + +#include +#include + +#include "pybind11/embed.h" +#include "pybind11/stl.h" + +namespace py = pybind11; +using namespace pybind11::literals; + +namespace robot { +namespace { +wchar_t *to_wchar(const char *str) { + const size_t num_chars = std::mbstowcs(nullptr, str, 0) + 1; + if (num_chars == std::numeric_limits::max()) { + return nullptr; + } + wchar_t *out = static_cast(malloc(num_chars * sizeof(wchar_t))); + + std::mbstowcs(out, str, num_chars); + + return out; +} +} // namespace + +void plot(const std::vector &signals, const bool block) { + PyConfig config; + PyConfig_InitPythonConfig(&config); + config.home = to_wchar(CPP_PYTHON_HOME); + config.pathconfig_warnings = 1; + config.program_name = to_wchar(CPP_PYVENV_LAUNCHER); + config.pythonpath_env = to_wchar(CPP_PYTHON_PATH); + config.user_site_directory = 0; + py::scoped_interpreter guard{&config}; + + py::module_ mpl = py::module_::import("matplotlib"); + mpl.attr("use")("GTK3Agg"); + py::module_ plt = py::module_::import("matplotlib.pyplot"); + + plt.attr("figure")(); + for (const auto &signal : signals) { + plt.attr("plot")(signal.x, signal.y, signal.marker, "label"_a = signal.label); + } + plt.attr("legend")(); + plt.attr("show")("block"_a = block); +} +} // namespace robot diff --git a/common/matplotlib.hh b/common/matplotlib.hh new file mode 100644 index 00000000..b3590388 --- /dev/null +++ b/common/matplotlib.hh @@ -0,0 +1,17 @@ + +#pragma once + +#include +#include + +namespace robot { +struct PlotSignal { + std::vector x; + std::vector y; + std::string label = ""; + std::string marker = "-"; +}; + +void plot(const std::vector &signals, const bool block = true); + +} // namespace robot diff --git a/common/matplotlib_test.cc b/common/matplotlib_test.cc new file mode 100644 index 00000000..02493c49 --- /dev/null +++ b/common/matplotlib_test.cc @@ -0,0 +1,13 @@ + +#include "common/matplotlib.hh" + +#include "gtest/gtest.h" + +namespace robot { + +TEST(MatplotlibTest, simple_plot) { + EXPECT_NO_THROW(plot( + {{.x = std::vector{0.0, 1.0, 2.0}, .y = std::vector{10.0, 20.0, 15.0}, .label = "label"}}, + false)); +} +} // namespace robot diff --git a/common/python/embedded_py.bzl b/common/python/embedded_py.bzl new file mode 100644 index 00000000..32f5d0a1 --- /dev/null +++ b/common/python/embedded_py.bzl @@ -0,0 +1,100 @@ +def _cc_py_runtime_impl(ctx): + toolchain = ctx.toolchains["@bazel_tools//tools/python:toolchain_type"] + py3_runtime = toolchain.py3_runtime + imports = [] + for dep in ctx.attr.deps: + imports.append(dep[PyInfo].imports) + python_path = "" + for path in depset(transitive = imports).to_list(): + # print("Printing python path: " + str(path)) + python_path += "external/" + path + ":" + + py3_runfiles = ctx.runfiles(files = py3_runtime.files.to_list()) + runfiles = [py3_runfiles] + for dep in ctx.attr.deps: + dep_runfiles = ctx.runfiles(files = dep[PyInfo].transitive_sources.to_list()) + runfiles.append(dep_runfiles) + runfiles.append(dep[DefaultInfo].default_runfiles) + + runfiles = ctx.runfiles().merge_all(runfiles) + + # print("Printing interpreter path: " + str(py3_runtime.interpreter.path)) + # print("Printing interpreter home: " + str(py3_runtime.interpreter.dirname.rstrip("bin"))) + + return [ + DefaultInfo(runfiles = runfiles), + platform_common.TemplateVariableInfo({ + "PYTHON3": str(py3_runtime.interpreter.path), + "PYTHONPATH": python_path, + "PYTHONHOME": str(py3_runtime.interpreter.dirname.rstrip("bin")), + }), + ] + +_cc_py_runtime = rule( + implementation = _cc_py_runtime_impl, + attrs = { + "deps": attr.label_list(providers = [PyInfo]), + }, + toolchains = [ + str(Label("@bazel_tools//tools/python:toolchain_type")), + ], +) + +def cc_py_test(name, py_deps = [], **kwargs): + py_runtime_target = name + "_py_runtime" + _cc_py_runtime( + name = py_runtime_target, + deps = py_deps, + ) + + kwargs.update({ + "data": kwargs.get("data", []) + [":" + py_runtime_target], + "env": {"__PYVENV_LAUNCHER__": "$(PYTHON3)", "PYTHONPATH": "$(PYTHONPATH)", "PYTHONHOME": "$(PYTHONHOME)", "PYTHONNOUSERSITE": "1"}, + "toolchains": kwargs.get("toolchains", []) + [":" + py_runtime_target], + }) + + native.cc_test( + name = name, + **kwargs + ) + +def cc_py_binary(name, py_deps = [], **kwargs): + py_runtime_target = name + "_py_runtime" + _cc_py_runtime( + name = py_runtime_target, + deps = py_deps, + ) + + kwargs.update({ + "data": kwargs.get("data", []) + [":" + py_runtime_target], + "env": {"__PYVENV_LAUNCHER__": "$(PYTHON3)", "PYTHONPATH": "$(PYTHONPATH)", "PYTHONHOME": "$(PYTHONHOME)", "PYTHONNOUSERSITE": "1"}, + "toolchains": kwargs.get("toolchains", []) + [":" + py_runtime_target], + }) + + native.cc_binary( + name = name, + **kwargs + ) + +def cc_py_library(name, py_deps = [], **kwargs): + py_runtime_target = name + "_py_runtime" + _cc_py_runtime( + name = py_runtime_target, + deps = py_deps, + ) + + kwargs.update({ + "data": kwargs.get("data", []) + [":" + py_runtime_target], + "defines": [ + "CPP_PYVENV_LAUNCHER=\\\"$(PYTHON3)\\\"", + "CPP_PYTHON_PATH=\\\"$(PYTHONPATH)\\\"", + "CPP_PYTHON_HOME=\\\"$(PYTHONHOME)\\\"", + "PYTHONNOUSERSITE=\\\"1\\\"", + ], + "toolchains": kwargs.get("toolchains", []) + [":" + py_runtime_target], + }) + + native.cc_library( + name = name, + **kwargs + ) diff --git a/experimental/overhead_matching/BUILD b/experimental/overhead_matching/BUILD index ca15a04d..f807768c 100644 --- a/experimental/overhead_matching/BUILD +++ b/experimental/overhead_matching/BUILD @@ -27,6 +27,7 @@ cc_test( srcs = ["spectacular_log_test.cc"], data = ["@spectacular_log_snippet//:files"], deps = [ + "//common:matplotlib", ":spectacular_log", "@com_google_googletest//:gtest_main", "@fmt", diff --git a/experimental/overhead_matching/spectacular_log_test.cc b/experimental/overhead_matching/spectacular_log_test.cc index a1f91317..a0c2dd43 100644 --- a/experimental/overhead_matching/spectacular_log_test.cc +++ b/experimental/overhead_matching/spectacular_log_test.cc @@ -4,6 +4,7 @@ #include #include +#include "common/matplotlib.hh" #include "fmt/format.h" #include "gtest/gtest.h" #include "opencv2/opencv.hpp" @@ -39,18 +40,46 @@ TEST(SpectacularLogTest, happy_case) { << " frame time: (" << log.min_frame_time() << ", " << log.max_frame_time() << ")" << std::endl; + std::vector ts; + std::vector axs; + std::vector ays; + std::vector azs; + std::vector gxs; + std::vector gys; + std::vector gzs; std::cout << "IMU Samples" << std::endl; for (time::RobotTimestamp t = log.min_imu_time(); - t < std::min(log.min_imu_time() + time::as_duration(5.0), log.max_imu_time()); + t < std::min(log.min_imu_time() + time::as_duration(20.0), log.max_imu_time()); t += time::as_duration(0.1)) { const auto sample = log.get_imu_sample(t); if (sample.has_value()) { + ts.push_back(std::chrono::duration(t.time_since_epoch()).count()); + axs.push_back(sample->accel_mpss.x()); + ays.push_back(sample->accel_mpss.y()); + azs.push_back(sample->accel_mpss.z()); + gxs.push_back(sample->gyro_radps.x()); + gys.push_back(sample->gyro_radps.y()); + gzs.push_back(sample->gyro_radps.z()); std::cout << "t: " << sample->time_of_validity << " accel: " << sample->accel_mpss.transpose() << " gyro: " << sample->gyro_radps.transpose() << std::endl; } } + if (false) { + const bool should_block = false; + plot( + { + {ts, axs, "ax"}, + {ts, ays, "ay"}, + {ts, azs, "az"}, + {ts, gxs, "gx"}, + {ts, gys, "gy"}, + {ts, gzs, "gz"}, + }, + should_block); + } + cv::VideoCapture video(log_path / "data.mov", cv::CAP_FFMPEG); constexpr int FRAME_SKIP = 50; cv::Mat expected_frame;