Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added 'position' sub-attribute to get coordinates from Detection2D ob… #10

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions src/image_object_detection/image_object_detection_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
import cv2

import cv_bridge
from cv_bridge import CvBridgeError
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, QoSReliabilityPolicy, QoSHistoryPolicy

import std_srvs.srv
from sensor_msgs.msg import CompressedImage, Image
from vision_msgs.msg import Detection2D, ObjectHypothesisWithPose

import torch
import torch.backends.cudnn as cudnn
Expand All @@ -27,7 +27,7 @@
from utils.general import check_img_size, non_max_suppression, scale_coords, xyxy2xywh, set_logging
from utils.plots import plot_one_box
from utils.torch_utils import select_device
from vision_msgs.msg import Detection2DArray, Detection2D
from vision_msgs.msg import Detection2D, Detection2DArray, ObjectHypothesisWithPose
from ament_index_python.packages import get_package_share_directory

PACKAGE_NAME = "image_object_detection"
Expand Down Expand Up @@ -102,7 +102,7 @@ def __init__(self):
self.get_parameter("subscribers.qos_policy").get_parameter_value().string_value
)

self.debug_image_output_format = "mono8" # "bgr8"
self.debug_image_output_format = "mono8" # "mono8" # "bgr8"

self.processing_enabled = self.get_parameter_or("processing_enabled", True)

Expand Down Expand Up @@ -215,13 +215,17 @@ def accomodate_image_to_model(self, img0):
def image_compressed_callback(self, msg):
if not self.processing_enabled:
return

try:
self.cv_img = self.bridge.compressed_imgmsg_to_cv2(msg, self.debug_image_output_format)
img = self.accomodate_image_to_model(self.cv_img)

self.cv_img = self.bridge.compressed_imgmsg_to_cv2(msg, self.debug_image_output_format)
img = self.accomodate_image_to_model(self.cv_img)

detections_msg, debugimg = self.predict(img, self.cv_img)
detections_msg, debugimg = self.predict(img, self.cv_img)

self.detection_publisher.publish(detections_msg)
self.detection_publisher.publish(detections_msg)
except CvBridgeError as e:
self.get_logger().error(f"Error converting image: {e}")
return

if debugimg is not None:
self.publish_debug_image(debugimg)
Expand Down Expand Up @@ -256,7 +260,7 @@ def publish_debug_image(self, debugimg):
elif self.debug_image_output_format == "rgba8":
debugimg = cv2.cvtColor(debugimg, cv2.COLOR_BGR2RGBA)
else:
self.logger.error(
self.get_logger().error(
"Unsupported debug image output format: {}".format(self.debug_image_output_format)
)
return
Expand Down Expand Up @@ -295,8 +299,8 @@ def predict(self, model_img, original_image):
detection2D_msg = Detection2D()
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()

detection2D_msg.bbox.center.x = xywh[0]
detection2D_msg.bbox.center.y = xywh[1]
detection2D_msg.bbox.center.position.x = xywh[0]
detection2D_msg.bbox.center.position.y = xywh[1]
detection2D_msg.bbox.size_x = xywh[2]
detection2D_msg.bbox.size_y = xywh[3]

Expand Down
Loading