diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 02946ff..a8e7329 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -30,7 +30,7 @@ jobs: conda activate cpp-conda mkdir build cd build - cmake .. + cmake -DCMAKE_PREFIX_PATH=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'` .. make - name: Check Code Format diff --git a/README.md b/README.md index 12b94d1..24d8d1e 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ it _might_ work on other systems. - [xtensor](https://xtensor.readthedocs.io/en/latest/) - [nlohmann/json](https://github.com/nlohmann/json#serialization--deserialization) - [protobuf](https://github.com/protocolbuffers/protobuf) + - [PyTorch](https://pytorch.org) -> For CUDA support, you will need to install + the [CUDA toolkit](https://developer.nvidia.com/cuda-toolkit) yourself - Dev tools - Formatting (Google style) with [clang-format](https://clang.llvm.org/docs/ClangFormat.html) @@ -77,7 +79,7 @@ conda activate cpp-conda ```bash mkdir build cd build -cmake .. +cmake -DCMAKE_PREFIX_PATH=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'` .. make ``` diff --git a/environment.yaml b/environment.yaml index 8b18ecf..77dc995 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,6 +1,7 @@ name: cpp-conda channels: - conda-forge + - pytorch - defaults dependencies: - python=3.8.* @@ -17,5 +18,6 @@ dependencies: - xsimd=7.* - nlohmann_json=3.* - protobuf=3.* # includes Protobuf compiler, C++ headers, Python libraries + - pytorch=1.7.* # Development - cmake_format diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f84a093..9627e92 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,6 +3,13 @@ find_package(xtensor REQUIRED) find_package(absl REQUIRED) find_package(nlohmann_json REQUIRED) +# +# PyTorch +# + +find_package(Torch REQUIRED) +set(CMAKE_CXX_FLAGS "${TORCH_CXX_FLAGS} ${CMAKE_CXX_FLAGS}") + # # Protobuf # @@ -23,8 +30,14 @@ message("PROTO_HDRS = ${PROTO_HDRS}") add_executable(main main.cpp ${PROTO_SRCS} ${PROTO_HDRS}) target_link_libraries( - main PRIVATE xtensor xtensor::optimize xtensor::use_xsimd absl::flat_hash_map - private_library ${PROTOBUF_LIBRARIES}) + main + PRIVATE xtensor + xtensor::optimize + xtensor::use_xsimd + absl::flat_hash_map + private_library + ${PROTOBUF_LIBRARIES} + ${TORCH_LIBRARIES}) target_include_directories( main PRIVATE . ${PROTOBUF_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR} # In order to capture the *.pb.h diff --git a/src/main.cpp b/src/main.cpp index e3d9d5e..c55953b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -50,5 +51,10 @@ int main() { std::cout << "Protobuf\n"; projectname::AddressBook address_book; + std::cout << "Random torch tensor\n"; + torch::Tensor tensor = torch::rand({2, 3}); // NOLINT + std::cout << tensor << std::endl; + std::cout << "cuda is available: " << torch::cuda::is_available() << '\n'; + return 0; }