From 33f47332eb4a6dd03c1aeaa84e417141c772f579 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Mon, 11 Nov 2024 16:23:02 +0000 Subject: [PATCH 1/4] fix --- fedn/network/clients/grpc_handler.py | 32 ++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/fedn/network/clients/grpc_handler.py b/fedn/network/clients/grpc_handler.py index 4b7d9874c..440657847 100644 --- a/fedn/network/clients/grpc_handler.py +++ b/fedn/network/clients/grpc_handler.py @@ -65,16 +65,25 @@ def __init__(self, host: str, port: int, name: str, token: str, combiner_name: s ("client", name), ("grpc-server", combiner_name), ] + self.host = host + self.port = port + self.token = token - if port == 443: - self._init_secure_channel(host, port, token) - else: - self._init_insecure_channel(host, port) + self._init_channel(host, port, token) + + self._init_stubs() + def _init_stubs(self): self.connectorStub = rpc.ConnectorStub(self.channel) self.combinerStub = rpc.CombinerStub(self.channel) self.modelStub = rpc.ModelServiceStub(self.channel) + def _init_channel(self, host: str, port: int, token: str): + if port == 443: + self._init_secure_channel(host, port, token) + else: + self._init_insecure_channel(host, port) + def _init_secure_channel(self, host: str, port: int, token: str): url = f"{host}:{port}" logger.info(f"Connecting (GRPC) to {url}") @@ -97,6 +106,7 @@ def _init_secure_channel(self, host: str, port: int, token: str): ) def _init_insecure_channel(self, host: str, port: int): + self.secure = False url = f"{host}:{port}" logger.info(f"Connecting (GRPC) to {url}") self.channel = grpc.insecure_channel( @@ -116,6 +126,7 @@ def heartbeat(self, client_name: str, client_id: str): logger.info("Sending heartbeat to combiner") response = self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata) except grpc.RpcError as e: + logger.error(f"GRPC (SendHeartbeat): An error occurred: {e}") raise e except Exception as e: logger.error(f"GRPC (SendHeartbeat): An error occurred: {e}") @@ -130,6 +141,8 @@ def send_heartbeats(self, client_name: str, client_id: str, update_frequency: fl response = self.heartbeat(client_name, client_id) except grpc.RpcError as e: return self._handle_grpc_error(e, "SendHeartbeat", lambda: self.send_heartbeats(client_name, client_id, update_frequency)) + except Exception as e: + return self._handle_unknown_error(e, "SendHeartbeat", lambda: self.send_heartbeats(client_name, client_id, update_frequency)) if isinstance(response, fedn.Response): logger.info("Heartbeat successful.") else: @@ -166,10 +179,12 @@ def listen_to_task_stream(self, client_name: str, client_id: str, callback: Call callback(request) except grpc.RpcError as e: + self.logger.error(f"GRPC (TaskStream): An error occurred: {e}") return self._handle_grpc_error(e, "TaskStream", lambda: self.listen_to_task_stream(client_name, client_id, callback)) except Exception as e: logger.error(f"GRPC (TaskStream): An error occurred: {e}") self._disconnect() + self._handle_unknown_error(e, "TaskStream", lambda: self.listen_to_task_stream(client_name, client_id, callback)) def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None): """Send status message. @@ -406,6 +421,15 @@ def _handle_grpc_error(self, e, method_name: str, sender_function: Callable): self._disconnect() logger.error(f"GRPC ({method_name}): An error occurred: {e}") + def _handle_unknown_error(self, e, method_name: str, sender_function: Callable): + # Try to reconnect + logger.warning(f"GRPC ({method_name}): An unknown error occurred: {e}.") + logger.warning(f"GRPC ({method_name}): Reconnecting to channel.") + # recreate the channel + self._init_channel(self.host, self.port, self.token) + self._init_stubs() + return sender_function() + def _disconnect(self): """Disconnect from the combiner.""" self.channel.close() From 23fc5653a2607aa1b97153c766695398a1ea0b89 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Mon, 11 Nov 2024 16:23:12 +0000 Subject: [PATCH 2/4] add choas tests --- .ci/tests/chaos_test.py | 210 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 210 insertions(+) create mode 100644 .ci/tests/chaos_test.py diff --git a/.ci/tests/chaos_test.py b/.ci/tests/chaos_test.py new file mode 100644 index 000000000..09841474f --- /dev/null +++ b/.ci/tests/chaos_test.py @@ -0,0 +1,210 @@ +from toxiproxy import Toxiproxy +import unittest +import grpc +import time +from fedn.network.clients.grpc_handler import GrpcHandler +import fedn.network.grpc.fedn_pb2 as fedn + + +class TestGRPCWithToxiproxy(unittest.TestCase): + @classmethod + def setUpClass(cls): + client_name = 'test-client' + client_id = 'test-client-id' + host = 'localhost' + port_proxy = 12081 + port_server = 12080 + token = "" + combiner_name = 'combiner' + + cls.toxiproxy = Toxiproxy() + if cls.toxiproxy.proxies(): + cls.toxiproxy.destroy_all() + + @classmethod + def tearDownClass(cls): + # Close the proxy and gRPC channel when done + cls.toxiproxy.destroy_all() + + @unittest.skip("Not implemented") + def test_normal_heartbeat(self): + # Test the heartbeat without any toxic + client_name = 'test-client' + client_id = 'test-client-id' + # Random proxy port + grpc_handler = GrpcHandler(host='localhost', port=12080, name=client_name, token='', combiner_name='combiner') + try: + response = grpc_handler.heartbeat(client_name, client_id) + self.assertIsInstance(response, fedn.Response) + except grpc.RpcError as e: + self.fail(f'gRPC error: {e.code()} {e.details()}') + finally: + grpc_handler.channel.close() + + @unittest.skip("Not implemented") + def test_latency_2s_toxic_heartbeat(self): + # Add latency of 1000ms + client_name = 'test-client' + client_id = 'test-client-id' + + proxy = self.toxiproxy.create(name='test_latency_toxic_heartbeat', listen='localhost:12082', upstream='localhost:12080') + grpc_handler = GrpcHandler(host='localhost', port=12082, name=client_name, token='', combiner_name='combiner') + proxy.add_toxic(name='latency', type='latency', attributes={'latency': 2000}) + + start_time = time.time() + try: + response = grpc_handler.heartbeat(client_name, client_id) + finally: + grpc_handler.channel.close() + proxy.destroy() + end_time = time.time() + + # Check that the latency delay is present + self.assertGreaterEqual(end_time - start_time, 2) # Expect at least 1 second delay + self.assertIsInstance(response, fedn.Response) + + def test_latency_long_toxic_heartbeat(self): + """Test gRPC request with a simulated latency of 25s. Should timeout based on KEEPALIVE_TIMEOUT_MS (default set to 20000).""" + client_name = 'test-client' + client_id = 'test-client-id' + latency = 20 # 15s latency + + proxy = self.toxiproxy.create(name='test_latency_toxic_heartbeat', listen='localhost:12083', upstream='localhost:12080') + grpc_handler = GrpcHandler(host='localhost', port=12083, name=client_name, token='', combiner_name='combiner') + proxy.add_toxic(name='latency', type='latency', attributes={'latency': latency * 1000}) + + start_time = time.time() + try: + response = grpc_handler.heartbeat(client_name, client_id) + except grpc.RpcError as e: + response = e + finally: + grpc_handler.channel.close() + proxy.destroy() + end_time = time.time() + + response + + # Check that the latency delay is present + self.assertGreaterEqual(end_time - start_time, latency) # Expect at least 1 second delay + self.assertIsInstance(response, grpc.RpcError) + self.assertEqual(response.code(), grpc.StatusCode.UNAVAILABLE) + self.assertEqual(response.details(), 'failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:12083: connection attempt timed out before receiving SETTINGS frame') + + def test_close_channel(self): + """ + Test closing the gRPC channel and trying to send a heartbeat. + Expect a ValueError to be raised. + """ + + client_name = 'test-client' + client_id = 'test-client-id' + + grpc_handler = GrpcHandler(host='localhost', port=12080, name=client_name, token='', combiner_name='combiner') + + # Close the channel + grpc_handler._disconnect() + + # Try to send heartbeat + with self.assertRaises(ValueError) as context: + response = grpc_handler.heartbeat(client_name, client_id) + self.assertEqual(str(context.exception), 'Cannot invoke RPC on closed channel!') + + + @unittest.skip("Not implemented") + def test_disconnect_toxic_heartbeat(self): + """Test gRPC request with a simulated disconnection.""" + # Add a timeout toxic to simulate network disconnection + client_name = 'test-client' + client_id = 'test-client-id' + + proxy = self.toxiproxy.create(name='test_disconnect_toxic_heartbeat', listen='localhost:12084', upstream='localhost:12080') + grpc_handler = GrpcHandler(host='localhost', port=12084, name=client_name, token='', combiner_name='combiner') + proxy.add_toxic(name='timeout', type='timeout', attributes={'timeout': 1000}) + + try: + response = grpc_handler.heartbeat(client_name, client_id) + except grpc.RpcError as e: + response = e + finally: + grpc_handler.channel.close() + proxy.destroy() + + # Assert that the response is a gRPC error with status code UNAVAILABLE + self.assertEqual(response.code(), grpc.StatusCode.UNAVAILABLE) + self.assertEqual(response.details(), 'failed to connect to all addresses; last error: UNAVAILABLE: ipv4:127.0.0.1:12084: Socket closed') + + @unittest.skip("Not implemented") + def test_timeout_toxic_heartbeat(self): + """Stops all data from getting through, and closes the connection after timeout. timeout is 0, + the connection won't close, and data will be delayed until the toxic is removed. + """ + # Add a timeout toxic to simulate network disconnection + client_name = 'test-client' + client_id = 'test-client-id' + + proxy = self.toxiproxy.create(name='test_timeout_toxic_heartbeat', listen='localhost:12085', upstream='localhost:12080') + grpc_handler = GrpcHandler(host='localhost', port=12085, name=client_name, token='', combiner_name='combiner') + proxy.add_toxic(name='timeout', type='timeout', attributes={'timeout': 0}) + + try: + response = grpc_handler.heartbeat(client_name, client_id) + except grpc.RpcError as e: + response = e + finally: + grpc_handler.channel.close() + proxy.destroy() + + # Assert that the response is a gRPC error with status code UNAVAILABLE + self.assertEqual(response.code(), grpc.StatusCode.UNAVAILABLE) + self.assertEqual(response.details(), 'failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:12085: connection attempt timed out before receiving SETTINGS frame') + + @unittest.skip("Not implemented") + def test_rate_limit_toxic_heartbeat(self): + # Purpose: Limits the number of connections that can be established within a certain time frame. + # Toxic: rate_limit + # Use Case: Useful for testing how the client behaves under strict rate limits. For example, in Federated Learning, + # this could simulate constraints in networks with multiple clients trying to access the server. + + # Add a rate limit toxic to the proxy + self.proxy.add_rate_limit(rate=1000) + + @unittest.skip("Not implemented") + def test_bandwidth_toxic_heartbeat(self): + # Purpose: Limits the bandwidth of the connection. + # Toxic: bandwidth + # Use Case: Useful for testing how the client behaves under limited bandwidth. For example, in Federated Learning, + # this could simulate a slow network connection between the client and the server. + + # Add a bandwidth toxic to the proxy + self.proxy.add_bandwidth(rate=1000) # 1 KB/s + + @unittest.skip("Not implemented") + def test_connection_reset(self): + # Purpose: Immediately resets the connection, simulating an abrupt network drop. + # Toxic: add_reset + # Use Case: This is helpful for testing error-handling logic on sudden network failures, + # ensuring the client retries appropriately or fails gracefully + + # Add a connection_reset toxic to the proxy + self.proxy.add_reset() + + @unittest.skip("Not implemented") + def test_slow_close(self): + # Purpose: Simulates a slow closing of the connection. + # Toxic: slow_close + # Use Case: Useful for testing how the client behaves when the server closes the connection slowly. + # This can help ensure that the client handles slow network disconnections gracefully. + + # Add a slow_close toxic to the proxy + self.proxy.add_slow_close(delay=1000) # Delay closing the connection by 1 second + + @unittest.skip("Not implemented") + def test_slicer(self): + # Purpose: Slices the data into smaller chunks. + # Toxic: slicer + # Use Case: Useful for testing how the client handles fragmented data. + # This can help ensure that the client can reassemble the data correctly and handle partial data gracefully. + + # Add a slicer toxic to the proxy + self.proxy.add_slicer(average_size=1000, size_variation=100) # Slice data into chunks of 1 KB with 100 bytes variation \ No newline at end of file From c047d388615e8dc565f5b4573fdb5895b1e6fd5a Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Mon, 11 Nov 2024 16:26:28 +0000 Subject: [PATCH 3/4] remove self.secure --- fedn/network/clients/grpc_handler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fedn/network/clients/grpc_handler.py b/fedn/network/clients/grpc_handler.py index 440657847..37a0d4a46 100644 --- a/fedn/network/clients/grpc_handler.py +++ b/fedn/network/clients/grpc_handler.py @@ -106,7 +106,6 @@ def _init_secure_channel(self, host: str, port: int, token: str): ) def _init_insecure_channel(self, host: str, port: int): - self.secure = False url = f"{host}:{port}" logger.info(f"Connecting (GRPC) to {url}") self.channel = grpc.insecure_channel( From 3d49eff76c2d09e0b0d3ef29aa9ff1fb3b9e396a Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Tue, 12 Nov 2024 15:06:05 +0000 Subject: [PATCH 4/4] handle unknown error and disconnect --- fedn/network/clients/grpc_handler.py | 48 +++++++++++++++++----------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/fedn/network/clients/grpc_handler.py b/fedn/network/clients/grpc_handler.py index 37a0d4a46..0aeedf344 100644 --- a/fedn/network/clients/grpc_handler.py +++ b/fedn/network/clients/grpc_handler.py @@ -129,7 +129,6 @@ def heartbeat(self, client_name: str, client_id: str): raise e except Exception as e: logger.error(f"GRPC (SendHeartbeat): An error occurred: {e}") - self._disconnect() raise e return response @@ -182,7 +181,6 @@ def listen_to_task_stream(self, client_name: str, client_id: str, callback: Call return self._handle_grpc_error(e, "TaskStream", lambda: self.listen_to_task_stream(client_name, client_id, callback)) except Exception as e: logger.error(f"GRPC (TaskStream): An error occurred: {e}") - self._disconnect() self._handle_unknown_error(e, "TaskStream", lambda: self.listen_to_task_stream(client_name, client_id, callback)) def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None): @@ -218,7 +216,7 @@ def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=N return self._handle_grpc_error(e, "SendStatus", lambda: self.send_status(msg, log_level, type, request, sesssion_id, sender_name)) except Exception as e: logger.error(f"GRPC (SendStatus): An error occurred: {e}") - self._disconnect() + self._handle_unknown_error(e, "SendStatus", lambda: self.send_status(msg, log_level, type, request, sesssion_id, sender_name)) def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> BytesIO: """Fetch a model from the assigned combiner. @@ -255,8 +253,7 @@ def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> return self._handle_grpc_error(e, "Download", lambda: self.get_model_from_combiner(id, client_id, timeout)) except Exception as e: logger.error(f"GRPC (Download): An error occurred: {e}") - self._disconnect() - + self._handle_unknown_error(e, "Download", lambda: self.get_model_from_combiner(id, client_id, timeout)) return data def send_model_to_combiner(self, model: BytesIO, id: str): @@ -287,8 +284,7 @@ def send_model_to_combiner(self, model: BytesIO, id: str): return self._handle_grpc_error(e, "Upload", lambda: self.send_model_to_combiner(model, id)) except Exception as e: logger.error(f"GRPC (Upload): An error occurred: {e}") - self._disconnect() - + self._handle_unknown_error(e, "Upload", lambda: self.send_model_to_combiner(model, id)) return result def create_update_message( @@ -367,8 +363,7 @@ def send_model_update(self, update: fedn.ModelUpdate): return self._handle_grpc_error(e, "SendModelUpdate", lambda: self.send_model_update(update)) except Exception as e: logger.error(f"GRPC (SendModelUpdate): An error occurred: {e}") - self._disconnect() - + self._handle_unknown_error(e, "SendModelUpdate", lambda: self.send_model_update(update)) return True def send_model_validation(self, validation: fedn.ModelValidation) -> bool: @@ -383,8 +378,7 @@ def send_model_validation(self, validation: fedn.ModelValidation) -> bool: ) except Exception as e: logger.error(f"GRPC (SendModelValidation): An error occurred: {e}") - self._disconnect() - + self._handle_unknown_error(e, "SendModelValidation", lambda: self.send_model_validation(validation)) return True def send_model_prediction(self, prediction: fedn.ModelPrediction) -> bool: @@ -399,8 +393,7 @@ def send_model_prediction(self, prediction: fedn.ModelPrediction) -> bool: ) except Exception as e: logger.error(f"GRPC (SendModelPrediction): An error occurred: {e}") - self._disconnect() - + self._handle_unknown_error(e, "SendModelPrediction", lambda: self.send_model_prediction(prediction)) return True def _handle_grpc_error(self, e, method_name: str, sender_function: Callable): @@ -413,21 +406,38 @@ def _handle_grpc_error(self, e, method_name: str, sender_function: Callable): logger.warning(f"GRPC ({method_name}): connection cancelled. Retrying in 5 seconds.") time.sleep(5) return sender_function() - if status_code == grpc.StatusCode.UNAUTHENTICATED: + elif status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() if details == "Token expired": logger.warning(f"GRPC ({method_name}): Token expired.") + raise e + elif status_code == grpc.StatusCode.UNKNOWN: + logger.warning(f"GRPC ({method_name}): An unknown error occurred: {e}.") + details = e.details() + if details == "Stream removed": + logger.warning(f"GRPC ({method_name}): Stream removed. Reconnecting") + self._disconnect() + self._init_channel(self.host, self.port, self.token) + self._init_stubs() + return sender_function() + raise e self._disconnect() logger.error(f"GRPC ({method_name}): An error occurred: {e}") + raise e def _handle_unknown_error(self, e, method_name: str, sender_function: Callable): # Try to reconnect logger.warning(f"GRPC ({method_name}): An unknown error occurred: {e}.") - logger.warning(f"GRPC ({method_name}): Reconnecting to channel.") - # recreate the channel - self._init_channel(self.host, self.port, self.token) - self._init_stubs() - return sender_function() + if isinstance(e, ValueError): + # ValueError is raised when the channel is closed + self._disconnect() + logger.warning(f"GRPC ({method_name}): Reconnecting to channel.") + # recreate the channel + self._init_channel(self.host, self.port, self.token) + self._init_stubs() + return sender_function() + else: + raise e def _disconnect(self): """Disconnect from the combiner."""