diff --git a/ros2param/ros2param/api/__init__.py b/ros2param/ros2param/api/__init__.py index 1cb102785..5e7daf52f 100644 --- a/ros2param/ros2param/api/__init__.py +++ b/ros2param/ros2param/api/__init__.py @@ -49,6 +49,10 @@ def load_parameter_file(*, node, node_name, parameter_file, use_wildcard): parameters = list(parameter_dict_from_yaml_file(parameter_file, use_wildcard).values()) rclpy.spin_until_future_complete(node, future) response = future.result() + if response is None: + raise RuntimeError('Exception while calling service of node ' + f'{node_name}: {future.exception()}') + assert len(response.results) == len(parameters), 'Not all parameters set' for i in range(0, len(response.results)): result = response.results[i] @@ -65,6 +69,26 @@ def load_parameter_file(*, node, node_name, parameter_file, use_wildcard): print(msg, file=sys.stderr) +def load_parameter_file_atomically(*, node, node_name, parameter_file, use_wildcard): + client = AsyncParameterClient(node, node_name) + ready = client.wait_for_services(timeout_sec=5.0) + if not ready: + raise RuntimeError('Wait for service timed out') + future = client.load_parameter_file_atomically(parameter_file, use_wildcard) + parameters = list(parameter_dict_from_yaml_file(parameter_file, use_wildcard).values()) + rclpy.spin_until_future_complete(node, future) + response = future.result() + if response is None: + raise RuntimeError('Exception while calling service of node ' + f'{node_name}: {future.exception()}') + + if response.result.successful: + msg = 'Set parameters {} successful'.format(' '.join([i.name for i in parameters])) + if response.result.reason: + msg += ': ' + response.result.reason + print(msg) + + def call_describe_parameters(*, node, node_name, parameter_names=None): client = AsyncParameterClient(node, node_name) ready = client.wait_for_services(timeout_sec=5.0) @@ -93,6 +117,18 @@ def call_get_parameters(*, node, node_name, parameter_names): return response +def call_set_parameters_atomically(*, node, node_name, parameters): + client = AsyncParameterClient(node, node_name) + client.wait_for_services(timeout_sec=5.0) + future = client.set_parameters_atomically(parameters) + rclpy.spin_until_future_complete(node, future) + response = future.result() + if response is None: + raise RuntimeError('Exception while calling service of node ' + f'{node_name}: {future.exception()}') + return response + + def call_set_parameters(*, node, node_name, parameters): client = AsyncParameterClient(node, node_name) ready = client.wait_for_services(timeout_sec=5.0) diff --git a/ros2param/ros2param/verb/load.py b/ros2param/ros2param/verb/load.py index e260f81f6..56041aa28 100644 --- a/ros2param/ros2param/verb/load.py +++ b/ros2param/ros2param/verb/load.py @@ -19,6 +19,7 @@ from ros2node.api import get_node_names from ros2node.api import NodeNameCompleter from ros2param.api import load_parameter_file +from ros2param.api import load_parameter_file_atomically from ros2param.verb import VerbExtension @@ -39,6 +40,9 @@ def add_arguments(self, parser, cli_name): # noqa: D102 parser.add_argument( '--no-use-wildcard', action='store_true', help="Do not load parameters in the '/**' namespace into the node") + parser.add_argument( + '--atomic', action='store_true', + help='Load parameters atomically') def main(self, *, args): # noqa: D102 with NodeStrategy(args) as node: @@ -50,5 +54,11 @@ def main(self, *, args): # noqa: D102 return 'Node not found' with DirectNode(args) as node: - load_parameter_file(node=node, node_name=node_name, parameter_file=args.parameter_file, - use_wildcard=not args.no_use_wildcard) + if args.atomic: + load_parameter_file_atomically(node=node, node_name=node_name, + parameter_file=args.parameter_file, + use_wildcard=not args.no_use_wildcard) + else: + load_parameter_file(node=node, node_name=node_name, + parameter_file=args.parameter_file, + use_wildcard=not args.no_use_wildcard) diff --git a/ros2param/ros2param/verb/set.py b/ros2param/ros2param/verb/set.py index 5e1eccf09..dbce9dd57 100644 --- a/ros2param/ros2param/verb/set.py +++ b/ros2param/ros2param/verb/set.py @@ -24,6 +24,7 @@ from ros2node.api import NodeNameCompleter from ros2param.api import call_set_parameters +from ros2param.api import call_set_parameters_atomically from ros2param.api import ParameterNameCompleter from ros2param.verb import VerbExtension @@ -47,6 +48,10 @@ def add_arguments(self, parser, cli_name): # noqa: D102 '--include-hidden-nodes', action='store_true', help='Consider hidden nodes as well') + parser.add_argument( + '--atomic', action='store_true', + help='Set parameters atomically') + def build_parameters(self, params): parameters = [] if len(params) % 2: @@ -72,9 +77,14 @@ def main(self, *, args): # noqa: D102 with DirectNode(args) as node: parameters = self.build_parameters(args.parameters) - response = call_set_parameters( - node=node, node_name=args.node_name, parameters=parameters) - results = response.results + if args.atomic: + response = call_set_parameters_atomically(node=node, node_name=args.node_name, + parameters=parameters) + results = [response.result] + else: + response = call_set_parameters(node=node, node_name=args.node_name, + parameters=parameters) + results = response.results for result in results: if result.successful: diff --git a/ros2param/test/test_verb_load.py b/ros2param/test/test_verb_load.py index 906a0e377..bbbeaa8fc 100644 --- a/ros2param/test/test_verb_load.py +++ b/ros2param/test/test_verb_load.py @@ -294,6 +294,31 @@ def test_verb_load(self): strict=True ) + def test_verb_load_atomic(self): + with tempfile.TemporaryDirectory() as tmpdir: + filepath = self._write_param_file(tmpdir, 'params.yaml') + with self.launch_param_load_command( + arguments=[f'{TEST_NAMESPACE}/{TEST_NODE}', filepath, '--atomic'] + ) as param_load_command: + assert param_load_command.wait_for_shutdown(timeout=TEST_TIMEOUT) + assert param_load_command.exit_code == launch_testing.asserts.EXIT_OK + assert launch_testing.tools.expect_output( + expected_lines=[''], + text=param_load_command.output, + strict=True + ) + # Dump with ros2 param dump and compare that output matches input file + with self.launch_param_dump_command( + arguments=[f'{TEST_NAMESPACE}/{TEST_NODE}'] + ) as param_dump_command: + assert param_dump_command.wait_for_shutdown(timeout=TEST_TIMEOUT) + assert param_dump_command.exit_code == launch_testing.asserts.EXIT_OK + assert launch_testing.tools.expect_output( + expected_text=INPUT_PARAMETER_FILE + '\n', + text=param_dump_command.output, + strict=True + ) + def test_verb_load_wildcard(self): with tempfile.TemporaryDirectory() as tmpdir: # Try param file with only wildcard