diff --git a/rclpy/rclpy/wait_for_message.py b/rclpy/rclpy/wait_for_message.py index 9eac995e1..a26eb7516 100644 --- a/rclpy/rclpy/wait_for_message.py +++ b/rclpy/rclpy/wait_for_message.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy from rclpy.node import Node +from rclpy.qos import QoSProfile from rclpy.signals import SignalHandlerGuardCondition from rclpy.utilities import timeout_sec_to_nsec @@ -22,6 +25,8 @@ def wait_for_message( msg_type, node: 'Node', topic: str, + *, + qos_profile: Union[QoSProfile, int] = 1, time_to_wait=-1 ): """ @@ -30,6 +35,7 @@ def wait_for_message( :param msg_type: message type :param node: node to initialize the subscription on :param topic: topic name to wait for message + :param qos_profile: QoS profile to use for the subscription :param time_to_wait: seconds to wait before returning :returns: (True, msg) if a message was successfully received, (False, None) if message could not be obtained or shutdown was triggered asynchronously on the context. @@ -38,7 +44,7 @@ def wait_for_message( wait_set = _rclpy.WaitSet(1, 1, 0, 0, 0, 0, context.handle) wait_set.clear_entities() - sub = node.create_subscription(msg_type, topic, lambda _: None, 1) + sub = node.create_subscription(msg_type, topic, lambda _: None, qos_profile=qos_profile) try: wait_set.add_subscription(sub.handle) sigint_gc = SignalHandlerGuardCondition(context=context) diff --git a/rclpy/test/test_wait_for_message.py b/rclpy/test/test_wait_for_message.py index 8929f238c..1aa8175fd 100644 --- a/rclpy/test/test_wait_for_message.py +++ b/rclpy/test/test_wait_for_message.py @@ -17,6 +17,7 @@ import unittest import rclpy +from rclpy.qos import QoSProfile from rclpy.wait_for_message import wait_for_message from test_msgs.msg import BasicTypes @@ -51,13 +52,22 @@ def _publish_message(self): def test_wait_for_message(self): t = threading.Thread(target=self._publish_message) t.start() - ret, msg = wait_for_message(BasicTypes, self.node, TOPIC_NAME) + ret, msg = wait_for_message(BasicTypes, self.node, TOPIC_NAME, qos_profile=1) + self.assertTrue(ret) + self.assertEqual(msg.int32_value, MSG_DATA) + t.join() + + def test_wait_for_message_qos(self): + t = threading.Thread(target=self._publish_message) + t.start() + ret, msg = wait_for_message( + BasicTypes, self.node, TOPIC_NAME, qos_profile=QoSProfile(depth=1)) self.assertTrue(ret) self.assertEqual(msg.int32_value, MSG_DATA) t.join() def test_wait_for_message_timeout(self): - ret, _ = wait_for_message(BasicTypes, self.node, TOPIC_NAME, 1) + ret, _ = wait_for_message(BasicTypes, self.node, TOPIC_NAME, time_to_wait=1) self.assertFalse(ret)