diff --git a/src/image_object_detection/image_object_detection_node.py b/src/image_object_detection/image_object_detection_node.py index 0ec44da..855ddab 100644 --- a/src/image_object_detection/image_object_detection_node.py +++ b/src/image_object_detection/image_object_detection_node.py @@ -12,13 +12,13 @@ import cv2 import cv_bridge +from cv_bridge import CvBridgeError import rclpy from rclpy.node import Node from rclpy.qos import QoSProfile, QoSReliabilityPolicy, QoSHistoryPolicy import std_srvs.srv from sensor_msgs.msg import CompressedImage, Image -from vision_msgs.msg import Detection2D, ObjectHypothesisWithPose import torch import torch.backends.cudnn as cudnn @@ -27,7 +27,7 @@ from utils.general import check_img_size, non_max_suppression, scale_coords, xyxy2xywh, set_logging from utils.plots import plot_one_box from utils.torch_utils import select_device -from vision_msgs.msg import Detection2DArray, Detection2D +from vision_msgs.msg import Detection2D, Detection2DArray, ObjectHypothesisWithPose from ament_index_python.packages import get_package_share_directory PACKAGE_NAME = "image_object_detection" @@ -102,7 +102,7 @@ def __init__(self): self.get_parameter("subscribers.qos_policy").get_parameter_value().string_value ) - self.debug_image_output_format = "mono8" # "bgr8" + self.debug_image_output_format = "mono8" # "mono8" # "bgr8" self.processing_enabled = self.get_parameter_or("processing_enabled", True) @@ -215,13 +215,17 @@ def accomodate_image_to_model(self, img0): def image_compressed_callback(self, msg): if not self.processing_enabled: return + + try: + self.cv_img = self.bridge.compressed_imgmsg_to_cv2(msg, self.debug_image_output_format) + img = self.accomodate_image_to_model(self.cv_img) - self.cv_img = self.bridge.compressed_imgmsg_to_cv2(msg, self.debug_image_output_format) - img = self.accomodate_image_to_model(self.cv_img) - - detections_msg, debugimg = self.predict(img, self.cv_img) + detections_msg, debugimg = self.predict(img, self.cv_img) - self.detection_publisher.publish(detections_msg) + self.detection_publisher.publish(detections_msg) + except CvBridgeError as e: + self.get_logger().error(f"Error converting image: {e}") + return if debugimg is not None: self.publish_debug_image(debugimg) @@ -256,7 +260,7 @@ def publish_debug_image(self, debugimg): elif self.debug_image_output_format == "rgba8": debugimg = cv2.cvtColor(debugimg, cv2.COLOR_BGR2RGBA) else: - self.logger.error( + self.get_logger().error( "Unsupported debug image output format: {}".format(self.debug_image_output_format) ) return @@ -295,8 +299,8 @@ def predict(self, model_img, original_image): detection2D_msg = Detection2D() xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() - detection2D_msg.bbox.center.x = xywh[0] - detection2D_msg.bbox.center.y = xywh[1] + detection2D_msg.bbox.center.position.x = xywh[0] + detection2D_msg.bbox.center.position.y = xywh[1] detection2D_msg.bbox.size_x = xywh[2] detection2D_msg.bbox.size_y = xywh[3]