Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/SK-1193 | Handle unknown error, reconnect channel #743

Merged
merged 4 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions .ci/tests/chaos_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from toxiproxy import Toxiproxy
import unittest
import grpc
import time
from fedn.network.clients.grpc_handler import GrpcHandler
import fedn.network.grpc.fedn_pb2 as fedn


class TestGRPCWithToxiproxy(unittest.TestCase):
@classmethod
def setUpClass(cls):
client_name = 'test-client'
client_id = 'test-client-id'
host = 'localhost'
port_proxy = 12081
port_server = 12080
token = ""
combiner_name = 'combiner'

cls.toxiproxy = Toxiproxy()
if cls.toxiproxy.proxies():
cls.toxiproxy.destroy_all()

@classmethod
def tearDownClass(cls):
# Close the proxy and gRPC channel when done
cls.toxiproxy.destroy_all()

@unittest.skip("Not implemented")
def test_normal_heartbeat(self):
# Test the heartbeat without any toxic
client_name = 'test-client'
client_id = 'test-client-id'
# Random proxy port
grpc_handler = GrpcHandler(host='localhost', port=12080, name=client_name, token='', combiner_name='combiner')
try:
response = grpc_handler.heartbeat(client_name, client_id)
self.assertIsInstance(response, fedn.Response)
except grpc.RpcError as e:
self.fail(f'gRPC error: {e.code()} {e.details()}')
finally:
grpc_handler.channel.close()

@unittest.skip("Not implemented")
def test_latency_2s_toxic_heartbeat(self):
# Add latency of 1000ms
client_name = 'test-client'
client_id = 'test-client-id'

proxy = self.toxiproxy.create(name='test_latency_toxic_heartbeat', listen='localhost:12082', upstream='localhost:12080')
grpc_handler = GrpcHandler(host='localhost', port=12082, name=client_name, token='', combiner_name='combiner')
proxy.add_toxic(name='latency', type='latency', attributes={'latency': 2000})

start_time = time.time()
try:
response = grpc_handler.heartbeat(client_name, client_id)
finally:
grpc_handler.channel.close()
proxy.destroy()
end_time = time.time()

# Check that the latency delay is present
self.assertGreaterEqual(end_time - start_time, 2) # Expect at least 1 second delay
self.assertIsInstance(response, fedn.Response)

def test_latency_long_toxic_heartbeat(self):
"""Test gRPC request with a simulated latency of 25s. Should timeout based on KEEPALIVE_TIMEOUT_MS (default set to 20000)."""
client_name = 'test-client'
client_id = 'test-client-id'
latency = 20 # 15s latency

proxy = self.toxiproxy.create(name='test_latency_toxic_heartbeat', listen='localhost:12083', upstream='localhost:12080')
grpc_handler = GrpcHandler(host='localhost', port=12083, name=client_name, token='', combiner_name='combiner')
proxy.add_toxic(name='latency', type='latency', attributes={'latency': latency * 1000})

start_time = time.time()
try:
response = grpc_handler.heartbeat(client_name, client_id)
except grpc.RpcError as e:
response = e
finally:
grpc_handler.channel.close()
proxy.destroy()
end_time = time.time()

response

# Check that the latency delay is present
self.assertGreaterEqual(end_time - start_time, latency) # Expect at least 1 second delay
self.assertIsInstance(response, grpc.RpcError)
self.assertEqual(response.code(), grpc.StatusCode.UNAVAILABLE)
self.assertEqual(response.details(), 'failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:12083: connection attempt timed out before receiving SETTINGS frame')

def test_close_channel(self):
"""
Test closing the gRPC channel and trying to send a heartbeat.
Expect a ValueError to be raised.
"""

client_name = 'test-client'
client_id = 'test-client-id'

grpc_handler = GrpcHandler(host='localhost', port=12080, name=client_name, token='', combiner_name='combiner')

# Close the channel
grpc_handler._disconnect()

# Try to send heartbeat
with self.assertRaises(ValueError) as context:
response = grpc_handler.heartbeat(client_name, client_id)
self.assertEqual(str(context.exception), 'Cannot invoke RPC on closed channel!')


@unittest.skip("Not implemented")
def test_disconnect_toxic_heartbeat(self):
"""Test gRPC request with a simulated disconnection."""
# Add a timeout toxic to simulate network disconnection
client_name = 'test-client'
client_id = 'test-client-id'

proxy = self.toxiproxy.create(name='test_disconnect_toxic_heartbeat', listen='localhost:12084', upstream='localhost:12080')
grpc_handler = GrpcHandler(host='localhost', port=12084, name=client_name, token='', combiner_name='combiner')
proxy.add_toxic(name='timeout', type='timeout', attributes={'timeout': 1000})

try:
response = grpc_handler.heartbeat(client_name, client_id)
except grpc.RpcError as e:
response = e
finally:
grpc_handler.channel.close()
proxy.destroy()

# Assert that the response is a gRPC error with status code UNAVAILABLE
self.assertEqual(response.code(), grpc.StatusCode.UNAVAILABLE)
self.assertEqual(response.details(), 'failed to connect to all addresses; last error: UNAVAILABLE: ipv4:127.0.0.1:12084: Socket closed')

@unittest.skip("Not implemented")
def test_timeout_toxic_heartbeat(self):
"""Stops all data from getting through, and closes the connection after timeout. timeout is 0,
the connection won't close, and data will be delayed until the toxic is removed.
"""
# Add a timeout toxic to simulate network disconnection
client_name = 'test-client'
client_id = 'test-client-id'

proxy = self.toxiproxy.create(name='test_timeout_toxic_heartbeat', listen='localhost:12085', upstream='localhost:12080')
grpc_handler = GrpcHandler(host='localhost', port=12085, name=client_name, token='', combiner_name='combiner')
proxy.add_toxic(name='timeout', type='timeout', attributes={'timeout': 0})

try:
response = grpc_handler.heartbeat(client_name, client_id)
except grpc.RpcError as e:
response = e
finally:
grpc_handler.channel.close()
proxy.destroy()

# Assert that the response is a gRPC error with status code UNAVAILABLE
self.assertEqual(response.code(), grpc.StatusCode.UNAVAILABLE)
self.assertEqual(response.details(), 'failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:12085: connection attempt timed out before receiving SETTINGS frame')

@unittest.skip("Not implemented")
def test_rate_limit_toxic_heartbeat(self):
# Purpose: Limits the number of connections that can be established within a certain time frame.
# Toxic: rate_limit
# Use Case: Useful for testing how the client behaves under strict rate limits. For example, in Federated Learning,
# this could simulate constraints in networks with multiple clients trying to access the server.

# Add a rate limit toxic to the proxy
self.proxy.add_rate_limit(rate=1000)

@unittest.skip("Not implemented")
def test_bandwidth_toxic_heartbeat(self):
# Purpose: Limits the bandwidth of the connection.
# Toxic: bandwidth
# Use Case: Useful for testing how the client behaves under limited bandwidth. For example, in Federated Learning,
# this could simulate a slow network connection between the client and the server.

# Add a bandwidth toxic to the proxy
self.proxy.add_bandwidth(rate=1000) # 1 KB/s

@unittest.skip("Not implemented")
def test_connection_reset(self):
# Purpose: Immediately resets the connection, simulating an abrupt network drop.
# Toxic: add_reset
# Use Case: This is helpful for testing error-handling logic on sudden network failures,
# ensuring the client retries appropriately or fails gracefully

# Add a connection_reset toxic to the proxy
self.proxy.add_reset()

@unittest.skip("Not implemented")
def test_slow_close(self):
# Purpose: Simulates a slow closing of the connection.
# Toxic: slow_close
# Use Case: Useful for testing how the client behaves when the server closes the connection slowly.
# This can help ensure that the client handles slow network disconnections gracefully.

# Add a slow_close toxic to the proxy
self.proxy.add_slow_close(delay=1000) # Delay closing the connection by 1 second

@unittest.skip("Not implemented")
def test_slicer(self):
# Purpose: Slices the data into smaller chunks.
# Toxic: slicer
# Use Case: Useful for testing how the client handles fragmented data.
# This can help ensure that the client can reassemble the data correctly and handle partial data gracefully.

# Add a slicer toxic to the proxy
self.proxy.add_slicer(average_size=1000, size_variation=100) # Slice data into chunks of 1 KB with 100 bytes variation
69 changes: 51 additions & 18 deletions fedn/network/clients/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,25 @@ def __init__(self, host: str, port: int, name: str, token: str, combiner_name: s
("client", name),
("grpc-server", combiner_name),
]
self.host = host
self.port = port
self.token = token

if port == 443:
self._init_secure_channel(host, port, token)
else:
self._init_insecure_channel(host, port)
self._init_channel(host, port, token)

self._init_stubs()

def _init_stubs(self):
self.connectorStub = rpc.ConnectorStub(self.channel)
self.combinerStub = rpc.CombinerStub(self.channel)
self.modelStub = rpc.ModelServiceStub(self.channel)

def _init_channel(self, host: str, port: int, token: str):
if port == 443:
self._init_secure_channel(host, port, token)
else:
self._init_insecure_channel(host, port)

def _init_secure_channel(self, host: str, port: int, token: str):
url = f"{host}:{port}"
logger.info(f"Connecting (GRPC) to {url}")
Expand Down Expand Up @@ -116,10 +125,10 @@ def heartbeat(self, client_name: str, client_id: str):
logger.info("Sending heartbeat to combiner")
response = self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata)
except grpc.RpcError as e:
logger.error(f"GRPC (SendHeartbeat): An error occurred: {e}")
raise e
except Exception as e:
logger.error(f"GRPC (SendHeartbeat): An error occurred: {e}")
self._disconnect()
raise e
return response

Expand All @@ -130,6 +139,8 @@ def send_heartbeats(self, client_name: str, client_id: str, update_frequency: fl
response = self.heartbeat(client_name, client_id)
except grpc.RpcError as e:
return self._handle_grpc_error(e, "SendHeartbeat", lambda: self.send_heartbeats(client_name, client_id, update_frequency))
except Exception as e:
return self._handle_unknown_error(e, "SendHeartbeat", lambda: self.send_heartbeats(client_name, client_id, update_frequency))
if isinstance(response, fedn.Response):
logger.info("Heartbeat successful.")
else:
Expand Down Expand Up @@ -166,10 +177,11 @@ def listen_to_task_stream(self, client_name: str, client_id: str, callback: Call
callback(request)

except grpc.RpcError as e:
self.logger.error(f"GRPC (TaskStream): An error occurred: {e}")
return self._handle_grpc_error(e, "TaskStream", lambda: self.listen_to_task_stream(client_name, client_id, callback))
except Exception as e:
logger.error(f"GRPC (TaskStream): An error occurred: {e}")
self._disconnect()
self._handle_unknown_error(e, "TaskStream", lambda: self.listen_to_task_stream(client_name, client_id, callback))

def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None):
"""Send status message.
Expand Down Expand Up @@ -204,7 +216,7 @@ def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=N
return self._handle_grpc_error(e, "SendStatus", lambda: self.send_status(msg, log_level, type, request, sesssion_id, sender_name))
except Exception as e:
logger.error(f"GRPC (SendStatus): An error occurred: {e}")
self._disconnect()
self._handle_unknown_error(e, "SendStatus", lambda: self.send_status(msg, log_level, type, request, sesssion_id, sender_name))

def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> BytesIO:
"""Fetch a model from the assigned combiner.
Expand Down Expand Up @@ -241,8 +253,7 @@ def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) ->
return self._handle_grpc_error(e, "Download", lambda: self.get_model_from_combiner(id, client_id, timeout))
except Exception as e:
logger.error(f"GRPC (Download): An error occurred: {e}")
self._disconnect()

self._handle_unknown_error(e, "Download", lambda: self.get_model_from_combiner(id, client_id, timeout))
return data

def send_model_to_combiner(self, model: BytesIO, id: str):
Expand Down Expand Up @@ -273,8 +284,7 @@ def send_model_to_combiner(self, model: BytesIO, id: str):
return self._handle_grpc_error(e, "Upload", lambda: self.send_model_to_combiner(model, id))
except Exception as e:
logger.error(f"GRPC (Upload): An error occurred: {e}")
self._disconnect()

self._handle_unknown_error(e, "Upload", lambda: self.send_model_to_combiner(model, id))
return result

def create_update_message(
Expand Down Expand Up @@ -353,8 +363,7 @@ def send_model_update(self, update: fedn.ModelUpdate):
return self._handle_grpc_error(e, "SendModelUpdate", lambda: self.send_model_update(update))
except Exception as e:
logger.error(f"GRPC (SendModelUpdate): An error occurred: {e}")
self._disconnect()

self._handle_unknown_error(e, "SendModelUpdate", lambda: self.send_model_update(update))
return True

def send_model_validation(self, validation: fedn.ModelValidation) -> bool:
Expand All @@ -369,8 +378,7 @@ def send_model_validation(self, validation: fedn.ModelValidation) -> bool:
)
except Exception as e:
logger.error(f"GRPC (SendModelValidation): An error occurred: {e}")
self._disconnect()

self._handle_unknown_error(e, "SendModelValidation", lambda: self.send_model_validation(validation))
return True

def send_model_prediction(self, prediction: fedn.ModelPrediction) -> bool:
Expand All @@ -385,8 +393,7 @@ def send_model_prediction(self, prediction: fedn.ModelPrediction) -> bool:
)
except Exception as e:
logger.error(f"GRPC (SendModelPrediction): An error occurred: {e}")
self._disconnect()

self._handle_unknown_error(e, "SendModelPrediction", lambda: self.send_model_prediction(prediction))
return True

def _handle_grpc_error(self, e, method_name: str, sender_function: Callable):
Expand All @@ -399,12 +406,38 @@ def _handle_grpc_error(self, e, method_name: str, sender_function: Callable):
logger.warning(f"GRPC ({method_name}): connection cancelled. Retrying in 5 seconds.")
time.sleep(5)
return sender_function()
if status_code == grpc.StatusCode.UNAUTHENTICATED:
elif status_code == grpc.StatusCode.UNAUTHENTICATED:
details = e.details()
if details == "Token expired":
logger.warning(f"GRPC ({method_name}): Token expired.")
raise e
elif status_code == grpc.StatusCode.UNKNOWN:
logger.warning(f"GRPC ({method_name}): An unknown error occurred: {e}.")
details = e.details()
if details == "Stream removed":
logger.warning(f"GRPC ({method_name}): Stream removed. Reconnecting")
self._disconnect()
self._init_channel(self.host, self.port, self.token)
self._init_stubs()
return sender_function()
raise e
self._disconnect()
logger.error(f"GRPC ({method_name}): An error occurred: {e}")
raise e

def _handle_unknown_error(self, e, method_name: str, sender_function: Callable):
# Try to reconnect
logger.warning(f"GRPC ({method_name}): An unknown error occurred: {e}.")
if isinstance(e, ValueError):
# ValueError is raised when the channel is closed
self._disconnect()
logger.warning(f"GRPC ({method_name}): Reconnecting to channel.")
# recreate the channel
self._init_channel(self.host, self.port, self.token)
self._init_stubs()
return sender_function()
else:
raise e

def _disconnect(self):
"""Disconnect from the combiner."""
Expand Down
Loading