diff --git a/.ci/tests/examples/print_logs.sh b/.ci/tests/examples/print_logs.sh index 4cb7f650e..e2b079026 100755 --- a/.ci/tests/examples/print_logs.sh +++ b/.ci/tests/examples/print_logs.sh @@ -1,7 +1,7 @@ #!/bin/bash -service = $1 -example = $2 -helper = $3 +service="$1" +example="$2" +helper="$3" if [ "$service" == "minio" ]; then echo "Minio logs" diff --git a/.ci/tests/examples/wait_for.py b/.ci/tests/examples/wait_for.py index 1564eadfe..3c4fa9383 100644 --- a/.ci/tests/examples/wait_for.py +++ b/.ci/tests/examples/wait_for.py @@ -40,7 +40,7 @@ def _test_rounds(n_rounds): def _test_nodes(n_nodes, node_type, reducer_host='localhost', reducer_port='8092'): try: - endpoint = "list_clients" if node_type == "client" else "list_combiners" + endpoint = "api/v1/clients/" if node_type == "client" else "api/v1/combiners/" response = requests.get( f'http://{reducer_host}:{reducer_port}/{endpoint}', verify=False) diff --git a/docker-compose.yaml b/docker-compose.yaml index e92a018bc..f020d9a0c 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -143,7 +143,7 @@ services: - ${HOST_REPO_DIR:-.}/fedn:/app/fedn entrypoint: [ "sh", "-c" ] command: - - "/venv/bin/pip install --no-cache-dir -e . && /venv/bin/fedn client start --init config/settings-client.yaml" + - "/venv/bin/pip install --no-cache-dir -e . && /venv/bin/fedn client start --api-url http://api-server:8092" deploy: replicas: 0 depends_on: diff --git a/fedn/__init__.py b/fedn/__init__.py index 703eab7b2..122c2e2bb 100644 --- a/fedn/__init__.py +++ b/fedn/__init__.py @@ -3,6 +3,7 @@ from os.path import basename, dirname, isfile from fedn.network.api.client import APIClient +from fedn.network.clients.fedn_client import FednClient # flake8: noqa diff --git a/fedn/cli/client_cmd.py b/fedn/cli/client_cmd.py index b6c33d8af..666bc6545 100644 --- a/fedn/cli/client_cmd.py +++ b/fedn/cli/client_cmd.py @@ -70,7 +70,7 @@ def list_clients(ctx, protocol: str, host: str, port: str, token: str = None, n_ click.echo(f"Error: Could not connect to {url}") -@client_cmd.command("start") +@client_cmd.command("start-v1") @click.option("-d", "--discoverhost", required=False, help="Hostname for discovery services(reducer).") @click.option("-p", "--discoverport", required=False, help="Port for discovery services (reducer).") @click.option("--token", required=False, help="Set token provided by reducer if enabled") @@ -208,7 +208,7 @@ def _complement_client_params(config: dict): click.echo(f"Protocol missing, complementing api_url with protocol: {result}") -@client_cmd.command("start-v2") +@client_cmd.command("start") @click.option("-u", "--api-url", required=False, help="Hostname for fedn api.") @click.option("-p", "--api-port", required=False, help="Port for discovery services (reducer).") @click.option("--token", required=False, help="Set token provided by reducer if enabled") @@ -239,8 +239,6 @@ def client_start_v2_cmd( helper_type: str, init: str, ): - click.echo(click.style("\n*** fedn client start-v2 is experimental ***\n", blink=True, bold=True, fg="red")) - package = "local" if local_package else "remote" config = { diff --git a/fedn/network/clients/README.rst b/fedn/network/clients/README.rst new file mode 100644 index 000000000..af5e2ac97 --- /dev/null +++ b/fedn/network/clients/README.rst @@ -0,0 +1,121 @@ +Creating Your Own Client +======================== + +This guide will help you create your own client for the FEDn network. + +Step-by-Step Instructions +------------------------- + +1. **Create a virutal environment**: Start by creating a virtual environment and activating it. + +.. code-block:: bash + + python3 -m venv fedn-env + source fedn-env/bin/activate + + +2. **Install FEDn**: Install the FEDn package by running the following command: + +.. code-block:: bash + + pip install fedn + +3. **Create your client**: Copy and paste the code Below into a new file called `client_example.py`. + +.. code-block:: python + + import argparse + + from fedn.network.clients.fedn_client import FednClient, ConnectToApiResult + + + def get_api_url(api_url: str, api_port: int): + url = f"{api_url}:{api_port}" if api_port else api_url + if not url.endswith("/"): + url += "/" + return url + + def on_train(in_model): + training_metadata = { + "num_examples": 1, + "batch_size": 1, + "epochs": 1, + "lr": 1, + } + + metadata = {"training_metadata": training_metadata} + + # Do your training here, out_model is your result... + out_model = in_model + + return out_model, metadata + + def on_validate(in_model): + # Calculate metrics here... + metrics = { + "test_accuracy": 0.9, + "test_loss": 0.1, + "train_accuracy": 0.8, + "train_loss": 0.2, + } + return metrics + + def on_predict(in_model): + # Do your prediction here... + prediction = { + "prediction": 1, + "confidence": 0.9, + } + return prediction + + + def main(api_url: str, api_port: int, token: str = None): + fedn_client = FednClient(train_callback=on_train, validate_callback=on_validate, predict_callback=on_predict) + + url = get_api_url(api_url, api_port) + + name = input("Enter Client Name: ") + fedn_client.set_name(name) + + client_id = str(uuid.uuid4()) + fedn_client.set_client_id(client_id) + + controller_config = { + "name": name, + "client_id": client_id, + "package": "local", + "preferred_combiner": "", + } + + result, combiner_config = fedn_client.connect_to_api(url, token, controller_config) + + if result != ConnectToApiResult.Assigned: + print("Failed to connect to API, exiting.") + return + + result: bool = fedn_client.init_grpchandler(config=combiner_config, client_name=client_id, token=token) + + if not result: + return + + fedn_client.run() + + if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Client Example") + parser.add_argument("--api-url", type=str, required=True, help="The API URL") + parser.add_argument("--api-port", type=int, required=False, help="The API Port") + parser.add_argument("--token", type=str, required=False, help="The API Token") + + args = parser.parse_args() + main(args.api_url, args.api_port, args.token) + + +4. **Run the client**: Run the client by executing the following command: + +.. code-block:: bash + + python client_example.py --api-url --token + +Replace `` and `` with the URL and token of the FEDn API. *Example when running a local FEDn instance: python client_example.py --api-url http://localhost:8092* + +5. **Start training**: Create a session and start training by using either the FEDn CLI or the FEDn UI. diff --git a/fedn/network/clients/client_v2.py b/fedn/network/clients/client_v2.py index 7f5ee93cf..6d1f52fb4 100644 --- a/fedn/network/clients/client_v2.py +++ b/fedn/network/clients/client_v2.py @@ -1,15 +1,14 @@ import io import json import os -import threading import time import uuid +from io import BytesIO from typing import Tuple -import fedn.network.grpc.fedn_pb2 as fedn from fedn.common.config import FEDN_CUSTOM_URL_PREFIX from fedn.common.log_config import logger -from fedn.network.clients.client_api import ClientAPI, ConnectToApiResult, GrpcConnectionOptions +from fedn.network.clients.fedn_client import ConnectToApiResult, FednClient, GrpcConnectionOptions from fedn.network.combiner.modelservice import get_tmp_path from fedn.utils.helpers.helpers import get_helper @@ -66,18 +65,18 @@ def __init__( self.fedn_api_url = get_url(self.api_url, self.api_port) - self.client_api: ClientAPI = ClientAPI() + self.fedn_client: FednClient = FednClient() self.helper = None def _connect_to_api(self) -> Tuple[bool, dict]: result = None - while not result or result == ConnectToApiResult.ComputePackgeMissing: - if result == ConnectToApiResult.ComputePackgeMissing: + while not result or result == ConnectToApiResult.ComputePackageMissing: + if result == ConnectToApiResult.ComputePackageMissing: logger.info("Retrying in 3 seconds") time.sleep(3) - result, response = self.client_api.connect_to_api(self.fedn_api_url, self.token, self.client_obj.to_json()) + result, response = self.fedn_client.connect_to_api(self.fedn_api_url, self.token, self.client_obj.to_json()) if result == ConnectToApiResult.Assigned: return True, response @@ -96,41 +95,32 @@ def start(self): return if self.client_obj.package == "remote": - result = self.client_api.init_remote_compute_package(url=self.fedn_api_url, token=self.token, package_checksum=self.package_checksum) + result = self.fedn_client.init_remote_compute_package(url=self.fedn_api_url, token=self.token, package_checksum=self.package_checksum) if not result: return else: - result = self.client_api.init_local_compute_package() + result = self.fedn_client.init_local_compute_package() if not result: return self.set_helper(combiner_config) - result: bool = self.client_api.init_grpchandler(config=combiner_config, client_name=self.client_obj.client_id, token=self.token) + result: bool = self.fedn_client.init_grpchandler(config=combiner_config, client_name=self.client_obj.client_id, token=self.token) if not result: return logger.info("-----------------------------") - threading.Thread( - target=self.client_api.send_heartbeats, kwargs={"client_name": self.client_obj.name, "client_id": self.client_obj.client_id}, daemon=True - ).start() + self.fedn_client.set_train_callback(self.on_train) + self.fedn_client.set_validate_callback(self.on_validation) - self.client_api.subscribe("train", self.on_train) - self.client_api.subscribe("validation", self.on_validation) + self.fedn_client.set_name(self.client_obj.name) + self.fedn_client.set_client_id(self.client_obj.client_id) - threading.Thread( - target=self.client_api.listen_to_task_stream, kwargs={"client_name": self.client_obj.name, "client_id": self.client_obj.client_id}, daemon=True - ).start() - - stop_event = threading.Event() - try: - stop_event.wait() - except KeyboardInterrupt: - logger.info("Client stopped by user.") + self.fedn_client.run() def set_helper(self, response: GrpcConnectionOptions = None): helper_type = response.get("helper_type", None) @@ -142,67 +132,43 @@ def set_helper(self, response: GrpcConnectionOptions = None): # Priority: helper_type from constructor > helper_type from response > default helper_type self.helper = get_helper(helper_type_to_use) - def on_train(self, request): - logger.info("Received train request") - self._process_training_request(request) + def on_train(self, in_model): + out_model, meta = self._process_training_request(in_model) + return out_model, meta - def on_validation(self, request): - logger.info("Received validation request") - self._process_validation_request(request) + def on_validation(self, in_model): + metrics = self._process_validation_request(in_model) + return metrics - def _process_training_request(self, request) -> Tuple[str, dict]: + def _process_training_request(self, in_model: BytesIO) -> Tuple[BytesIO, dict]: """Process a training (model update) request. - :param model_id: The model id of the model to be updated. - :type model_id: str - :param session_id: The id of the current session - :type session_id: str - :return: The model id of the updated model, or None if the update failed. And a dict with metadata. + :param in_model: The model to be updated. + :type in_model: BytesIO + :return: The updated model, or None if the update failed. And a dict with metadata. :rtype: tuple """ - model_id: str = request.model_id - session_id: str = request.session_id - - self.client_api.send_status( - f"\t Starting processing of training request for model_id {model_id}", sesssion_id=session_id, sender_name=self.client_obj.name - ) - try: meta = {} - tic = time.time() - - model = self.client_api.get_model_from_combiner(id=str(model_id), client_id=self.client_obj.client_id) - - if model is None: - logger.error("Could not retrieve model from combiner. Aborting training request.") - return None, None - - meta["fetch_model"] = time.time() - tic inpath = self.helper.get_tmp_path() with open(inpath, "wb") as fh: - fh.write(model.getbuffer()) + fh.write(in_model.getbuffer()) outpath = self.helper.get_tmp_path() tic = time.time() - self.client_api.dispatcher.run_cmd("train {} {}".format(inpath, outpath)) + self.fedn_client.dispatcher.run_cmd("train {} {}".format(inpath, outpath)) meta["exec_training"] = time.time() - tic - tic = time.time() out_model = None with open(outpath, "rb") as fr: out_model = io.BytesIO(fr.read()) - # Stream model update to combiner server - updated_model_id = uuid.uuid4() - self.client_api.send_model_to_combiner(out_model, str(updated_model_id)) - meta["upload_model"] = time.time() - tic - # Read the metadata file with open(outpath + "-metadata", "r") as fh: training_metadata = json.loads(fh.read()) @@ -216,65 +182,27 @@ def _process_training_request(self, request) -> Tuple[str, dict]: except Exception as e: logger.error("Could not process training request due to error: {}".format(e)) - updated_model_id = None + out_model = None meta = {"status": "failed", "error": str(e)} - if meta is not None: - processing_time = time.time() - tic - meta["processing_time"] = processing_time - meta["config"] = request.data - - if model_id is not None: - # Send model update to combiner - - self.client_api.send_model_update( - sender_name=self.client_obj.name, - sender_role=fedn.WORKER, - client_id=self.client_obj.client_id, - model_id=model_id, - model_update_id=str(updated_model_id), - receiver_name=request.sender.name, - receiver_role=request.sender.role, - meta=meta, - ) - - self.client_api.send_status( - "Model update completed.", - log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_UPDATE, - request=request, - sesssion_id=session_id, - sender_name=self.client_obj.name, - ) - - def _process_validation_request(self, request): + return out_model, meta + + def _process_validation_request(self, in_model: BytesIO) -> dict: """Process a validation request. - :param model_id: The model id of the model to be validated. - :type model_id: str - :param session_id: The id of the current session. - :type session_id: str + :param in_model: The model to be validated. + :type in_model: BytesIO :return: The validation metrics, or None if validation failed. :rtype: dict """ - model_id: str = request.model_id - session_id: str = request.session_id - cmd = "validate" - - self.client_api.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id, sender_name=self.client_obj.name) - try: - model = self.client_api.get_model_from_combiner(id=str(model_id), client_id=self.client_obj.client_id) - if model is None: - logger.error("Could not retrieve model from combiner. Aborting validation request.") - return inpath = self.helper.get_tmp_path() with open(inpath, "wb") as fh: - fh.write(model.getbuffer()) + fh.write(in_model.getbuffer()) outpath = get_tmp_path() - self.client_api.dispatcher.run_cmd(f"{cmd} {inpath} {outpath}") + self.fedn_client.dispatcher.run_cmd(f"validate {inpath} {outpath}") with open(outpath, "r") as fh: metrics = json.loads(fh.read()) @@ -284,45 +212,6 @@ def _process_validation_request(self, request): except Exception as e: logger.warning("Validation failed with exception {}".format(e)) + metrics = None - if metrics is not None: - # Send validation - validation = fedn.ModelValidation() - validation.sender.name = self.client_obj.name - validation.sender.role = fedn.WORKER - validation.receiver.name = request.sender.name - validation.receiver.role = request.sender.role - validation.model_id = str(request.model_id) - validation.data = json.dumps(metrics) - validation.timestamp.GetCurrentTime() - validation.correlation_id = request.correlation_id - validation.session_id = request.session_id - - # sender_name: str, sender_role: fedn.Role, model_id: str, model_update_id: str - result: bool = self.client_api.send_model_validation( - sender_name=self.client_obj.name, - receiver_name=request.sender.name, - receiver_role=request.sender.role, - model_id=str(request.model_id), - metrics=json.dumps(metrics), - correlation_id=request.correlation_id, - session_id=request.session_id, - ) - - if result: - self.client_api.send_status( - "Model validation completed.", - log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_VALIDATION, - request=validation, - sesssion_id=request.session_id, - sender_name=self.client_obj.name, - ) - else: - self.client_api.send_status( - "Client {} failed to complete model validation.".format(self.name), - log_level=fedn.Status.WARNING, - request=request, - sesssion_id=request.session_id, - sender_name=self.client_obj.name, - ) + return metrics diff --git a/fedn/network/clients/client_api.py b/fedn/network/clients/fedn_client.py similarity index 56% rename from fedn/network/clients/client_api.py rename to fedn/network/clients/fedn_client.py index 108a5448a..828758131 100644 --- a/fedn/network/clients/client_api.py +++ b/fedn/network/clients/fedn_client.py @@ -1,6 +1,9 @@ import enum +import json import os +import threading import time +import uuid from io import BytesIO from typing import Any, Tuple @@ -28,7 +31,7 @@ def __init__(self, status: str, host: str, fqdn: str, package: str, ip: str, por # Enum for respresenting the result of connecting to the FEDn API class ConnectToApiResult(enum.Enum): Assigned = 0 - ComputePackgeMissing = 1 + ComputePackageMissing = 1 UnAuthorized = 2 UnMatchedConfig = 3 IncorrectUrl = 4 @@ -50,45 +53,26 @@ def get_compute_package_dir_path(): return result -class ClientAPI: - def __init__(self): - self._subscribers = {"train": [], "validation": []} +class FednClient: + def __init__(self, train_callback: callable = None, validate_callback: callable = None, predict_callback: callable = None): + self.train_callback: callable = train_callback + self.validate_callback: callable = validate_callback + self.predict_callback: callable = predict_callback + path = get_compute_package_dir_path() self._package_runtime = PackageRuntime(path) self.dispatcher: Dispatcher = None self.grpc_handler: GrpcHandler = None - def subscribe(self, event_type: str, callback: callable): - """Subscribe to a specific event.""" - if event_type in self._subscribers: - self._subscribers[event_type].append(callback) - else: - raise ValueError(f"Unsupported event type: {event_type}") - - def notify_subscribers(self, event_type: str, *args, **kwargs): - """Notify all subscribers about a specific event.""" - if event_type in self._subscribers: - for callback in self._subscribers[event_type]: - callback(*args, **kwargs) - else: - raise ValueError(f"Unsupported event type: {event_type}") + def set_train_callback(self, callback: callable): + self.train_callback = callback - def train(self, *args, **kwargs): - """Function to be triggered from the server via gRPC.""" - # Perform training logic here - logger.info("Training started") + def set_validate_callback(self, callback: callable): + self.validate_callback = callback - # Notify all subscribers about the train event - self.notify_subscribers("train", *args, **kwargs) - - def validate(self, *args, **kwargs): - """Function to be triggered from the server via gRPC.""" - # Perform validation logic here - logger.info("Validation started") - - # Notify all subscribers about the validation event - self.notify_subscribers("validation", *args, **kwargs) + def set_predict_callback(self, callback: callable): + self.predict_callback = callback def connect_to_api(self, url: str, token: str, json: dict) -> Tuple[ConnectToApiResult, Any]: # TODO: Use new API endpoint (v1) @@ -110,7 +94,7 @@ def connect_to_api(self, url: str, token: str, json: dict) -> Tuple[ConnectToApi elif response.status_code == 203: json_response = response.json() logger.info("Connect to FEDn Api - Remote compute package missing.") - return ConnectToApiResult.ComputePackgeMissing, json_response + return ConnectToApiResult.ComputePackageMissing, json_response elif response.status_code == 401: logger.warning("Connect to FEDn Api - Unauthorized") return ConnectToApiResult.UnAuthorized, "Unauthorized" @@ -215,12 +199,166 @@ def listen_to_task_stream(self, client_name: str, client_id: str): def _task_stream_callback(self, request): if request.type == fedn.StatusType.MODEL_UPDATE: - self.train(request) + self.update_local_model(request) elif request.type == fedn.StatusType.MODEL_VALIDATION: - self.validate(request) + self.validate_global_model(request) + elif request.type == fedn.StatusType.MODEL_PREDICTION: + self.predict_global_model(request) + + def update_local_model(self, request): + model_id = request.model_id + model_update_id = str(uuid.uuid4()) + + tic = time.time() + in_model = self.get_model_from_combiner(id=model_id, client_id=self.client_id) + + if in_model is None: + logger.error("Could not retrieve model from combiner. Aborting training request.") + return + + fetch_model_time = time.time() - tic + + if not self.train_callback: + logger.error("No train callback set") + return + + self.send_status(f"\t Starting processing of training request for model_id {model_id}", sesssion_id=request.session_id, sender_name=self.name) + + logger.info(f"Running train callback with model ID: {model_id}") + tic = time.time() + out_model, meta = self.train_callback(in_model) + meta["processing_time"] = time.time() - tic + + tic = time.time() + self.send_model_to_combiner(model=out_model, id=model_update_id) + meta["upload_model"] = time.time() - tic + + meta["fetch_model"] = fetch_model_time + meta["config"] = request.data + + update = self.create_update_message(model_id=model_id, model_update_id=model_update_id, meta=meta, request=request) + + self.send_model_update(update) + + self.send_status( + "Model update completed.", + log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_UPDATE, + request=update, + sesssion_id=request.session_id, + sender_name=self.name, + ) + + def validate_global_model(self, request): + model_id = request.model_id + + self.send_status(f"Processing validate request for model_id {model_id}", sesssion_id=request.session_id, sender_name=self.name) + + in_model = self.get_model_from_combiner(id=model_id, client_id=self.client_id) + + if in_model is None: + logger.error("Could not retrieve model from combiner. Aborting validation request.") + return + + if not self.validate_callback: + logger.error("No validate callback set") + return + + logger.info(f"Running validate callback with model ID: {model_id}") + metrics = self.validate_callback(in_model) + + if metrics is not None: + # Send validation + validation = self.create_validation_message(metrics=metrics, request=request) + + result: bool = self.send_model_validation(validation) + + if result: + self.send_status( + "Model validation completed.", + log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_VALIDATION, + request=validation, + sesssion_id=request.session_id, + sender_name=self.name, + ) + else: + self.send_status( + "Client {} failed to complete model validation.".format(self.name), + log_level=fedn.Status.WARNING, + request=request, + sesssion_id=request.session_id, + sender_name=self.name, + ) + + def predict_global_model(self, request): + model_id = request.model_id + model = self.get_model_from_combiner(id=model_id, client_id=self.client_id) + + if model is None: + logger.error("Could not retrieve model from combiner. Aborting prediction request.") + return + + if not self.predict_callback: + logger.error("No predict callback set") + return + + logger.info(f"Running predict callback with model ID: {model_id}") + prediction = self.predict_callback(model) + + prediction_message = self.create_prediction_message(prediction=prediction, request=request) + + self.send_model_prediction(prediction_message) + + def create_update_message(self, model_id: str, model_update_id: str, meta: dict, request: fedn.TaskRequest): + return self.grpc_handler.create_update_message( + sender_name=self.name, + model_id=model_id, + model_update_id=model_update_id, + receiver_name=request.sender.name, + receiver_role=request.sender.role, + meta=meta, + ) + + def create_validation_message(self, metrics: dict, request: fedn.TaskRequest): + return self.grpc_handler.create_validation_message( + sender_name=self.name, + receiver_name=request.sender.name, + receiver_role=request.sender.role, + model_id=request.model_id, + metrics=json.dumps(metrics), + correlation_id=request.correlation_id, + session_id=request.session_id, + ) + + def create_prediction_message(self, prediction: dict, request: fedn.TaskRequest): + return self.grpc_handler.create_prediction_message( + sender_name=self.name, + receiver_name=request.sender.name, + receiver_role=request.sender.role, + model_id=request.model_id, + prediction_output=json.dumps(prediction), + correlation_id=request.correlation_id, + session_id=request.session_id, + ) + + def set_name(self, name: str): + logger.info(f"Setting client name to: {name}") + self.name = name + + def set_client_id(self, client_id: str): + logger.info(f"Setting client ID to: {client_id}") + self.client_id = client_id + + def run(self): + threading.Thread(target=self.send_heartbeats, kwargs={"client_name": self.name, "client_id": self.client_id}, daemon=True).start() + try: + self.listen_to_task_stream(client_name=self.name, client_id=self.client_id) + except KeyboardInterrupt: + logger.info("Client stopped by user.") def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> BytesIO: - return self.grpc_handler.get_model_from_combiner(id=id, client_name=client_id, timeout=timeout) + return self.grpc_handler.get_model_from_combiner(id=id, client_id=client_id, timeout=timeout) def send_model_to_combiner(self, model: BytesIO, id: str): return self.grpc_handler.send_model_to_combiner(model, id) @@ -228,32 +366,14 @@ def send_model_to_combiner(self, model: BytesIO, id: str): def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None): return self.grpc_handler.send_status(msg, log_level, type, request, sesssion_id, sender_name) - def send_model_update( - self, - sender_name: str, - sender_role: fedn.Role, - client_id: str, - model_id: str, - model_update_id: str, - receiver_name: str, - receiver_role: fedn.Role, - meta: dict, - ) -> bool: - return self.grpc_handler.send_model_update( - sender_name=sender_name, - sender_role=sender_role, - client_id=client_id, - model_id=model_id, - model_update_id=model_update_id, - receiver_name=receiver_name, - receiver_role=receiver_role, - meta=meta, - ) + def send_model_update(self, update: fedn.ModelUpdate) -> bool: + return self.grpc_handler.send_model_update(update) + + def send_model_validation(self, validation: fedn.ModelValidation) -> bool: + return self.grpc_handler.send_model_validation(validation) - def send_model_validation( - self, sender_name: str, receiver_name: str, receiver_role: fedn.Role, model_id: str, metrics: dict, correlation_id: str, session_id: str - ) -> bool: - return self.grpc_handler.send_model_validation(sender_name, receiver_name, receiver_role, model_id, metrics, correlation_id, session_id) + def send_model_prediction(self, prediction: fedn.ModelPrediction) -> bool: + return self.grpc_handler.send_model_prediction(prediction) # Init functions def init_remote_compute_package(self, url: str, token: str, package_checksum: str = None) -> bool: diff --git a/fedn/network/clients/grpc_handler.py b/fedn/network/clients/grpc_handler.py index 759161a4c..4b7d9874c 100644 --- a/fedn/network/clients/grpc_handler.py +++ b/fedn/network/clients/grpc_handler.py @@ -277,11 +277,9 @@ def send_model_to_combiner(self, model: BytesIO, id: str): return result - def send_model_update( + def create_update_message( self, sender_name: str, - sender_role: fedn.Role, - client_id: str, model_id: str, model_update_id: str, receiver_name: str, @@ -290,8 +288,8 @@ def send_model_update( ): update = fedn.ModelUpdate() update.sender.name = sender_name - update.sender.role = sender_role - update.sender.client_id = client_id + update.sender.role = fedn.WORKER + update.sender.client_id = self.metadata[0][1] update.receiver.name = receiver_name update.receiver.role = receiver_role update.model_id = model_id @@ -299,22 +297,18 @@ def send_model_update( update.timestamp = str(datetime.now()) update.meta = json.dumps(meta) - try: - logger.info("Sending model update to combiner.") - _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) - except grpc.RpcError as e: - return self._handle_grpc_error( - e, "SendModelUpdate", lambda: self.send_model_update(sender_name, sender_role, model_id, model_update_id, receiver_name, receiver_role, meta) - ) - except Exception as e: - logger.error(f"GRPC (SendModelUpdate): An error occurred: {e}") - self._disconnect() - - return True + return update - def send_model_validation( - self, sender_name: str, receiver_name: str, receiver_role: fedn.Role, model_id: str, metrics: str, correlation_id: str, session_id: str - ) -> bool: + def create_validation_message( + self, + sender_name: str, + receiver_name: str, + receiver_role: fedn.Role, + model_id: str, + metrics: str, + correlation_id: str, + session_id: str, + ): validation = fedn.ModelValidation() validation.sender.name = sender_name validation.sender.role = fedn.WORKER @@ -326,6 +320,44 @@ def send_model_validation( validation.correlation_id = correlation_id validation.session_id = session_id + return validation + + def create_prediction_message( + self, + sender_name: str, + receiver_name: str, + receiver_role: fedn.Role, + model_id: str, + prediction_output: str, + correlation_id: str, + session_id: str, + ): + prediction = fedn.ModelPrediction() + prediction.sender.name = sender_name + prediction.sender.role = fedn.WORKER + prediction.receiver.name = receiver_name + prediction.receiver.role = receiver_role + prediction.model_id = model_id + prediction.data = prediction_output + prediction.timestamp.GetCurrentTime() + prediction.correlation_id = correlation_id + prediction.prediction_id = session_id + + return prediction + + def send_model_update(self, update: fedn.ModelUpdate): + try: + logger.info("Sending model update to combiner.") + _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) + except grpc.RpcError as e: + return self._handle_grpc_error(e, "SendModelUpdate", lambda: self.send_model_update(update)) + except Exception as e: + logger.error(f"GRPC (SendModelUpdate): An error occurred: {e}") + self._disconnect() + + return True + + def send_model_validation(self, validation: fedn.ModelValidation) -> bool: try: logger.info("Sending model validation to combiner.") _ = self.combinerStub.SendModelValidation(validation, metadata=self.metadata) @@ -333,7 +365,7 @@ def send_model_validation( return self._handle_grpc_error( e, "SendModelValidation", - lambda: self.send_model_validation(sender_name, receiver_name, receiver_role, model_id, metrics, correlation_id, session_id), + lambda: self.send_model_validation(validation), ) except Exception as e: logger.error(f"GRPC (SendModelValidation): An error occurred: {e}") @@ -341,6 +373,22 @@ def send_model_validation( return True + def send_model_prediction(self, prediction: fedn.ModelPrediction) -> bool: + try: + logger.info("Sending model prediction to combiner.") + _ = self.combinerStub.SendModelPrediction(prediction, metadata=self.metadata) + except grpc.RpcError as e: + return self._handle_grpc_error( + e, + "SendModelPrediction", + lambda: self.send_model_prediction(prediction), + ) + except Exception as e: + logger.error(f"GRPC (SendModelPrediction): An error occurred: {e}") + self._disconnect() + + return True + def _handle_grpc_error(self, e, method_name: str, sender_function: Callable): status_code = e.code() if status_code == grpc.StatusCode.UNAVAILABLE: