Skip to content

Commit

Permalink
fixing syntax issues for cv
Browse files Browse the repository at this point in the history
  • Loading branch information
Rshah2004 committed Jun 9, 2024
1 parent f416680 commit 42ae2f8
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
4 changes: 4 additions & 0 deletions src/triton_object_recognition/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ find_package(sensor_msgs REQUIRED)
find_package(image_transport REQUIRED)
find_package(OpenCV 4.5.3 REQUIRED)
find_package(Boost REQUIRED COMPONENTS filesystem)
find_package(Torch REQUIRED)


include_directories(include ${catkin_INCLUDE_DIRS})
include_directories(${OpenCV_INCLUDE_DIRS} ${Boost_INCLUDE_DIRS})
Expand All @@ -35,6 +37,8 @@ add_library(triton_object_recognition SHARED
)

target_link_libraries(triton_object_recognition ${OpenCV_LIBRARIES} ${Boost_LIBRARIES})
target_link_libraries(your_executable_name "${TORCH_LIBRARIES}")


rclcpp_components_register_nodes(triton_object_recognition
"triton_object_recognition::ObjectRecognizer"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ namespace triton_object_recognition
#endif

std::shared_ptr<cv::dnn::Net> net_;
std::shared_ptr<torch::jit::script::Module> module_;


//Default Neural Net Parameters (overriden by parameters)
// THESE ARE RELATED TO DARKNET SO I HAVE COMMENTED THEM OUT
Expand All @@ -66,7 +68,7 @@ namespace triton_object_recognition
// std::string cfg_url_ = "https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3-tiny.cfg";
// std::string cfg_filename_ = "yolov3-tiny.cfg";

std::string weights_filename_ = "config/weights/best.pt"; // Path to the PyTorch model
std::string weights_filename_ = "/config/weights/best.pt"; // Path to the PyTorch model

cv::dnn::Backend backend_ = cv::dnn::DNN_BACKEND_OPENCV;
cv::dnn::Target target_ = cv::dnn::DNN_TARGET_CPU;
Expand Down
7 changes: 5 additions & 2 deletions src/triton_object_recognition/src/object_recognizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include <boost/filesystem.hpp>
#include <rcl_yaml_param_parser/parser.h>
#include "ament_index_cpp/get_package_share_directory.hpp"
#include <torch/script.h> // One-stop header for TorchScript
#include <torch/torch.h> // General purpose header for PyTorch

using std::placeholders::_1;
using std::placeholders::_2;
using std::placeholders::_3;
Expand Down Expand Up @@ -118,15 +121,15 @@ namespace triton_object_recognition

// Load PyTorch model
try {
module_ = std::make_shared<torch::jit::script::Module>(torch::jit::load(model_weights.string()));
module_ = torch::jit::load(model_weights.string());
module_->eval(); // Set the model to evaluation mode
} catch (const c10::Error &e) {
RCLCPP_ERROR(this->get_logger(), "Error loading the model: %s", e.what());
throw;
}

RCLCPP_INFO(get_logger(), "Object Recognizer successfully started!");

}
void ObjectRecognizer::subscriberCallback(const sensor_msgs::msg::Image::ConstSharedPtr & msg) const
{
#if DEBUG_VISUALIZE
Expand Down

0 comments on commit 42ae2f8

Please sign in to comment.