diff --git a/fedn/fedn/network/clients/client.py b/fedn/fedn/network/clients/client.py index 9851b32ef..b003136e8 100644 --- a/fedn/fedn/network/clients/client.py +++ b/fedn/fedn/network/clients/client.py @@ -127,6 +127,28 @@ def _assign(self): print("Received combiner config: {}".format(client_config), flush=True) return client_config + def _add_grpc_metadata(self, key, value): + """Add metadata for gRPC calls. + + :param key: The key of the metadata. + :type key: str + :param value: The value of the metadata. + :type value: str + """ + # Check if metadata exists and add if not + if not hasattr(self, 'metadata'): + self.metadata = () + + # Check if metadata key already exists and replace value if so + for i, (k, v) in enumerate(self.metadata): + if k == key: + # Replace value + self.metadata = self.metadata[:i] + ((key, value),) + self.metadata[i + 1:] + return + + # Set metadata using tuple concatenation + self.metadata += ((key, value),) + def _connect(self, client_config): """Connect to assigned combiner. @@ -137,6 +159,9 @@ def _connect(self, client_config): # TODO use the client_config['certificate'] for setting up secure comms' host = client_config['host'] + # Add host to gRPC metadata + self._add_grpc_metadata('grpc-server', host) + print("CLIENT: Using metadata: {}".format(self.metadata), flush=True) port = client_config['port'] secure = False if client_config['fqdn'] is not None: @@ -331,7 +356,7 @@ def get_model(self, id): """ data = BytesIO() - for part in self.modelStub.Download(fedn.ModelRequest(id=id)): + for part in self.modelStub.Download(fedn.ModelRequest(id=id), metadata=self.metadata): if part.status == fedn.ModelStatus.IN_PROGRESS: data.write(part.data) @@ -386,7 +411,7 @@ def upload_request_generator(mdl): if not b: break - result = self.modelStub.Upload(upload_request_generator(bt)) + result = self.modelStub.Upload(upload_request_generator(bt), metadata=self.metadata) return result @@ -400,11 +425,12 @@ def _listen_to_model_update_request_stream(self): r = fedn.ClientAvailableMessage() r.sender.name = self.name r.sender.role = fedn.WORKER - metadata = [('client', r.sender.name)] + # Add client to metadata + self._add_grpc_metadata('client', self.name) while True: try: - for request in self.combinerStub.ModelUpdateRequestStream(r, metadata=metadata): + for request in self.combinerStub.ModelUpdateRequestStream(r, metadata=self.metadata): if request.sender.role == fedn.COMBINER: # Process training request self._send_status("Received model update request.", log_level=fedn.Status.AUDIT, @@ -438,7 +464,7 @@ def _listen_to_model_validation_request_stream(self): r.sender.role = fedn.WORKER while True: try: - for request in self.combinerStub.ModelValidationRequestStream(r): + for request in self.combinerStub.ModelValidationRequestStream(r, metadata=self.metadata): # Process validation request _ = request.model_id self._send_status("Recieved model validation request.", log_level=fedn.Status.AUDIT, @@ -589,7 +615,7 @@ def process_request(self): update.correlation_id = request.correlation_id update.meta = json.dumps(meta) # TODO: Check responses - _ = self.combinerStub.SendModelUpdate(update) + _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) self._send_status("Model update completed.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_UPDATE, request=update) @@ -618,7 +644,7 @@ def process_request(self): validation.timestamp = self.str validation.correlation_id = request.correlation_id _ = self.combinerStub.SendModelValidation( - validation) + validation, metadata=self.metadata) # Set status type if request.is_inference: @@ -655,7 +681,7 @@ def _send_heartbeat(self, update_frequency=2.0): heartbeat = fedn.Heartbeat(sender=fedn.Client( name=self.name, role=fedn.WORKER)) try: - self.connectorStub.SendHeartbeat(heartbeat) + self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata) self._missed_heartbeat = 0 except grpc.RpcError as e: status_code = e.code() @@ -694,7 +720,7 @@ def _send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None) self.logs.append( "{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, status.status)) - _ = self.connectorStub.SendStatus(status) + _ = self.connectorStub.SendStatus(status, metadata=self.metadata) def run(self): """ Run the client. """ diff --git a/fedn/fedn/network/clients/test_client.py b/fedn/fedn/network/clients/test_client.py new file mode 100644 index 000000000..e9a4fdfd4 --- /dev/null +++ b/fedn/fedn/network/clients/test_client.py @@ -0,0 +1,46 @@ +import unittest +from unittest.mock import MagicMock + +from fedn.network.clients.client import Client + + +class TestClient(unittest.TestCase): + """Test the Client class.""" + + def setUp(self): + self.client = Client() + + def test_add_grpc_metadata(self): + """Test the _add_grpc_metadata method.""" + + # Test adding metadata when it doesn't exist + self.client._add_grpc_metadata('key1', 'value1') + self.assertEqual(self.client.metadata, (('key1', 'value1'),)) + + # Test adding metadata when it already exists + self.client._add_grpc_metadata('key1', 'value2') + self.assertEqual(self.client.metadata, (('key1', 'value2'),)) + + # Test adding multiple metadata + self.client._add_grpc_metadata('key2', 'value3') + self.assertEqual(self.client.metadata, (('key1', 'value2'), ('key2', 'value3'))) + + # Test adding metadata with special characters + self.client._add_grpc_metadata('key3', 'value4!@#$%^&*()') + self.assertEqual(self.client.metadata, (('key1', 'value2'), ('key2', 'value3'), ('key3', 'value4!@#$%^&*()'))) + + # Test adding metadata with empty key + with self.assertRaises(ValueError): + self.client._add_grpc_metadata('', 'value5') + + # Test adding metadata with empty value + with self.assertRaises(ValueError): + self.client._add_grpc_metadata('key4', '') + + # Test adding metadata with None value + with self.assertRaises(ValueError): + self.client._add_grpc_metadata('key5', None) + + +if __name__ == '__main__': + unittest.main()