Skip to content

Commit

Permalink
multiple input images implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloinigoblasco committed Sep 11, 2024
1 parent 711c2d2 commit 40d38e9
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 64 deletions.
15 changes: 14 additions & 1 deletion config/image_object_detection.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,20 @@ image_object_detection_node:
model.weights_file: yolov7-tiny.pt
model.device: '0'

selected_detections: ['person']
selected_detections: ['person'] # Classes to detect ['person', 'car']

show_image: False
publish_debug_image: True

# Lists of topics to subscribe
camera_topics:
- '/camera/image_raw'
# - '/camera1/image_raw'
# - '/camera2/image_raw'
# - '/camera3/image_raw'

# QoS policy for the image subscriber
subscribers.qos_policy: 'best_effort'

# QoS policy for the image debug publisher
image_debug_publisher.qos_policy: 'best_effort'
139 changes: 76 additions & 63 deletions src/image_object_detection/image_object_detection_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import std_srvs.srv
from sensor_msgs.msg import CompressedImage, Image
from vision_msgs.msg import Detection2D, ObjectHypothesisWithPose

import torch
import torch.backends.cudnn as cudnn
Expand All @@ -27,17 +28,17 @@
from utils.general import check_img_size, non_max_suppression, scale_coords, xyxy2xywh, set_logging
from utils.plots import plot_one_box
from utils.torch_utils import select_device
from vision_msgs.msg import Detection2D, Detection2DArray, ObjectHypothesisWithPose
from vision_msgs.msg import Detection2DArray, Detection2D
from ament_index_python.packages import get_package_share_directory

PACKAGE_NAME = "image_object_detection"


class ImageDetectObjectNode(Node):
def __init__(self):
super().__init__("image_object_detection_node")
super().__init__("image_object_detection_node")

# parametros
# Model parameters
self.declare_parameter("model.image_size", 640)
self.model_image_size = (
self.get_parameter("model.image_size").get_parameter_value().integer_value
Expand Down Expand Up @@ -129,60 +130,59 @@ def __init__(self):

self.bridge = cv_bridge.CvBridge()

self.image_sub = self.create_subscription(
msg_type=Image, topic="image", callback=self.image_callback, qos_profile=self.qos
# Get the list of camera topics from the config file
self.declare_parameter("camera_topics", [])
self.camera_topics = (
self.get_parameter("camera_topics").get_parameter_value().string_array_value
)

self.image_compressed_sub = self.create_subscription(
msg_type=CompressedImage,
topic="image/compressed",
callback=self.image_compressed_callback,
qos_profile=self.qos,
)

self.detection_publisher = self.create_publisher(
msg_type=Detection2DArray, topic="detections", qos_profile=self.qos
)

if self.enable_publish_debug_image:
if self.qos_policy == "best_effort":
self.get_logger().info("Using best effort qos policy for debug image publisher")
self.qos = QoSProfile(
reliability=QoSReliabilityPolicy.BEST_EFFORT,
history=QoSHistoryPolicy.KEEP_LAST,
depth=1,
)
else:
self.get_logger().info("Using reliable qos policy for debug image publisher")
self.qos = QoSProfile(
reliability=QoSReliabilityPolicy.RELIABLE,
history=QoSHistoryPolicy.KEEP_LAST,
depth=1,
self.get_logger().info(f"Subscribed to topics: {self.camera_topics}")

# Initialize subscribers and publishers for each camera topic
self.subscribers = []
self.detection_publishers = {}
self.debug_image_publishers = {}

for topic in self.camera_topics:
# Create a subscriber for each camera topic
self.subscribers.append(
self.create_subscription(
Image,
topic,
callback=self.image_callback_factory(topic),
qos_profile=self.qos,
)
)

self.debug_image_publisher = self.create_publisher(
msg_type=Image, topic="debug_image", qos_profile=self.qos
# Create a detection publisher for each camera
detection_topic = f"{topic}/detections"
self.detection_publishers[topic] = self.create_publisher(
Detection2DArray, detection_topic, self.qos
)

# Create a debug image publisher for each camera (if enabled)
if self.enable_publish_debug_image:
debug_image_topic = f"{topic}/debug_image"
self.debug_image_publishers[topic] = self.create_publisher(
Image, debug_image_topic, self.qos
)

self.initialize_model()

def initialize_model(self):
with torch.no_grad():
# Initialize
set_logging()
self.device = select_device(self.device)
self.half = self.device.type != "cpu"

# Load model
self.model = attempt_load(
self.model_weights_file, map_location=self.device
) # load FP32 model
)
self.stride = int(self.model.stride.max())

self.imgsz = check_img_size(self.model_image_size, s=self.stride)

if self.half:
self.model.half() # to FP16
self.model.half()

cudnn.benchmark = True

Expand Down Expand Up @@ -215,17 +215,13 @@ def accomodate_image_to_model(self, img0):
def image_compressed_callback(self, msg):
if not self.processing_enabled:
return

try:
self.cv_img = self.bridge.compressed_imgmsg_to_cv2(msg, self.debug_image_output_format)
img = self.accomodate_image_to_model(self.cv_img)

detections_msg, debugimg = self.predict(img, self.cv_img)
self.cv_img = self.bridge.compressed_imgmsg_to_cv2(msg, self.debug_image_output_format)
img = self.accomodate_image_to_model(self.cv_img)

self.detection_publisher.publish(detections_msg)
except CvBridgeError as e:
self.get_logger().error(f"Error converting image: {e}")
return
detections_msg, debugimg = self.predict(img, self.cv_img)

self.detection_publisher.publish(detections_msg)

if debugimg is not None:
self.publish_debug_image(debugimg)
Expand All @@ -234,25 +230,43 @@ def image_compressed_callback(self, msg):
cv2.imshow("Compressed Image", debugimg)
cv2.waitKey(1)

def image_callback(self, msg):
if not self.processing_enabled:
return
def image_callback_factory(self, topic):
def callback(msg):
try:
cv_img = self.bridge.imgmsg_to_cv2(msg, "bgr8")
self.image_queue[topic] = cv_img
except CvBridgeError as e:
self.get_logger().error(f"Error converting image from {topic}: {e}")

self.cv_img = self.bridge.imgmsg_to_cv2(msg, "bgr8")
img = self.accomodate_image_to_model(self.cv_img)
return callback

def image_callback_factory(self, topic):
def callback(msg):
if not self.processing_enabled:
return

detections_msg, debugimg = self.predict(img, self.cv_img)
try:
cv_img = self.bridge.imgmsg_to_cv2(msg, "bgr8")
img = self.accomodate_image_to_model(cv_img)

self.detection_publisher.publish(detections_msg)
detections_msg, debugimg = self.predict(img, cv_img)

if debugimg is not None:
self.publish_debug_image(debugimg)
# Publish detections for the current camera
self.detection_publishers[topic].publish(detections_msg)

if self.show_image:
cv2.imshow("Detection", debugimg)
cv2.waitKey(1)
# Publish debug image for the current camera (if enabled)
if self.enable_publish_debug_image and topic in self.debug_image_publishers:
self.publish_debug_image(debugimg, topic)

def publish_debug_image(self, debugimg):
if self.show_image:
cv2.imshow(f"Detection from {topic}", debugimg)
cv2.waitKey(1)
except CvBridgeError as e:
self.get_logger().error(f"Error converting image from {topic}: {e}")

return callback

def publish_debug_image(self, debugimg, topic):
if self.debug_image_output_format == "mono8":
debugimg = cv2.cvtColor(debugimg, cv2.COLOR_RGB2GRAY)
elif self.debug_image_output_format == "rgb8":
Expand All @@ -261,11 +275,12 @@ def publish_debug_image(self, debugimg):
debugimg = cv2.cvtColor(debugimg, cv2.COLOR_BGR2RGBA)
else:
self.get_logger().error(
"Unsupported debug image output format: {}".format(self.debug_image_output_format)
f"Unsupported debug image output format: {self.debug_image_output_format}"
)
return

self.debug_image_publisher.publish(
# Publish the debug image for the current camera
self.debug_image_publishers[topic].publish(
self.bridge.cv2_to_imgmsg(debugimg, self.debug_image_output_format)
)

Expand Down Expand Up @@ -294,7 +309,6 @@ def predict(self, model_img, original_image):
).round()

for *xyxy, conf, cls in reversed(det):
# clase clases deseadas
if self.names[int(cls)] in self.selected_detections:
detection2D_msg = Detection2D()
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
Expand Down Expand Up @@ -322,7 +336,6 @@ def predict(self, model_img, original_image):

return detections_msg, original_image


def main(args=None):
print(args)
rclpy.init(args=sys.argv)
Expand Down

0 comments on commit 40d38e9

Please sign in to comment.