diff --git a/src/triton_object_recognition/CMakeLists.txt b/src/triton_object_recognition/CMakeLists.txt index f692e34e..bfaad342 100644 --- a/src/triton_object_recognition/CMakeLists.txt +++ b/src/triton_object_recognition/CMakeLists.txt @@ -26,6 +26,8 @@ find_package(sensor_msgs REQUIRED) find_package(image_transport REQUIRED) find_package(OpenCV 4.5.3 REQUIRED) find_package(Boost REQUIRED COMPONENTS filesystem) +find_package(Torch REQUIRED) + include_directories(include ${catkin_INCLUDE_DIRS}) include_directories(${OpenCV_INCLUDE_DIRS} ${Boost_INCLUDE_DIRS}) @@ -35,6 +37,8 @@ add_library(triton_object_recognition SHARED ) target_link_libraries(triton_object_recognition ${OpenCV_LIBRARIES} ${Boost_LIBRARIES}) +target_link_libraries(your_executable_name "${TORCH_LIBRARIES}") + rclcpp_components_register_nodes(triton_object_recognition "triton_object_recognition::ObjectRecognizer" diff --git a/src/triton_object_recognition/include/triton_object_recognition/object_recognizer.hpp b/src/triton_object_recognition/include/triton_object_recognition/object_recognizer.hpp index 8fa86fb3..a2ef99c3 100644 --- a/src/triton_object_recognition/include/triton_object_recognition/object_recognizer.hpp +++ b/src/triton_object_recognition/include/triton_object_recognition/object_recognizer.hpp @@ -58,6 +58,8 @@ namespace triton_object_recognition #endif std::shared_ptr net_; + std::shared_ptr module_; + //Default Neural Net Parameters (overriden by parameters) // THESE ARE RELATED TO DARKNET SO I HAVE COMMENTED THEM OUT @@ -66,7 +68,7 @@ namespace triton_object_recognition // std::string cfg_url_ = "https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3-tiny.cfg"; // std::string cfg_filename_ = "yolov3-tiny.cfg"; - std::string weights_filename_ = "config/weights/best.pt"; // Path to the PyTorch model + std::string weights_filename_ = "/config/weights/best.pt"; // Path to the PyTorch model cv::dnn::Backend backend_ = cv::dnn::DNN_BACKEND_OPENCV; cv::dnn::Target target_ = cv::dnn::DNN_TARGET_CPU; diff --git a/src/triton_object_recognition/src/object_recognizer.cpp b/src/triton_object_recognition/src/object_recognizer.cpp index 920eec67..3b82dfcb 100644 --- a/src/triton_object_recognition/src/object_recognizer.cpp +++ b/src/triton_object_recognition/src/object_recognizer.cpp @@ -3,6 +3,9 @@ #include #include #include "ament_index_cpp/get_package_share_directory.hpp" +#include // One-stop header for TorchScript +#include // General purpose header for PyTorch + using std::placeholders::_1; using std::placeholders::_2; using std::placeholders::_3; @@ -118,7 +121,7 @@ namespace triton_object_recognition // Load PyTorch model try { - module_ = std::make_shared(torch::jit::load(model_weights.string())); + module_ = torch::jit::load(model_weights.string()); module_->eval(); // Set the model to evaluation mode } catch (const c10::Error &e) { RCLCPP_ERROR(this->get_logger(), "Error loading the model: %s", e.what()); @@ -126,7 +129,7 @@ namespace triton_object_recognition } RCLCPP_INFO(get_logger(), "Object Recognizer successfully started!"); - + } void ObjectRecognizer::subscriberCallback(const sensor_msgs::msg::Image::ConstSharedPtr & msg) const { #if DEBUG_VISUALIZE