From 797d9a7df80d13e7658f7e8fedfe80d1a3deabc1 Mon Sep 17 00:00:00 2001 From: niklastheman Date: Fri, 11 Oct 2024 14:06:04 +0200 Subject: [PATCH 01/18] Fixed spelling error --- fedn/network/clients/client_api.py | 4 ++-- fedn/network/clients/client_v2.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fedn/network/clients/client_api.py b/fedn/network/clients/client_api.py index 332331309..53a9eabc3 100644 --- a/fedn/network/clients/client_api.py +++ b/fedn/network/clients/client_api.py @@ -28,7 +28,7 @@ def __init__(self, status: str, host: str, fqdn: str, package: str, ip: str, por # Enum for respresenting the result of connecting to the FEDn API class ConnectToApiResult(enum.Enum): Assigned = 0 - ComputePackgeMissing = 1 + ComputePackageMissing = 1 UnAuthorized = 2 UnMatchedConfig = 3 IncorrectUrl = 4 @@ -110,7 +110,7 @@ def connect_to_api(self, url: str, token: str, json: dict) -> Tuple[ConnectToApi elif response.status_code == 203: json_response = response.json() logger.info("Connect to FEDn Api - Remote compute package missing.") - return ConnectToApiResult.ComputePackgeMissing, json_response + return ConnectToApiResult.ComputePackageMissing, json_response elif response.status_code == 401: logger.warning("Connect to FEDn Api - Unauthorized") return ConnectToApiResult.UnAuthorized, "Unauthorized" diff --git a/fedn/network/clients/client_v2.py b/fedn/network/clients/client_v2.py index ab32f6116..047f82e92 100644 --- a/fedn/network/clients/client_v2.py +++ b/fedn/network/clients/client_v2.py @@ -72,8 +72,8 @@ def __init__(self, def _connect_to_api(self) -> Tuple[bool, dict]: result = None - while not result or result == ConnectToApiResult.ComputePackgeMissing: - if result == ConnectToApiResult.ComputePackgeMissing: + while not result or result == ConnectToApiResult.ComputePackageMissing: + if result == ConnectToApiResult.ComputePackageMissing: logger.info("Retrying in 3 seconds") time.sleep(3) result, response = self.client_api.connect_to_api(self.fedn_api_url, self.token, self.client_obj.to_json()) From 685236f4f58f9cc666cf2675e0797a724b5b83ba Mon Sep 17 00:00:00 2001 From: niklastheman Date: Fri, 11 Oct 2024 15:33:05 +0200 Subject: [PATCH 02/18] Added readme to display how to use client api --- fedn/network/clients/README.rst | 157 ++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 fedn/network/clients/README.rst diff --git a/fedn/network/clients/README.rst b/fedn/network/clients/README.rst new file mode 100644 index 000000000..9d4adbf63 --- /dev/null +++ b/fedn/network/clients/README.rst @@ -0,0 +1,157 @@ +Creating Your Own Client +======================== + +This guide will help you create your own client for the FEDn network. + +Step-by-Step Instructions +------------------------- + +1. **Create a virutal environment**: Start by creating a virtual environment and activating it. + + ```bash + python3 -m venv fedn-env + source fedn-env/bin/activate + ``` + +2. **Install FEDn**: Install the FEDn package by running the following command: + + ```bash + pip install fedn + ``` + +3. **Create your client**: Copy and paste the code Below into a new file called `client_example.py`. + + ```python + import argparse + import json + import threading + import uuid + + import fedn.network.grpc.fedn_pb2 as fedn + from fedn.network.clients.client_api import ClientAPI, ConnectToApiResult + + client_api = ClientAPI() + + + def main(api_url: str, api_port: int, token: str = None, name: str = None): + print(f"API URL: {api_url}") + print(f"API Token: {token or "-"}") + print(f"API Port: {api_port or "-"}") + + if name is None: + name = input("Enter Client Name: ") + + url = f"{api_url}:{api_port}" if api_port else api_url + + if not url.endswith("/"): + url += "/" + + print(f"Client Name: {name}") + + client_id = str(uuid.uuid4()) + + print("Connecting to API...") + + client_options = { + "name": "client_example", + "client_id": client_id, + "package": "local", + "preferred_combiner": "", + } + + result, combiner_config = client_api.connect_to_api(url, token, client_options) + + if result != ConnectToApiResult.Assigned: + print("Failed to connect to API, exiting.") + return + + print("Connected to API") + + result: bool = client_api.init_grpchandler(config=combiner_config, client_name=client_id, token=token) + + if not result: + return + + threading.Thread(target=client_api.send_heartbeats, kwargs={"client_name": name, "client_id": client_id}, daemon=True).start() + + def on_train(request): + print("Received train request") + model_id: str = request.model_id + + model = client_api.get_model_from_combiner(id=str(model_id), client_name=name) + + # Do your training here + out_model = model + updated_model_id = uuid.uuid4() + client_api.send_model_to_combiner(out_model, str(updated_model_id)) + + # val metadataJson = buildJsonObject { + # put("num_examples", 1) + # put("batch_size", 1) + # put("epochs", 1) + # put("lr", 1) + # } + # val configJson = buildJsonObject { + # put("round_id", 1) + # } + + # val json = buildJsonObject { + # put("training_metadata", metadataJson) + # put("config", configJson.toString()) + # } + + training_metadata = { + "num_examples": 1, + "batch_size": 1, + "epochs": 1, + "lr": 1, + } + + config = { + "round_id": 1, + } + + client_api.send_model_update( + sender_name=name, + sender_role=fedn.WORKER, + client_id=client_id, + model_id=model_id, + model_update_id=str(updated_model_id), + receiver_name=request.sender.name, + receiver_role=request.sender.role, + meta={ + "training_metadata": training_metadata, + "config": json.dumps(config), + }, + ) + + client_api.subscribe("train", on_train) + + threading.Thread(target=client_api.listen_to_task_stream, kwargs={"client_name": name, "client_id": client_id}, daemon=True).start() + + stop_event = threading.Event() + try: + stop_event.wait() + except KeyboardInterrupt: + print("Client stopped by user.") + + + if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Client Example") + parser.add_argument("--api-url", type=str, required=True, help="The API URL") + parser.add_argument("--api-port", type=int, required=False, help="The API Port") + parser.add_argument("--token", type=str, required=False, help="The API Token") + parser.add_argument("--name", type=str, required=False, help="The Client Name") + + args = parser.parse_args() + main(args.api_url, args.api_port) + + ``` +4. **Run the client**: Run the client by executing the following command: + + ```bash + python client_example.py --api-url http(s):// --token + ``` + Replace `` and `` with the URL and token of the FEDn API. + +5. **Start training**: Create a session and start training by using either the FEDn CLI or the FEDn UI. \ No newline at end of file From b1dcdbedad86fac9e9cb5b39d963760216bbf9e7 Mon Sep 17 00:00:00 2001 From: Niklas Date: Fri, 11 Oct 2024 16:00:13 +0200 Subject: [PATCH 03/18] Update README.rst --- fedn/network/clients/README.rst | 107 ++++++++++++++------------------ 1 file changed, 46 insertions(+), 61 deletions(-) diff --git a/fedn/network/clients/README.rst b/fedn/network/clients/README.rst index 9d4adbf63..ae67f7f00 100644 --- a/fedn/network/clients/README.rst +++ b/fedn/network/clients/README.rst @@ -8,109 +8,95 @@ Step-by-Step Instructions 1. **Create a virutal environment**: Start by creating a virtual environment and activating it. - ```bash +.. code-block:: bash + python3 -m venv fedn-env source fedn-env/bin/activate - ``` + 2. **Install FEDn**: Install the FEDn package by running the following command: - ```bash +.. code-block:: bash + pip install fedn - ``` 3. **Create your client**: Copy and paste the code Below into a new file called `client_example.py`. - ```python +.. code-block:: python + import argparse import json import threading import uuid - + import fedn.network.grpc.fedn_pb2 as fedn from fedn.network.clients.client_api import ClientAPI, ConnectToApiResult - + client_api = ClientAPI() - - - def main(api_url: str, api_port: int, token: str = None, name: str = None): + + + def main(api_url: str, api_port: int, token: str = None): print(f"API URL: {api_url}") print(f"API Token: {token or "-"}") print(f"API Port: {api_port or "-"}") - - if name is None: - name = input("Enter Client Name: ") - + + name = input("Enter Client Name: ") + url = f"{api_url}:{api_port}" if api_port else api_url - + if not url.endswith("/"): url += "/" - + print(f"Client Name: {name}") - + client_id = str(uuid.uuid4()) - + print("Connecting to API...") - + client_options = { "name": "client_example", "client_id": client_id, "package": "local", "preferred_combiner": "", } - + result, combiner_config = client_api.connect_to_api(url, token, client_options) - + if result != ConnectToApiResult.Assigned: print("Failed to connect to API, exiting.") return - + print("Connected to API") - + result: bool = client_api.init_grpchandler(config=combiner_config, client_name=client_id, token=token) - + if not result: return - + threading.Thread(target=client_api.send_heartbeats, kwargs={"client_name": name, "client_id": client_id}, daemon=True).start() - + def on_train(request): print("Received train request") model_id: str = request.model_id - + model = client_api.get_model_from_combiner(id=str(model_id), client_name=name) - - # Do your training here + + # Do your training here, out_model is your result... out_model = model updated_model_id = uuid.uuid4() client_api.send_model_to_combiner(out_model, str(updated_model_id)) - - # val metadataJson = buildJsonObject { - # put("num_examples", 1) - # put("batch_size", 1) - # put("epochs", 1) - # put("lr", 1) - # } - # val configJson = buildJsonObject { - # put("round_id", 1) - # } - - # val json = buildJsonObject { - # put("training_metadata", metadataJson) - # put("config", configJson.toString()) - # } - + training_metadata = { "num_examples": 1, "batch_size": 1, "epochs": 1, "lr": 1, } - + config = { "round_id": 1, } - + client_api.send_model_update( sender_name=name, sender_role=fedn.WORKER, @@ -124,34 +110,33 @@ Step-by-Step Instructions "config": json.dumps(config), }, ) - + client_api.subscribe("train", on_train) - + threading.Thread(target=client_api.listen_to_task_stream, kwargs={"client_name": name, "client_id": client_id}, daemon=True).start() - + stop_event = threading.Event() try: stop_event.wait() except KeyboardInterrupt: print("Client stopped by user.") - - + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Client Example") parser.add_argument("--api-url", type=str, required=True, help="The API URL") parser.add_argument("--api-port", type=int, required=False, help="The API Port") parser.add_argument("--token", type=str, required=False, help="The API Token") - parser.add_argument("--name", type=str, required=False, help="The Client Name") - + args = parser.parse_args() - main(args.api_url, args.api_port) + main(args.api_url, args.api_port, args.token) - ``` 4. **Run the client**: Run the client by executing the following command: - ```bash - python client_example.py --api-url http(s):// --token - ``` - Replace `` and `` with the URL and token of the FEDn API. +.. code-block:: bash + + python client_example.py --api-url --token + +Replace `` and `` with the URL and token of the FEDn API. *Example when running a local FEDn instance: python client_example.py --api-url http://localhost:8092* -5. **Start training**: Create a session and start training by using either the FEDn CLI or the FEDn UI. \ No newline at end of file +5. **Start training**: Create a session and start training by using either the FEDn CLI or the FEDn UI. From f3ec3623ba79218a1768def8cee9891f5c51dc6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Benjamin=20=C3=85strand?= <112588563+benjaminastrand@users.noreply.github.com> Date: Wed, 30 Oct 2024 08:46:08 +0100 Subject: [PATCH 04/18] Refactor/SK-1144 | Simplify on_train and on_validate (#731) --- fedn/network/clients/README.rst | 142 ++++++++---------- fedn/network/clients/client_api.py | 206 ++++++++++++++++++++------- fedn/network/clients/client_v2.py | 189 +++++------------------- fedn/network/clients/grpc_handler.py | 7 +- 4 files changed, 258 insertions(+), 286 deletions(-) diff --git a/fedn/network/clients/README.rst b/fedn/network/clients/README.rst index ae67f7f00..013a6a75a 100644 --- a/fedn/network/clients/README.rst +++ b/fedn/network/clients/README.rst @@ -26,111 +26,95 @@ Step-by-Step Instructions import argparse import json - import threading import uuid - - import fedn.network.grpc.fedn_pb2 as fedn + from fedn.network.clients.client_api import ClientAPI, ConnectToApiResult - - client_api = ClientAPI() - - + + + def get_api_url(api_url: str, api_port: int): + url = f"{api_url}:{api_port}" if api_port else api_url + if not url.endswith("/"): + url += "/" + return url + + def on_train(in_model): + training_metadata = { + "num_examples": 1, + "batch_size": 1, + "epochs": 1, + "lr": 1, + } + + config = { + "round_id": 1, + } + + metadata = { + "training_metadata": training_metadata, + "config": json.dumps(config), + } + + # Do your training here, out_model is your result... + out_model = in_model + + return out_model, metadata + + def on_validate(in_model): + + # Calculate metrics here... + metrics = { + "test_accuracy": 0.9, + "test_loss": 0.1, + "train_accuracy": 0.8, + "train_loss": 0.2, + } + return metrics + def main(api_url: str, api_port: int, token: str = None): print(f"API URL: {api_url}") print(f"API Token: {token or "-"}") print(f"API Port: {api_port or "-"}") - + + client_api = ClientAPI(train_callback=on_train, validate_callback=on_validate) + + url = get_api_url(api_url, api_port) + name = input("Enter Client Name: ") - - url = f"{api_url}:{api_port}" if api_port else api_url - - if not url.endswith("/"): - url += "/" - - print(f"Client Name: {name}") - + client_api.set_name(name) + client_id = str(uuid.uuid4()) - - print("Connecting to API...") - - client_options = { - "name": "client_example", + client_api.set_client_id(client_id) + + controller_config = { + "name": name, "client_id": client_id, "package": "local", "preferred_combiner": "", } - - result, combiner_config = client_api.connect_to_api(url, token, client_options) - + + result, combiner_config = client_api.connect_to_api(url, token, controller_config) + if result != ConnectToApiResult.Assigned: print("Failed to connect to API, exiting.") return - - print("Connected to API") - + result: bool = client_api.init_grpchandler(config=combiner_config, client_name=client_id, token=token) - + if not result: return - - threading.Thread(target=client_api.send_heartbeats, kwargs={"client_name": name, "client_id": client_id}, daemon=True).start() - - def on_train(request): - print("Received train request") - model_id: str = request.model_id - - model = client_api.get_model_from_combiner(id=str(model_id), client_name=name) - - # Do your training here, out_model is your result... - out_model = model - updated_model_id = uuid.uuid4() - client_api.send_model_to_combiner(out_model, str(updated_model_id)) - - training_metadata = { - "num_examples": 1, - "batch_size": 1, - "epochs": 1, - "lr": 1, - } - - config = { - "round_id": 1, - } - - client_api.send_model_update( - sender_name=name, - sender_role=fedn.WORKER, - client_id=client_id, - model_id=model_id, - model_update_id=str(updated_model_id), - receiver_name=request.sender.name, - receiver_role=request.sender.role, - meta={ - "training_metadata": training_metadata, - "config": json.dumps(config), - }, - ) - - client_api.subscribe("train", on_train) - - threading.Thread(target=client_api.listen_to_task_stream, kwargs={"client_name": name, "client_id": client_id}, daemon=True).start() - - stop_event = threading.Event() - try: - stop_event.wait() - except KeyboardInterrupt: - print("Client stopped by user.") - - + + client_api.run() + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Client Example") parser.add_argument("--api-url", type=str, required=True, help="The API URL") parser.add_argument("--api-port", type=int, required=False, help="The API Port") parser.add_argument("--token", type=str, required=False, help="The API Token") - + args = parser.parse_args() main(args.api_url, args.api_port, args.token) + 4. **Run the client**: Run the client by executing the following command: .. code-block:: bash diff --git a/fedn/network/clients/client_api.py b/fedn/network/clients/client_api.py index 53a9eabc3..88d1eb77c 100644 --- a/fedn/network/clients/client_api.py +++ b/fedn/network/clients/client_api.py @@ -1,8 +1,11 @@ import enum import os import time +import threading from io import BytesIO from typing import Any, Tuple +import uuid +import json import requests @@ -51,44 +54,20 @@ def get_compute_package_dir_path(): class ClientAPI: - def __init__(self): - self._subscribers = {"train": [], "validation": []} + def __init__(self, train_callback: callable = None, validate_callback: callable = None): + self.train_callback: callable = train_callback + self.validate_callback: callable = validate_callback path = get_compute_package_dir_path() self._package_runtime = PackageRuntime(path) self.dispatcher: Dispatcher = None self.grpc_handler: GrpcHandler = None - def subscribe(self, event_type: str, callback: callable): - """Subscribe to a specific event.""" - if event_type in self._subscribers: - self._subscribers[event_type].append(callback) - else: - raise ValueError(f"Unsupported event type: {event_type}") - - def notify_subscribers(self, event_type: str, *args, **kwargs): - """Notify all subscribers about a specific event.""" - if event_type in self._subscribers: - for callback in self._subscribers[event_type]: - callback(*args, **kwargs) - else: - raise ValueError(f"Unsupported event type: {event_type}") - - def train(self, *args, **kwargs): - """Function to be triggered from the server via gRPC.""" - # Perform training logic here - logger.info("Training started") - - # Notify all subscribers about the train event - self.notify_subscribers("train", *args, **kwargs) + def set_train_callback(self, callback: callable): + self.train_callback = callback - def validate(self, *args, **kwargs): - """Function to be triggered from the server via gRPC.""" - # Perform validation logic here - logger.info("Validation started") - - # Notify all subscribers about the validation event - self.notify_subscribers("validation", *args, **kwargs) + def set_validate_callback(self, callback: callable): + self.validate_callback = callback def connect_to_api(self, url: str, token: str, json: dict) -> Tuple[ConnectToApiResult, Any]: # TODO: Use new API endpoint (v1) @@ -216,9 +195,136 @@ def listen_to_task_stream(self, client_name: str, client_id: str): def _task_stream_callback(self, request): if request.type == fedn.StatusType.MODEL_UPDATE: - self.train(request) + self.update_local_model(request) elif request.type == fedn.StatusType.MODEL_VALIDATION: - self.validate(request) + self.validate_global_model(request) + + def update_local_model(self, request): + model_id = request.model_id + model_update_id = str(uuid.uuid4()) + + tic = time.time() + in_model = self.get_model_from_combiner(id=model_id, client_name=self.name) + + if in_model is None: + logger.error("Could not retrieve model from combiner. Aborting training request.") + return + + fetch_model_time = time.time() - tic + + if not self.train_callback: + logger.error("No train callback set") + return + + self.send_status( + f"\t Starting processing of training request for model_id {model_id}", + sesssion_id=request.session_id, + sender_name=self.name + ) + + logger.info(f"Running train callback with model ID: {model_id}") + tic = time.time() + out_model, meta = self.train_callback(in_model) + meta["processing_time"] = time.time() - tic + + tic = time.time() + self.send_model_to_combiner( + model=out_model, + id=model_update_id + ) + meta["upload_model"] = time.time() - tic + + meta["fetch_model"] = fetch_model_time + meta["config"] = request.data + + self.send_model_update( + model_id=model_id, + model_update_id=model_update_id, + meta=meta, + request=request + ) + + self.send_status( + "Model update completed.", + log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_UPDATE, + request=request, + sesssion_id=request.session_id, + sender_name=self.name + ) + + def validate_global_model(self, request): + model_id = request.model_id + + self.send_status( + f"Processing validate request for model_id {model_id}", + sesssion_id=request.session_id, + sender_name=self.name + ) + + in_model = self.get_model_from_combiner(id=model_id, client_name=self.name) + + if in_model is None: + logger.error("Could not retrieve model from combiner. Aborting validation request.") + return + + if not self.validate_callback: + logger.error("No validate callback set") + return + + logger.info(f"Running validate callback with model ID: {model_id}") + metrics = self.validate_callback(in_model) + + if metrics is not None: + # Send validation + validation = fedn.ModelValidation() + validation.sender.name = self.name + validation.sender.role = fedn.WORKER + validation.receiver.name = request.sender.name + validation.receiver.role = request.sender.role + validation.model_id = str(request.model_id) + validation.data = json.dumps(metrics) + validation.timestamp.GetCurrentTime() + validation.correlation_id = request.correlation_id + validation.session_id = request.session_id + + result: bool = self.send_model_validation( + metrics=metrics, + request=request + ) + + if result: + self.send_status( + "Model validation completed.", + log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_VALIDATION, + request=validation, + sesssion_id=request.session_id, + sender_name=self.name + ) + else: + self.send_status( + "Client {} failed to complete model validation.".format(self.name), + log_level=fedn.Status.WARNING, + request=request, + sesssion_id=request.session_id, + sender_name=self.name + ) + + def set_name(self, name: str): + logger.info(f"Setting client name to: {name}") + self.name = name + + def set_client_id(self, client_id: str): + logger.info(f"Setting client ID to: {client_id}") + self.client_id = client_id + + def run(self): + threading.Thread(target=self.send_heartbeats, kwargs={"client_name": self.name, "client_id": self.client_id}, daemon=True).start() + try: + self.listen_to_task_stream(client_name=self.name, client_id=self.client_id) + except KeyboardInterrupt: + logger.info("Client stopped by user.") def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) -> BytesIO: return self.grpc_handler.get_model_from_combiner(id=id, client_name=client_name, timeout=timeout) @@ -230,36 +336,34 @@ def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=N return self.grpc_handler.send_status(msg, log_level, type, request, sesssion_id, sender_name) def send_model_update(self, - sender_name: str, - sender_role: fedn.Role, - client_id: str, model_id: str, model_update_id: str, - receiver_name: str, - receiver_role: fedn.Role, - meta: dict + meta: dict, + request: fedn.TaskRequest ) -> bool: + return self.grpc_handler.send_model_update( - sender_name=sender_name, - sender_role=sender_role, - client_id=client_id, + sender_name=self.name, model_id=model_id, model_update_id=model_update_id, - receiver_name=receiver_name, - receiver_role=receiver_role, + receiver_name=request.sender.name, + receiver_role=request.sender.role, meta=meta ) def send_model_validation(self, - sender_name: str, - receiver_name: str, - receiver_role: fedn.Role, - model_id: str, metrics: dict, - correlation_id: str, - session_id: str + request: fedn.TaskRequest ) -> bool: - return self.grpc_handler.send_model_validation(sender_name, receiver_name, receiver_role, model_id, metrics, correlation_id, session_id) + return self.grpc_handler.send_model_validation( + sender_name=self.name, + receiver_name=request.sender.name, + receiver_role=request.sender.role, + model_id=request.model_id, + metrics=json.dumps(metrics), + correlation_id=request.correlation_id, + session_id=request.session_id + ) # Init functions def init_remote_compute_package(self, url: str, token: str, package_checksum: str = None) -> bool: diff --git a/fedn/network/clients/client_v2.py b/fedn/network/clients/client_v2.py index 047f82e92..fa1bf8e44 100644 --- a/fedn/network/clients/client_v2.py +++ b/fedn/network/clients/client_v2.py @@ -1,12 +1,11 @@ import io import json import os -import threading import time import uuid +from io import BytesIO from typing import Tuple -import fedn.network.grpc.fedn_pb2 as fedn from fedn.common.config import FEDN_CUSTOM_URL_PREFIX from fedn.common.log_config import logger from fedn.network.clients.client_api import ClientAPI, ConnectToApiResult, GrpcConnectionOptions @@ -44,16 +43,17 @@ def to_json(self): class Client: - def __init__(self, - api_url: str, - api_port: int, - client_obj: ClientOptions, - combiner_host: str = None, - combiner_port: int = None, - token: str = None, - package_checksum: str = None, - helper_type: str = None - ): + def __init__( + self, + api_url: str, + api_port: int, + client_obj: ClientOptions, + combiner_host: str = None, + combiner_port: int = None, + token: str = None, + package_checksum: str = None, + helper_type: str = None, + ): self.api_url = api_url self.api_port = api_port self.combiner_host = combiner_host @@ -114,22 +114,13 @@ def start(self): logger.info("-----------------------------") - threading.Thread( - target=self.client_api.send_heartbeats, kwargs={"client_name": self.client_obj.name, "client_id": self.client_obj.client_id}, daemon=True - ).start() + self.client_api.set_train_callback(self.on_train) + self.client_api.set_validate_callback(self.on_validation) - self.client_api.subscribe("train", self.on_train) - self.client_api.subscribe("validation", self.on_validation) + self.client_api.set_name(self.client_obj.name) + self.client_api.set_client_id(self.client_obj.client_id) - threading.Thread( - target=self.client_api.listen_to_task_stream, kwargs={"client_name": self.client_obj.name, "client_id": self.client_obj.client_id}, daemon=True - ).start() - - stop_event = threading.Event() - try: - stop_event.wait() - except KeyboardInterrupt: - logger.info("Client stopped by user.") + self.client_api.run() def set_helper(self, response: GrpcConnectionOptions = None): helper_type = response.get("helper_type", None) @@ -141,50 +132,29 @@ def set_helper(self, response: GrpcConnectionOptions = None): # Priority: helper_type from constructor > helper_type from response > default helper_type self.helper = get_helper(helper_type_to_use) - def on_train(self, request): - logger.info("Received train request") - self._process_training_request(request) + def on_train(self, in_model): + out_model, meta = self._process_training_request(in_model) + return out_model, meta - def on_validation(self, request): - logger.info("Received validation request") - self._process_validation_request(request) + def on_validation(self, in_model): + metrics = self._process_validation_request(in_model) + return metrics - - def _process_training_request(self, request) -> Tuple[str, dict]: + def _process_training_request(self, in_model: BytesIO) -> Tuple[BytesIO, dict]: """Process a training (model update) request. - :param model_id: The model id of the model to be updated. - :type model_id: str - :param session_id: The id of the current session - :type session_id: str - :return: The model id of the updated model, or None if the update failed. And a dict with metadata. + :param in_model: The model to be updated. + :type in_model: BytesIO + :return: The updated model, or None if the update failed. And a dict with metadata. :rtype: tuple """ - model_id: str = request.model_id - session_id: str = request.session_id - - self.client_api.send_status( - f"\t Starting processing of training request for model_id {model_id}", - sesssion_id=session_id, - sender_name=self.client_obj.name - ) - try: meta = {} - tic = time.time() - - model = self.client_api.get_model_from_combiner(id=str(model_id), client_name=self.client_obj.client_id) - - if model is None: - logger.error("Could not retrieve model from combiner. Aborting training request.") - return None, None - - meta["fetch_model"] = time.time() - tic inpath = self.helper.get_tmp_path() with open(inpath, "wb") as fh: - fh.write(model.getbuffer()) + fh.write(in_model.getbuffer()) outpath = self.helper.get_tmp_path() @@ -194,17 +164,11 @@ def _process_training_request(self, request) -> Tuple[str, dict]: meta["exec_training"] = time.time() - tic - tic = time.time() out_model = None with open(outpath, "rb") as fr: out_model = io.BytesIO(fr.read()) - # Stream model update to combiner server - updated_model_id = uuid.uuid4() - self.client_api.send_model_to_combiner(out_model, str(updated_model_id)) - meta["upload_model"] = time.time() - tic - # Read the metadata file with open(outpath + "-metadata", "r") as fh: training_metadata = json.loads(fh.read()) @@ -218,65 +182,27 @@ def _process_training_request(self, request) -> Tuple[str, dict]: except Exception as e: logger.error("Could not process training request due to error: {}".format(e)) - updated_model_id = None + out_model = None meta = {"status": "failed", "error": str(e)} - if meta is not None: - processing_time = time.time() - tic - meta["processing_time"] = processing_time - meta["config"] = request.data - - if model_id is not None: - # Send model update to combiner - - self.client_api.send_model_update( - sender_name=self.client_obj.name, - sender_role=fedn.WORKER, - client_id=self.client_obj.client_id, - model_id=model_id, - model_update_id=str(updated_model_id), - receiver_name=request.sender.name, - receiver_role=request.sender.role, - meta=meta, - ) - - self.client_api.send_status( - "Model update completed.", - log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_UPDATE, - request=request, - sesssion_id=session_id, - sender_name=self.client_obj.name - ) - - def _process_validation_request(self, request): + return out_model, meta + + def _process_validation_request(self, in_model: BytesIO) -> dict: """Process a validation request. - :param model_id: The model id of the model to be validated. - :type model_id: str - :param session_id: The id of the current session. - :type session_id: str + :param in_model: The model to be validated. + :type in_model: BytesIO :return: The validation metrics, or None if validation failed. :rtype: dict """ - model_id: str = request.model_id - session_id: str = request.session_id - cmd = "validate" - - self.client_api.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id, sender_name=self.client_obj.name) - try: - model = self.client_api.get_model_from_combiner(id=str(model_id), client_name=self.client_obj.client_id) - if model is None: - logger.error("Could not retrieve model from combiner. Aborting validation request.") - return inpath = self.helper.get_tmp_path() with open(inpath, "wb") as fh: - fh.write(model.getbuffer()) + fh.write(in_model.getbuffer()) outpath = get_tmp_path() - self.client_api.dispatcher.run_cmd(f"{cmd} {inpath} {outpath}") + self.client_api.dispatcher.run_cmd(f"validate {inpath} {outpath}") with open(outpath, "r") as fh: metrics = json.loads(fh.read()) @@ -286,45 +212,6 @@ def _process_validation_request(self, request): except Exception as e: logger.warning("Validation failed with exception {}".format(e)) + metrics = None - if metrics is not None: - # Send validation - validation = fedn.ModelValidation() - validation.sender.name = self.client_obj.name - validation.sender.role = fedn.WORKER - validation.receiver.name = request.sender.name - validation.receiver.role = request.sender.role - validation.model_id = str(request.model_id) - validation.data = json.dumps(metrics) - validation.timestamp.GetCurrentTime() - validation.correlation_id = request.correlation_id - validation.session_id = request.session_id - - # sender_name: str, sender_role: fedn.Role, model_id: str, model_update_id: str - result: bool = self.client_api.send_model_validation( - sender_name=self.client_obj.name, - receiver_name=request.sender.name, - receiver_role=request.sender.role, - model_id=str(request.model_id), - metrics=json.dumps(metrics), - correlation_id=request.correlation_id, - session_id=request.session_id, - ) - - if result: - self.client_api.send_status( - "Model validation completed.", - log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_VALIDATION, - request=validation, - sesssion_id=request.session_id, - sender_name=self.client_obj.name - ) - else: - self.client_api.send_status( - "Client {} failed to complete model validation.".format(self.name), - log_level=fedn.Status.WARNING, - request=request, - sesssion_id=request.session_id, - sender_name=self.client_obj.name - ) + return metrics diff --git a/fedn/network/clients/grpc_handler.py b/fedn/network/clients/grpc_handler.py index 9b8550344..e06ad2d61 100644 --- a/fedn/network/clients/grpc_handler.py +++ b/fedn/network/clients/grpc_handler.py @@ -236,8 +236,6 @@ def send_model_to_combiner(self, model: BytesIO, id: str): def send_model_update(self, sender_name: str, - sender_role: fedn.Role, - client_id: str, model_id: str, model_update_id: str, receiver_name: str, @@ -246,8 +244,8 @@ def send_model_update(self, ): update = fedn.ModelUpdate() update.sender.name = sender_name - update.sender.role = sender_role - update.sender.client_id = client_id + update.sender.role = fedn.WORKER + update.sender.client_id = self.metadata[0][1] update.receiver.name = receiver_name update.receiver.role = receiver_role update.model_id = model_id @@ -264,7 +262,6 @@ def send_model_update(self, "SendModelUpdate", lambda: self.send_model_update( sender_name, - sender_role, model_id, model_update_id, receiver_name, From 257d27ae888e8b92781fa2fca5ed0f86dbb7fd2f Mon Sep 17 00:00:00 2001 From: benjaminastrand Date: Mon, 4 Nov 2024 13:07:17 +0100 Subject: [PATCH 05/18] Add send_model_prediction to GrpcHandler --- fedn/network/clients/grpc_handler.py | 29 ++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/fedn/network/clients/grpc_handler.py b/fedn/network/clients/grpc_handler.py index c950603b2..644ddf50b 100644 --- a/fedn/network/clients/grpc_handler.py +++ b/fedn/network/clients/grpc_handler.py @@ -312,6 +312,35 @@ def send_model_validation( return True + def send_model_prediction( + self, sender_name: str, receiver_name: str, receiver_role: fedn.Role, model_id: str, prediction_output: str, correlation_id: str, session_id: str + ) -> bool: + prediction = fedn.ModelPrediction() + prediction.sender.name = sender_name + prediction.sender.role = fedn.WORKER + prediction.receiver.name = receiver_name + prediction.receiver.role = receiver_role + prediction.model_id = model_id + prediction.data = prediction_output + prediction.timestamp.GetCurrentTime() + prediction.correlation_id = correlation_id + prediction.prediction_id = session_id + + try: + logger.info("Sending model prediction to combiner.") + _ = self.combinerStub.SendModelPrediction(prediction, metadata=self.metadata) + except grpc.RpcError as e: + return self._handle_grpc_error( + e, + "SendModelPrediction", + lambda: self.send_model_prediction(sender_name, receiver_name, receiver_role, model_id, prediction_output, correlation_id, session_id), + ) + except Exception as e: + logger.error(f"GRPC (SendModelPrediction): An error occurred: {e}") + self._disconnect() + + return True + def _handle_grpc_error(self, e, method_name: str, sender_function: Callable): status_code = e.code() if status_code == grpc.StatusCode.UNAVAILABLE: From 8c3de60d3c812328ff27b8e119ff72e0304cbb62 Mon Sep 17 00:00:00 2001 From: benjaminastrand Date: Mon, 4 Nov 2024 13:08:30 +0100 Subject: [PATCH 06/18] Add predict to ClientAPI + Ruff --- fedn/network/clients/client_api.py | 96 ++++++++++++++++-------------- 1 file changed, 51 insertions(+), 45 deletions(-) diff --git a/fedn/network/clients/client_api.py b/fedn/network/clients/client_api.py index 88d1eb77c..263b93881 100644 --- a/fedn/network/clients/client_api.py +++ b/fedn/network/clients/client_api.py @@ -1,11 +1,11 @@ import enum +import json import os -import time import threading +import time +import uuid from io import BytesIO from typing import Any, Tuple -import uuid -import json import requests @@ -54,9 +54,11 @@ def get_compute_package_dir_path(): class ClientAPI: - def __init__(self, train_callback: callable = None, validate_callback: callable = None): + def __init__(self, train_callback: callable = None, validate_callback: callable = None, predict_callback: callable = None): self.train_callback: callable = train_callback self.validate_callback: callable = validate_callback + self.predict_callback: callable = predict_callback + path = get_compute_package_dir_path() self._package_runtime = PackageRuntime(path) @@ -69,6 +71,9 @@ def set_train_callback(self, callback: callable): def set_validate_callback(self, callback: callable): self.validate_callback = callback + def set_predict_callback(self, callback: callable): + self.predict_callback = callback + def connect_to_api(self, url: str, token: str, json: dict) -> Tuple[ConnectToApiResult, Any]: # TODO: Use new API endpoint (v1) url_endpoint = f"{url}add_client" @@ -186,7 +191,6 @@ def init_grpchandler(self, config: GrpcConnectionOptions, client_name: str, toke logger.error("Error: Could not initialize GRPC connection") return False - def send_heartbeats(self, client_name: str, client_id: str, update_frequency: float = 2.0): self.grpc_handler.send_heartbeats(client_name=client_name, client_id=client_id, update_frequency=update_frequency) @@ -198,6 +202,8 @@ def _task_stream_callback(self, request): self.update_local_model(request) elif request.type == fedn.StatusType.MODEL_VALIDATION: self.validate_global_model(request) + elif request.type == fedn.StatusType.MODEL_PREDICTION: + self.predict_global_model(request) def update_local_model(self, request): model_id = request.model_id @@ -216,11 +222,7 @@ def update_local_model(self, request): logger.error("No train callback set") return - self.send_status( - f"\t Starting processing of training request for model_id {model_id}", - sesssion_id=request.session_id, - sender_name=self.name - ) + self.send_status(f"\t Starting processing of training request for model_id {model_id}", sesssion_id=request.session_id, sender_name=self.name) logger.info(f"Running train callback with model ID: {model_id}") tic = time.time() @@ -228,21 +230,13 @@ def update_local_model(self, request): meta["processing_time"] = time.time() - tic tic = time.time() - self.send_model_to_combiner( - model=out_model, - id=model_update_id - ) + self.send_model_to_combiner(model=out_model, id=model_update_id) meta["upload_model"] = time.time() - tic meta["fetch_model"] = fetch_model_time meta["config"] = request.data - self.send_model_update( - model_id=model_id, - model_update_id=model_update_id, - meta=meta, - request=request - ) + self.send_model_update(model_id=model_id, model_update_id=model_update_id, meta=meta, request=request) self.send_status( "Model update completed.", @@ -250,17 +244,13 @@ def update_local_model(self, request): type=fedn.StatusType.MODEL_UPDATE, request=request, sesssion_id=request.session_id, - sender_name=self.name + sender_name=self.name, ) def validate_global_model(self, request): model_id = request.model_id - self.send_status( - f"Processing validate request for model_id {model_id}", - sesssion_id=request.session_id, - sender_name=self.name - ) + self.send_status(f"Processing validate request for model_id {model_id}", sesssion_id=request.session_id, sender_name=self.name) in_model = self.get_model_from_combiner(id=model_id, client_name=self.name) @@ -288,10 +278,7 @@ def validate_global_model(self, request): validation.correlation_id = request.correlation_id validation.session_id = request.session_id - result: bool = self.send_model_validation( - metrics=metrics, - request=request - ) + result: bool = self.send_model_validation(metrics=metrics, request=request) if result: self.send_status( @@ -300,7 +287,7 @@ def validate_global_model(self, request): type=fedn.StatusType.MODEL_VALIDATION, request=validation, sesssion_id=request.session_id, - sender_name=self.name + sender_name=self.name, ) else: self.send_status( @@ -308,9 +295,26 @@ def validate_global_model(self, request): log_level=fedn.Status.WARNING, request=request, sesssion_id=request.session_id, - sender_name=self.name + sender_name=self.name, ) + def predict_global_model(self, request): + model_id = request.model_id + model = self.get_model_from_combiner(id=model_id, client_name=self.name) + + if model is None: + logger.error("Could not retrieve model from combiner. Aborting prediction request.") + return + + if not self.predict_callback: + logger.error("No predict callback set") + return + + logger.info(f"Running predict callback with model ID: {model_id}") + prediction = self.predict_callback(model) + + self.send_model_prediction(prediction=prediction, request=request) + def set_name(self, name: str): logger.info(f"Setting client name to: {name}") self.name = name @@ -335,26 +339,17 @@ def send_model_to_combiner(self, model: BytesIO, id: str): def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None): return self.grpc_handler.send_status(msg, log_level, type, request, sesssion_id, sender_name) - def send_model_update(self, - model_id: str, - model_update_id: str, - meta: dict, - request: fedn.TaskRequest - ) -> bool: - + def send_model_update(self, model_id: str, model_update_id: str, meta: dict, request: fedn.TaskRequest) -> bool: return self.grpc_handler.send_model_update( sender_name=self.name, model_id=model_id, model_update_id=model_update_id, receiver_name=request.sender.name, receiver_role=request.sender.role, - meta=meta + meta=meta, ) - def send_model_validation(self, - metrics: dict, - request: fedn.TaskRequest - ) -> bool: + def send_model_validation(self, metrics: dict, request: fedn.TaskRequest) -> bool: return self.grpc_handler.send_model_validation( sender_name=self.name, receiver_name=request.sender.name, @@ -362,7 +357,18 @@ def send_model_validation(self, model_id=request.model_id, metrics=json.dumps(metrics), correlation_id=request.correlation_id, - session_id=request.session_id + session_id=request.session_id, + ) + + def send_model_prediction(self, prediction: dict, request: fedn.TaskRequest) -> bool: + return self.grpc_handler.send_model_prediction( + sender_name=self.name, + receiver_name=request.sender.name, + receiver_role=request.sender.role, + model_id=request.model_id, + prediction_output=json.dumps(prediction), + correlation_id=request.correlation_id, + session_id=request.session_id, ) # Init functions From d2d8cc7625434bd2d0f7d5a05c5b0840f3239924 Mon Sep 17 00:00:00 2001 From: benjaminastrand Date: Tue, 5 Nov 2024 13:34:35 +0100 Subject: [PATCH 07/18] Bug fix - send ModelUpdate in status --- fedn/network/clients/client_api.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/fedn/network/clients/client_api.py b/fedn/network/clients/client_api.py index 263b93881..6e2d41bd9 100644 --- a/fedn/network/clients/client_api.py +++ b/fedn/network/clients/client_api.py @@ -4,6 +4,7 @@ import threading import time import uuid +from datetime import datetime from io import BytesIO from typing import Any, Tuple @@ -236,13 +237,24 @@ def update_local_model(self, request): meta["fetch_model"] = fetch_model_time meta["config"] = request.data + update = fedn.ModelUpdate() + update.sender.name = self.name + update.sender.role = fedn.WORKER + update.sender.client_id = self.client_id + update.receiver.name = request.sender.name + update.receiver.role = request.sender.role + update.model_id = model_id + update.model_update_id = model_update_id + update.timestamp = str(datetime.now()) + update.meta = json.dumps(meta) + self.send_model_update(model_id=model_id, model_update_id=model_update_id, meta=meta, request=request) self.send_status( "Model update completed.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_UPDATE, - request=request, + request=update, sesssion_id=request.session_id, sender_name=self.name, ) From 40cb3d5391d8c0105c34438c7bc33fc2ed94e21e Mon Sep 17 00:00:00 2001 From: benjaminastrand Date: Tue, 5 Nov 2024 14:35:10 +0100 Subject: [PATCH 08/18] Create messages for update, validate and predict in GrpcHandler methods --- fedn/network/clients/client_api.py | 100 ++++++++++++--------------- fedn/network/clients/grpc_handler.py | 85 ++++++++++++++--------- 2 files changed, 99 insertions(+), 86 deletions(-) diff --git a/fedn/network/clients/client_api.py b/fedn/network/clients/client_api.py index 6e2d41bd9..36294ab06 100644 --- a/fedn/network/clients/client_api.py +++ b/fedn/network/clients/client_api.py @@ -4,7 +4,6 @@ import threading import time import uuid -from datetime import datetime from io import BytesIO from typing import Any, Tuple @@ -237,18 +236,9 @@ def update_local_model(self, request): meta["fetch_model"] = fetch_model_time meta["config"] = request.data - update = fedn.ModelUpdate() - update.sender.name = self.name - update.sender.role = fedn.WORKER - update.sender.client_id = self.client_id - update.receiver.name = request.sender.name - update.receiver.role = request.sender.role - update.model_id = model_id - update.model_update_id = model_update_id - update.timestamp = str(datetime.now()) - update.meta = json.dumps(meta) + update = self.create_update_message(model_id=model_id, model_update_id=model_update_id, meta=meta, request=request) - self.send_model_update(model_id=model_id, model_update_id=model_update_id, meta=meta, request=request) + self.send_model_update(update) self.send_status( "Model update completed.", @@ -279,18 +269,9 @@ def validate_global_model(self, request): if metrics is not None: # Send validation - validation = fedn.ModelValidation() - validation.sender.name = self.name - validation.sender.role = fedn.WORKER - validation.receiver.name = request.sender.name - validation.receiver.role = request.sender.role - validation.model_id = str(request.model_id) - validation.data = json.dumps(metrics) - validation.timestamp.GetCurrentTime() - validation.correlation_id = request.correlation_id - validation.session_id = request.session_id - - result: bool = self.send_model_validation(metrics=metrics, request=request) + validation = self.create_validation_message(metrics=metrics, request=request) + + result: bool = self.send_model_validation(validation) if result: self.send_status( @@ -325,34 +306,12 @@ def predict_global_model(self, request): logger.info(f"Running predict callback with model ID: {model_id}") prediction = self.predict_callback(model) - self.send_model_prediction(prediction=prediction, request=request) - - def set_name(self, name: str): - logger.info(f"Setting client name to: {name}") - self.name = name - - def set_client_id(self, client_id: str): - logger.info(f"Setting client ID to: {client_id}") - self.client_id = client_id - - def run(self): - threading.Thread(target=self.send_heartbeats, kwargs={"client_name": self.name, "client_id": self.client_id}, daemon=True).start() - try: - self.listen_to_task_stream(client_name=self.name, client_id=self.client_id) - except KeyboardInterrupt: - logger.info("Client stopped by user.") - - def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) -> BytesIO: - return self.grpc_handler.get_model_from_combiner(id=id, client_name=client_name, timeout=timeout) + prediction_message = self.create_prediction_message(prediction=prediction, request=request) - def send_model_to_combiner(self, model: BytesIO, id: str): - return self.grpc_handler.send_model_to_combiner(model, id) - - def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None): - return self.grpc_handler.send_status(msg, log_level, type, request, sesssion_id, sender_name) + self.send_model_prediction(prediction_message) - def send_model_update(self, model_id: str, model_update_id: str, meta: dict, request: fedn.TaskRequest) -> bool: - return self.grpc_handler.send_model_update( + def create_update_message(self, model_id: str, model_update_id: str, meta: dict, request: fedn.TaskRequest): + return self.grpc_handler.create_update_message( sender_name=self.name, model_id=model_id, model_update_id=model_update_id, @@ -361,8 +320,8 @@ def send_model_update(self, model_id: str, model_update_id: str, meta: dict, req meta=meta, ) - def send_model_validation(self, metrics: dict, request: fedn.TaskRequest) -> bool: - return self.grpc_handler.send_model_validation( + def create_validation_message(self, metrics: dict, request: fedn.TaskRequest): + return self.grpc_handler.create_validation_message( sender_name=self.name, receiver_name=request.sender.name, receiver_role=request.sender.role, @@ -372,8 +331,8 @@ def send_model_validation(self, metrics: dict, request: fedn.TaskRequest) -> boo session_id=request.session_id, ) - def send_model_prediction(self, prediction: dict, request: fedn.TaskRequest) -> bool: - return self.grpc_handler.send_model_prediction( + def create_prediction_message(self, prediction: dict, request: fedn.TaskRequest): + return self.grpc_handler.create_prediction_message( sender_name=self.name, receiver_name=request.sender.name, receiver_role=request.sender.role, @@ -383,6 +342,39 @@ def send_model_prediction(self, prediction: dict, request: fedn.TaskRequest) -> session_id=request.session_id, ) + def set_name(self, name: str): + logger.info(f"Setting client name to: {name}") + self.name = name + + def set_client_id(self, client_id: str): + logger.info(f"Setting client ID to: {client_id}") + self.client_id = client_id + + def run(self): + threading.Thread(target=self.send_heartbeats, kwargs={"client_name": self.name, "client_id": self.client_id}, daemon=True).start() + try: + self.listen_to_task_stream(client_name=self.name, client_id=self.client_id) + except KeyboardInterrupt: + logger.info("Client stopped by user.") + + def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) -> BytesIO: + return self.grpc_handler.get_model_from_combiner(id=id, client_name=client_name, timeout=timeout) + + def send_model_to_combiner(self, model: BytesIO, id: str): + return self.grpc_handler.send_model_to_combiner(model, id) + + def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None): + return self.grpc_handler.send_status(msg, log_level, type, request, sesssion_id, sender_name) + + def send_model_update(self, update: fedn.ModelUpdate) -> bool: + return self.grpc_handler.send_model_update(update) + + def send_model_validation(self, validation: fedn.ModelValidation) -> bool: + return self.grpc_handler.send_model_validation(validation) + + def send_model_prediction(self, prediction: fedn.ModelPrediction) -> bool: + return self.grpc_handler.send_model_prediction(prediction) + # Init functions def init_remote_compute_package(self, url: str, token: str, package_checksum: str = None) -> bool: result: bool = self.download_compute_package(url, token) diff --git a/fedn/network/clients/grpc_handler.py b/fedn/network/clients/grpc_handler.py index 644ddf50b..0d3f472ea 100644 --- a/fedn/network/clients/grpc_handler.py +++ b/fedn/network/clients/grpc_handler.py @@ -250,7 +250,7 @@ def send_model_to_combiner(self, model: BytesIO, id: str): return result - def send_model_update( + def create_update_message( self, sender_name: str, model_id: str, @@ -270,22 +270,18 @@ def send_model_update( update.timestamp = str(datetime.now()) update.meta = json.dumps(meta) - try: - logger.info("Sending model update to combiner.") - _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) - except grpc.RpcError as e: - return self._handle_grpc_error( - e, "SendModelUpdate", lambda: self.send_model_update(sender_name, model_id, model_update_id, receiver_name, receiver_role, meta) - ) - except Exception as e: - logger.error(f"GRPC (SendModelUpdate): An error occurred: {e}") - self._disconnect() - - return True + return update - def send_model_validation( - self, sender_name: str, receiver_name: str, receiver_role: fedn.Role, model_id: str, metrics: str, correlation_id: str, session_id: str - ) -> bool: + def create_validation_message( + self, + sender_name: str, + receiver_name: str, + receiver_role: fedn.Role, + model_id: str, + metrics: str, + correlation_id: str, + session_id: str, + ): validation = fedn.ModelValidation() validation.sender.name = sender_name validation.sender.role = fedn.WORKER @@ -297,6 +293,44 @@ def send_model_validation( validation.correlation_id = correlation_id validation.session_id = session_id + return validation + + def create_prediction_message( + self, + sender_name: str, + receiver_name: str, + receiver_role: fedn.Role, + model_id: str, + prediction_output: str, + correlation_id: str, + session_id: str, + ): + prediction = fedn.ModelPrediction() + prediction.sender.name = sender_name + prediction.sender.role = fedn.WORKER + prediction.receiver.name = receiver_name + prediction.receiver.role = receiver_role + prediction.model_id = model_id + prediction.data = prediction_output + prediction.timestamp.GetCurrentTime() + prediction.correlation_id = correlation_id + prediction.prediction_id = session_id + + return prediction + + def send_model_update(self, update: fedn.ModelUpdate): + try: + logger.info("Sending model update to combiner.") + _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) + except grpc.RpcError as e: + return self._handle_grpc_error(e, "SendModelUpdate", lambda: self.send_model_update(update)) + except Exception as e: + logger.error(f"GRPC (SendModelUpdate): An error occurred: {e}") + self._disconnect() + + return True + + def send_model_validation(self, validation: fedn.ModelValidation) -> bool: try: logger.info("Sending model validation to combiner.") _ = self.combinerStub.SendModelValidation(validation, metadata=self.metadata) @@ -304,7 +338,7 @@ def send_model_validation( return self._handle_grpc_error( e, "SendModelValidation", - lambda: self.send_model_validation(sender_name, receiver_name, receiver_role, model_id, metrics, correlation_id, session_id), + lambda: self.send_model_validation(validation), ) except Exception as e: logger.error(f"GRPC (SendModelValidation): An error occurred: {e}") @@ -312,20 +346,7 @@ def send_model_validation( return True - def send_model_prediction( - self, sender_name: str, receiver_name: str, receiver_role: fedn.Role, model_id: str, prediction_output: str, correlation_id: str, session_id: str - ) -> bool: - prediction = fedn.ModelPrediction() - prediction.sender.name = sender_name - prediction.sender.role = fedn.WORKER - prediction.receiver.name = receiver_name - prediction.receiver.role = receiver_role - prediction.model_id = model_id - prediction.data = prediction_output - prediction.timestamp.GetCurrentTime() - prediction.correlation_id = correlation_id - prediction.prediction_id = session_id - + def send_model_prediction(self, prediction: fedn.ModelPrediction) -> bool: try: logger.info("Sending model prediction to combiner.") _ = self.combinerStub.SendModelPrediction(prediction, metadata=self.metadata) @@ -333,7 +354,7 @@ def send_model_prediction( return self._handle_grpc_error( e, "SendModelPrediction", - lambda: self.send_model_prediction(sender_name, receiver_name, receiver_role, model_id, prediction_output, correlation_id, session_id), + lambda: self.send_model_prediction(prediction), ) except Exception as e: logger.error(f"GRPC (SendModelPrediction): An error occurred: {e}") From d3b6998a557f3cfaecf43477e0ea893aee42174a Mon Sep 17 00:00:00 2001 From: benjaminastrand Date: Tue, 5 Nov 2024 14:52:25 +0100 Subject: [PATCH 09/18] Update README - remove printouts and round_id --- fedn/network/clients/README.rst | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/fedn/network/clients/README.rst b/fedn/network/clients/README.rst index 013a6a75a..5dbc63ff0 100644 --- a/fedn/network/clients/README.rst +++ b/fedn/network/clients/README.rst @@ -25,8 +25,6 @@ Step-by-Step Instructions .. code-block:: python import argparse - import json - import uuid from fedn.network.clients.client_api import ClientAPI, ConnectToApiResult @@ -45,14 +43,7 @@ Step-by-Step Instructions "lr": 1, } - config = { - "round_id": 1, - } - - metadata = { - "training_metadata": training_metadata, - "config": json.dumps(config), - } + metadata = {"training_metadata": training_metadata} # Do your training here, out_model is your result... out_model = in_model @@ -60,7 +51,6 @@ Step-by-Step Instructions return out_model, metadata def on_validate(in_model): - # Calculate metrics here... metrics = { "test_accuracy": 0.9, @@ -70,12 +60,17 @@ Step-by-Step Instructions } return metrics - def main(api_url: str, api_port: int, token: str = None): - print(f"API URL: {api_url}") - print(f"API Token: {token or "-"}") - print(f"API Port: {api_port or "-"}") + def on_predict(in_model): + # Do your prediction here... + prediction = { + "prediction": 1, + "confidence": 0.9, + } + return prediction - client_api = ClientAPI(train_callback=on_train, validate_callback=on_validate) + + def main(api_url: str, api_port: int, token: str = None): + client_api = ClientAPI(train_callback=on_train, validate_callback=on_validate, predict_callback=on_predict) url = get_api_url(api_url, api_port) From 9b0e82a3b056218c01754ec74900244dc42beaae Mon Sep 17 00:00:00 2001 From: benjaminastrand Date: Wed, 6 Nov 2024 17:31:53 +0100 Subject: [PATCH 10/18] Changed name ClientAPI -> FednClient --- fedn/__init__.py | 1 + fedn/network/clients/README.rst | 14 +++++----- fedn/network/clients/client_v2.py | 26 +++++++++---------- .../clients/{client_api.py => fedn_client.py} | 2 +- 4 files changed, 22 insertions(+), 21 deletions(-) rename fedn/network/clients/{client_api.py => fedn_client.py} (99%) diff --git a/fedn/__init__.py b/fedn/__init__.py index 703eab7b2..122c2e2bb 100644 --- a/fedn/__init__.py +++ b/fedn/__init__.py @@ -3,6 +3,7 @@ from os.path import basename, dirname, isfile from fedn.network.api.client import APIClient +from fedn.network.clients.fedn_client import FednClient # flake8: noqa diff --git a/fedn/network/clients/README.rst b/fedn/network/clients/README.rst index 5dbc63ff0..af5e2ac97 100644 --- a/fedn/network/clients/README.rst +++ b/fedn/network/clients/README.rst @@ -26,7 +26,7 @@ Step-by-Step Instructions import argparse - from fedn.network.clients.client_api import ClientAPI, ConnectToApiResult + from fedn.network.clients.fedn_client import FednClient, ConnectToApiResult def get_api_url(api_url: str, api_port: int): @@ -70,15 +70,15 @@ Step-by-Step Instructions def main(api_url: str, api_port: int, token: str = None): - client_api = ClientAPI(train_callback=on_train, validate_callback=on_validate, predict_callback=on_predict) + fedn_client = FednClient(train_callback=on_train, validate_callback=on_validate, predict_callback=on_predict) url = get_api_url(api_url, api_port) name = input("Enter Client Name: ") - client_api.set_name(name) + fedn_client.set_name(name) client_id = str(uuid.uuid4()) - client_api.set_client_id(client_id) + fedn_client.set_client_id(client_id) controller_config = { "name": name, @@ -87,18 +87,18 @@ Step-by-Step Instructions "preferred_combiner": "", } - result, combiner_config = client_api.connect_to_api(url, token, controller_config) + result, combiner_config = fedn_client.connect_to_api(url, token, controller_config) if result != ConnectToApiResult.Assigned: print("Failed to connect to API, exiting.") return - result: bool = client_api.init_grpchandler(config=combiner_config, client_name=client_id, token=token) + result: bool = fedn_client.init_grpchandler(config=combiner_config, client_name=client_id, token=token) if not result: return - client_api.run() + fedn_client.run() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Client Example") diff --git a/fedn/network/clients/client_v2.py b/fedn/network/clients/client_v2.py index fa1bf8e44..6d1f52fb4 100644 --- a/fedn/network/clients/client_v2.py +++ b/fedn/network/clients/client_v2.py @@ -8,7 +8,7 @@ from fedn.common.config import FEDN_CUSTOM_URL_PREFIX from fedn.common.log_config import logger -from fedn.network.clients.client_api import ClientAPI, ConnectToApiResult, GrpcConnectionOptions +from fedn.network.clients.fedn_client import ConnectToApiResult, FednClient, GrpcConnectionOptions from fedn.network.combiner.modelservice import get_tmp_path from fedn.utils.helpers.helpers import get_helper @@ -65,7 +65,7 @@ def __init__( self.fedn_api_url = get_url(self.api_url, self.api_port) - self.client_api: ClientAPI = ClientAPI() + self.fedn_client: FednClient = FednClient() self.helper = None @@ -76,7 +76,7 @@ def _connect_to_api(self) -> Tuple[bool, dict]: if result == ConnectToApiResult.ComputePackageMissing: logger.info("Retrying in 3 seconds") time.sleep(3) - result, response = self.client_api.connect_to_api(self.fedn_api_url, self.token, self.client_obj.to_json()) + result, response = self.fedn_client.connect_to_api(self.fedn_api_url, self.token, self.client_obj.to_json()) if result == ConnectToApiResult.Assigned: return True, response @@ -95,32 +95,32 @@ def start(self): return if self.client_obj.package == "remote": - result = self.client_api.init_remote_compute_package(url=self.fedn_api_url, token=self.token, package_checksum=self.package_checksum) + result = self.fedn_client.init_remote_compute_package(url=self.fedn_api_url, token=self.token, package_checksum=self.package_checksum) if not result: return else: - result = self.client_api.init_local_compute_package() + result = self.fedn_client.init_local_compute_package() if not result: return self.set_helper(combiner_config) - result: bool = self.client_api.init_grpchandler(config=combiner_config, client_name=self.client_obj.client_id, token=self.token) + result: bool = self.fedn_client.init_grpchandler(config=combiner_config, client_name=self.client_obj.client_id, token=self.token) if not result: return logger.info("-----------------------------") - self.client_api.set_train_callback(self.on_train) - self.client_api.set_validate_callback(self.on_validation) + self.fedn_client.set_train_callback(self.on_train) + self.fedn_client.set_validate_callback(self.on_validation) - self.client_api.set_name(self.client_obj.name) - self.client_api.set_client_id(self.client_obj.client_id) + self.fedn_client.set_name(self.client_obj.name) + self.fedn_client.set_client_id(self.client_obj.client_id) - self.client_api.run() + self.fedn_client.run() def set_helper(self, response: GrpcConnectionOptions = None): helper_type = response.get("helper_type", None) @@ -160,7 +160,7 @@ def _process_training_request(self, in_model: BytesIO) -> Tuple[BytesIO, dict]: tic = time.time() - self.client_api.dispatcher.run_cmd("train {} {}".format(inpath, outpath)) + self.fedn_client.dispatcher.run_cmd("train {} {}".format(inpath, outpath)) meta["exec_training"] = time.time() - tic @@ -202,7 +202,7 @@ def _process_validation_request(self, in_model: BytesIO) -> dict: fh.write(in_model.getbuffer()) outpath = get_tmp_path() - self.client_api.dispatcher.run_cmd(f"validate {inpath} {outpath}") + self.fedn_client.dispatcher.run_cmd(f"validate {inpath} {outpath}") with open(outpath, "r") as fh: metrics = json.loads(fh.read()) diff --git a/fedn/network/clients/client_api.py b/fedn/network/clients/fedn_client.py similarity index 99% rename from fedn/network/clients/client_api.py rename to fedn/network/clients/fedn_client.py index 36294ab06..aab3df10e 100644 --- a/fedn/network/clients/client_api.py +++ b/fedn/network/clients/fedn_client.py @@ -53,7 +53,7 @@ def get_compute_package_dir_path(): return result -class ClientAPI: +class FednClient: def __init__(self, train_callback: callable = None, validate_callback: callable = None, predict_callback: callable = None): self.train_callback: callable = train_callback self.validate_callback: callable = validate_callback From 1d3cb0a6623965e36a113d906fa91c58f59fa8a4 Mon Sep 17 00:00:00 2001 From: benjaminastrand Date: Thu, 7 Nov 2024 14:24:54 +0100 Subject: [PATCH 11/18] Fix - replace client_name with client_id in get_model_from_combiner --- fedn/network/clients/fedn_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fedn/network/clients/fedn_client.py b/fedn/network/clients/fedn_client.py index a6a26b245..828758131 100644 --- a/fedn/network/clients/fedn_client.py +++ b/fedn/network/clients/fedn_client.py @@ -210,7 +210,7 @@ def update_local_model(self, request): model_update_id = str(uuid.uuid4()) tic = time.time() - in_model = self.get_model_from_combiner(id=model_id, client_name=self.name) + in_model = self.get_model_from_combiner(id=model_id, client_id=self.client_id) if in_model is None: logger.error("Could not retrieve model from combiner. Aborting training request.") @@ -254,7 +254,7 @@ def validate_global_model(self, request): self.send_status(f"Processing validate request for model_id {model_id}", sesssion_id=request.session_id, sender_name=self.name) - in_model = self.get_model_from_combiner(id=model_id, client_name=self.name) + in_model = self.get_model_from_combiner(id=model_id, client_id=self.client_id) if in_model is None: logger.error("Could not retrieve model from combiner. Aborting validation request.") @@ -293,7 +293,7 @@ def validate_global_model(self, request): def predict_global_model(self, request): model_id = request.model_id - model = self.get_model_from_combiner(id=model_id, client_name=self.name) + model = self.get_model_from_combiner(id=model_id, client_id=self.client_id) if model is None: logger.error("Could not retrieve model from combiner. Aborting prediction request.") @@ -358,7 +358,7 @@ def run(self): logger.info("Client stopped by user.") def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> BytesIO: - return self.grpc_handler.get_model_from_combiner(id=id, client_name=client_id, timeout=timeout) + return self.grpc_handler.get_model_from_combiner(id=id, client_id=client_id, timeout=timeout) def send_model_to_combiner(self, model: BytesIO, id: str): return self.grpc_handler.send_model_to_combiner(model, id) From c25bb2220dcea2ac15885e88b34e73b0b8f1454d Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Thu, 7 Nov 2024 15:19:52 +0000 Subject: [PATCH 12/18] change start-v2 -> start --- fedn/cli/client_cmd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fedn/cli/client_cmd.py b/fedn/cli/client_cmd.py index b6c33d8af..5f5c979fa 100644 --- a/fedn/cli/client_cmd.py +++ b/fedn/cli/client_cmd.py @@ -70,7 +70,7 @@ def list_clients(ctx, protocol: str, host: str, port: str, token: str = None, n_ click.echo(f"Error: Could not connect to {url}") -@client_cmd.command("start") +@client_cmd.command("start-v1") @click.option("-d", "--discoverhost", required=False, help="Hostname for discovery services(reducer).") @click.option("-p", "--discoverport", required=False, help="Port for discovery services (reducer).") @click.option("--token", required=False, help="Set token provided by reducer if enabled") @@ -208,7 +208,7 @@ def _complement_client_params(config: dict): click.echo(f"Protocol missing, complementing api_url with protocol: {result}") -@client_cmd.command("start-v2") +@client_cmd.command("start") @click.option("-u", "--api-url", required=False, help="Hostname for fedn api.") @click.option("-p", "--api-port", required=False, help="Port for discovery services (reducer).") @click.option("--token", required=False, help="Set token provided by reducer if enabled") From 7e11b81328f65b5ea38911bf7dd8b21ca5c4bac1 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Fri, 8 Nov 2024 11:09:50 +0000 Subject: [PATCH 13/18] fix print_logs --- .ci/tests/examples/print_logs.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.ci/tests/examples/print_logs.sh b/.ci/tests/examples/print_logs.sh index 4cb7f650e..27823f126 100755 --- a/.ci/tests/examples/print_logs.sh +++ b/.ci/tests/examples/print_logs.sh @@ -1,7 +1,7 @@ #!/bin/bash -service = $1 -example = $2 -helper = $3 +service = "$1" +example = "$2" +helper = "$3" if [ "$service" == "minio" ]; then echo "Minio logs" From 9bdb95befefbcb213dccd5646d5a2dce855b9fdf Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Fri, 8 Nov 2024 11:11:17 +0000 Subject: [PATCH 14/18] fix --- .ci/tests/examples/print_logs.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.ci/tests/examples/print_logs.sh b/.ci/tests/examples/print_logs.sh index 27823f126..e2b079026 100755 --- a/.ci/tests/examples/print_logs.sh +++ b/.ci/tests/examples/print_logs.sh @@ -1,7 +1,7 @@ #!/bin/bash -service = "$1" -example = "$2" -helper = "$3" +service="$1" +example="$2" +helper="$3" if [ "$service" == "minio" ]; then echo "Minio logs" From fd2b3df62dfb36d7c109c1aa7fbef73161dcd5ed Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Fri, 8 Nov 2024 11:23:54 +0000 Subject: [PATCH 15/18] fix --- docker-compose.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index e92a018bc..09d24e927 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -143,7 +143,7 @@ services: - ${HOST_REPO_DIR:-.}/fedn:/app/fedn entrypoint: [ "sh", "-c" ] command: - - "/venv/bin/pip install --no-cache-dir -e . && /venv/bin/fedn client start --init config/settings-client.yaml" + - "/venv/bin/pip install --no-cache-dir -e . && /venv/bin/fedn client start --api-url localhost --api-port 8092" deploy: replicas: 0 depends_on: From 082bc8ce9e6d5a4ff0483cbcb5624fdd08cb40aa Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Fri, 8 Nov 2024 14:00:08 +0000 Subject: [PATCH 16/18] remove echo --- fedn/cli/client_cmd.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fedn/cli/client_cmd.py b/fedn/cli/client_cmd.py index 5f5c979fa..666bc6545 100644 --- a/fedn/cli/client_cmd.py +++ b/fedn/cli/client_cmd.py @@ -239,8 +239,6 @@ def client_start_v2_cmd( helper_type: str, init: str, ): - click.echo(click.style("\n*** fedn client start-v2 is experimental ***\n", blink=True, bold=True, fg="red")) - package = "local" if local_package else "remote" config = { From 21f9d05e2f01fbe1a2da054633027d3ffc709721 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Fri, 8 Nov 2024 14:53:58 +0000 Subject: [PATCH 17/18] fix --- docker-compose.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 09d24e927..f020d9a0c 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -143,7 +143,7 @@ services: - ${HOST_REPO_DIR:-.}/fedn:/app/fedn entrypoint: [ "sh", "-c" ] command: - - "/venv/bin/pip install --no-cache-dir -e . && /venv/bin/fedn client start --api-url localhost --api-port 8092" + - "/venv/bin/pip install --no-cache-dir -e . && /venv/bin/fedn client start --api-url http://api-server:8092" deploy: replicas: 0 depends_on: From 38969e84bc2320645c20c61d5d56624eeb2f3d70 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Mon, 11 Nov 2024 13:03:49 +0000 Subject: [PATCH 18/18] fix endpoint in wait_for --- .ci/tests/examples/wait_for.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/tests/examples/wait_for.py b/.ci/tests/examples/wait_for.py index 1564eadfe..3c4fa9383 100644 --- a/.ci/tests/examples/wait_for.py +++ b/.ci/tests/examples/wait_for.py @@ -40,7 +40,7 @@ def _test_rounds(n_rounds): def _test_nodes(n_nodes, node_type, reducer_host='localhost', reducer_port='8092'): try: - endpoint = "list_clients" if node_type == "client" else "list_combiners" + endpoint = "api/v1/clients/" if node_type == "client" else "api/v1/combiners/" response = requests.get( f'http://{reducer_host}:{reducer_port}/{endpoint}', verify=False)