Skip to content

Commit

Permalink
Add example of calling matplotlib from c++
Browse files Browse the repository at this point in the history
  • Loading branch information
ewfuentes committed Nov 24, 2024
1 parent 8d32967 commit b28cae5
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 0 deletions.
26 changes: 26 additions & 0 deletions common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package(features=["warning_compile_flags"])

load("@pip//:requirements.bzl", "requirement")
load("@rules_python//python:defs.bzl", "py_library")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
load("@//common/python:embedded_py.bzl", "cc_py_library", "cc_py_test")

cc_library(
name="argument_wrapper",
Expand Down Expand Up @@ -63,6 +65,30 @@ cc_test(
]
)

cc_py_library(
name = "matplotlib",
hdrs = ["matplotlib.hh"],
srcs = ["matplotlib.cc"],
py_deps = [
requirement("matplotlib"),
requirement("PyGObject"),
],
deps = [
"@pybind11",
"@rules_python//python/cc:current_py_cc_headers",
"@rules_python//python/cc:current_py_cc_libs",
],
)

cc_test(
name = "matplotlib_test",
srcs = ["matplotlib_test.cc"],
deps = [
":matplotlib",
"@com_google_googletest//:gtest_main",
],
)

py_library(
name = "torch",
srcs = ["torch.py"],
Expand Down
44 changes: 44 additions & 0 deletions common/matplotlib.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

#include "common/matplotlib.hh"

#include <cstdlib>
#include <limits>

#include "pybind11/embed.h"
#include "pybind11/stl.h"

namespace py = pybind11;
using namespace pybind11::literals;

namespace robot {
namespace {
wchar_t *to_wchar(const char *str) {
const size_t num_chars = std::mbstowcs(nullptr, str, 0) + 1;
if (num_chars == std::numeric_limits<size_t>::max()) {
return nullptr;
}
wchar_t *out = static_cast<wchar_t *>(malloc(num_chars * sizeof(wchar_t)));

std::mbstowcs(out, str, num_chars);

return out;
}
} // namespace

void plot(const std::vector<double> &x, const std::vector<double> &y) {
PyConfig config;
PyConfig_InitPythonConfig(&config);
config.home = to_wchar(CPP_PYTHON_HOME);
config.pathconfig_warnings = 1;
config.program_name = to_wchar(CPP_PYVENV_LAUNCHER);
config.pythonpath_env = to_wchar(CPP_PYTHON_PATH);
config.user_site_directory = 0;
py::scoped_interpreter guard{&config};

py::module_ plt = py::module_::import("matplotlib.pyplot");

plt.attr("figure")();
plt.attr("plot")(x, y);
plt.attr("show")("block"_a = false);
}
} // namespace robot
10 changes: 10 additions & 0 deletions common/matplotlib.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

#pragma once

#include <vector>

namespace robot {

void plot(const std::vector<double> &x, const std::vector<double> &y);

}
9 changes: 9 additions & 0 deletions common/matplotlib_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

#include "common/matplotlib.hh"

#include "gtest/gtest.h"

namespace robot {

TEST(MatplotlibTest, simple_plot) { EXPECT_NO_THROW(plot({0.0, 1.0, 2.0}, {10.0, 20.0, 15.0})); }
} // namespace robot
100 changes: 100 additions & 0 deletions common/python/embedded_py.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
def _cc_py_runtime_impl(ctx):
toolchain = ctx.toolchains["@bazel_tools//tools/python:toolchain_type"]
py3_runtime = toolchain.py3_runtime
imports = []
for dep in ctx.attr.deps:
imports.append(dep[PyInfo].imports)
python_path = ""
for path in depset(transitive = imports).to_list():
# print("Printing python path: " + str(path))
python_path += "external/" + path + ":"

py3_runfiles = ctx.runfiles(files = py3_runtime.files.to_list())
runfiles = [py3_runfiles]
for dep in ctx.attr.deps:
dep_runfiles = ctx.runfiles(files = dep[PyInfo].transitive_sources.to_list())
runfiles.append(dep_runfiles)
runfiles.append(dep[DefaultInfo].default_runfiles)

runfiles = ctx.runfiles().merge_all(runfiles)

# print("Printing interpreter path: " + str(py3_runtime.interpreter.path))
# print("Printing interpreter home: " + str(py3_runtime.interpreter.dirname.rstrip("bin")))

return [
DefaultInfo(runfiles = runfiles),
platform_common.TemplateVariableInfo({
"PYTHON3": str(py3_runtime.interpreter.path),
"PYTHONPATH": python_path,
"PYTHONHOME": str(py3_runtime.interpreter.dirname.rstrip("bin")),
}),
]

_cc_py_runtime = rule(
implementation = _cc_py_runtime_impl,
attrs = {
"deps": attr.label_list(providers = [PyInfo]),
},
toolchains = [
str(Label("@bazel_tools//tools/python:toolchain_type")),
],
)

def cc_py_test(name, py_deps = [], **kwargs):
py_runtime_target = name + "_py_runtime"
_cc_py_runtime(
name = py_runtime_target,
deps = py_deps,
)

kwargs.update({
"data": kwargs.get("data", []) + [":" + py_runtime_target],
"env": {"__PYVENV_LAUNCHER__": "$(PYTHON3)", "PYTHONPATH": "$(PYTHONPATH)", "PYTHONHOME": "$(PYTHONHOME)", "PYTHONNOUSERSITE": "1"},
"toolchains": kwargs.get("toolchains", []) + [":" + py_runtime_target],
})

native.cc_test(
name = name,
**kwargs
)

def cc_py_binary(name, py_deps = [], **kwargs):
py_runtime_target = name + "_py_runtime"
_cc_py_runtime(
name = py_runtime_target,
deps = py_deps,
)

kwargs.update({
"data": kwargs.get("data", []) + [":" + py_runtime_target],
"env": {"__PYVENV_LAUNCHER__": "$(PYTHON3)", "PYTHONPATH": "$(PYTHONPATH)", "PYTHONHOME": "$(PYTHONHOME)", "PYTHONNOUSERSITE": "1"},
"toolchains": kwargs.get("toolchains", []) + [":" + py_runtime_target],
})

native.cc_binary(
name = name,
**kwargs
)

def cc_py_library(name, py_deps = [], **kwargs):
py_runtime_target = name + "_py_runtime"
_cc_py_runtime(
name = py_runtime_target,
deps = py_deps,
)

kwargs.update({
"data": kwargs.get("data", []) + [":" + py_runtime_target],
"defines": [
"CPP_PYVENV_LAUNCHER=\\\"$(PYTHON3)\\\"",
"CPP_PYTHON_PATH=\\\"$(PYTHONPATH)\\\"",
"CPP_PYTHON_HOME=\\\"$(PYTHONHOME)\\\"",
"PYTHONNOUSERSITE=\\\"1\\\"",
],
"toolchains": kwargs.get("toolchains", []) + [":" + py_runtime_target],
})

native.cc_library(
name = name,
**kwargs
)

0 comments on commit b28cae5

Please sign in to comment.