diff --git a/.gitignore b/.gitignore index b595e49c9..e75a5b9e9 100644 --- a/.gitignore +++ b/.gitignore @@ -177,6 +177,7 @@ config/extra-hosts-reducer.yaml config/settings-client.yaml config/settings-reducer.yaml config/settings-combiner.yaml +config/settings-hooks.yaml ./tmp/* diff --git a/Dockerfile b/Dockerfile index 4f5952a33..b651dbea4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,6 +11,7 @@ ARG REQUIREMENTS="" COPY . /app COPY config/settings-client.yaml.template /app/config/settings-client.yaml COPY config/settings-combiner.yaml.template /app/config/settings-combiner.yaml +COPY config/settings-hooks.yaml.template /app/config/settings-hooks.yaml COPY config/settings-reducer.yaml.template /app/config/settings-reducer.yaml COPY $REQUIREMENTS /app/config/requirements.txt diff --git a/config/settings-hooks.yaml.template b/config/settings-hooks.yaml.template new file mode 100644 index 000000000..e395b20ce --- /dev/null +++ b/config/settings-hooks.yaml.template @@ -0,0 +1,8 @@ +network_id: fedn-network +discover_host: api-server +discover_port: 8092 + +name: hooks +host: hooks +port: 12081 +max_clients: 30 \ No newline at end of file diff --git a/docker-compose.yaml b/docker-compose.yaml index 26291748f..e92a018bc 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -19,13 +19,7 @@ services: - MINIO_ROOT_PASSWORD=password command: server /data --console-address minio:9001 healthcheck: - test: - [ - "CMD", - "curl", - "-f", - "http://minio:9000/minio/health/live" - ] + test: [ "CMD", "curl", "-f", "http://minio:9000/minio/health/live" ] interval: 30s timeout: 20s retries: 3 @@ -89,6 +83,7 @@ services: - GET_HOSTS_FROM=dns - STATESTORE_CONFIG=/app/config/settings-combiner.yaml - MODELSTORAGE_CONFIG=/app/config/settings-combiner.yaml + - HOOK_SERVICE_HOST=hook:12081 build: context: . args: @@ -103,17 +98,36 @@ services: ports: - 12080:12080 healthcheck: - test: - [ - "CMD", - "/bin/grpc_health_probe", - "-addr=localhost:12080" - ] + test: [ "CMD", "/bin/grpc_health_probe", "-addr=localhost:12080" ] interval: 20s timeout: 10s retries: 5 depends_on: - api-server + - hooks + # Hooks + hooks: + container_name: hook + environment: + - GET_HOSTS_FROM=dns + build: + context: . + args: + BASE_IMG: ${BASE_IMG:-python:3.10-slim} + GRPC_HEALTH_PROBE_VERSION: v0.4.24 + working_dir: /app + volumes: + - ${HOST_REPO_DIR:-.}/fedn:/app/fedn + entrypoint: [ "sh", "-c" ] + command: + - "/venv/bin/pip install --no-cache-dir -e . && /venv/bin/fedn hooks start" + ports: + - 12081:12081 + healthcheck: + test: [ "CMD", "/bin/grpc_health_probe", "-addr=localhost:12081" ] + interval: 20s + timeout: 10s + retries: 5 # Client client: diff --git a/examples/server-functions/.dockerignore b/examples/server-functions/.dockerignore new file mode 100644 index 000000000..8ba9024ad --- /dev/null +++ b/examples/server-functions/.dockerignore @@ -0,0 +1,4 @@ +data +seed.npz +*.tgz +*.tar.gz \ No newline at end of file diff --git a/examples/server-functions/.gitignore b/examples/server-functions/.gitignore new file mode 100644 index 000000000..a9f01054b --- /dev/null +++ b/examples/server-functions/.gitignore @@ -0,0 +1,6 @@ +data +*.npz +*.tgz +*.tar.gz +.mnist-pytorch +client.yaml \ No newline at end of file diff --git a/examples/server-functions/README.rst b/examples/server-functions/README.rst new file mode 100644 index 000000000..c594fac28 --- /dev/null +++ b/examples/server-functions/README.rst @@ -0,0 +1,11 @@ +FEDn Project: Server functions toy example +----------------------------- + +See server_functions.py for details. + +README Will be updated after studio update. + +To run with server functions: + +from server_functions import ServerFunctions +client.start_session(server_functions=ServerFunctions) \ No newline at end of file diff --git a/examples/server-functions/client/data.py b/examples/server-functions/client/data.py new file mode 100644 index 000000000..b921f3132 --- /dev/null +++ b/examples/server-functions/client/data.py @@ -0,0 +1,97 @@ +import os +from math import floor + +import torch +import torchvision + +dir_path = os.path.dirname(os.path.realpath(__file__)) +abs_path = os.path.abspath(dir_path) + + +def get_data(out_dir="data"): + # Make dir if necessary + if not os.path.exists(out_dir): + os.mkdir(out_dir) + + # Only download if not already downloaded + if not os.path.exists(f"{out_dir}/train"): + torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True, download=True) + if not os.path.exists(f"{out_dir}/test"): + torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False, download=True) + + +def load_data(data_path, is_train=True): + """Load data from disk. + + :param data_path: Path to data file. + :type data_path: str + :param is_train: Whether to load training or test data. + :type is_train: bool + :return: Tuple of data and labels. + :rtype: tuple + """ + if data_path is None: + data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.pt") + + data = torch.load(data_path) + + if is_train: + X = data["x_train"] + y = data["y_train"] + else: + X = data["x_test"] + y = data["y_test"] + + # Normalize + X = X / 255 + + return X, y + + +def splitset(dataset, parts): + n = dataset.shape[0] + local_n = floor(n / parts) + result = [] + for i in range(parts): + result.append(dataset[i * local_n : (i + 1) * local_n]) + return result + + +def split(out_dir="data"): + n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2)) + + # Make dir + if not os.path.exists(f"{out_dir}/clients"): + os.mkdir(f"{out_dir}/clients") + + # Load and convert to dict + train_data = torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True) + test_data = torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False) + data = { + "x_train": splitset(train_data.data, n_splits), + "y_train": splitset(train_data.targets, n_splits), + "x_test": splitset(test_data.data, n_splits), + "y_test": splitset(test_data.targets, n_splits), + } + + # Make splits + for i in range(n_splits): + subdir = f"{out_dir}/clients/{str(i+1)}" + if not os.path.exists(subdir): + os.mkdir(subdir) + torch.save( + { + "x_train": data["x_train"][i], + "y_train": data["y_train"][i], + "x_test": data["x_test"][i], + "y_test": data["y_test"][i], + }, + f"{subdir}/mnist.pt", + ) + + +if __name__ == "__main__": + # Prepare data if not already done + if not os.path.exists(abs_path + "/data/clients/1"): + get_data() + split() diff --git a/examples/server-functions/client/fedn.yaml b/examples/server-functions/client/fedn.yaml new file mode 100644 index 000000000..30873488b --- /dev/null +++ b/examples/server-functions/client/fedn.yaml @@ -0,0 +1,12 @@ +python_env: python_env.yaml +entry_points: + build: + command: python model.py + startup: + command: python data.py + train: + command: python train.py + validate: + command: python validate.py + predict: + command: python predict.py \ No newline at end of file diff --git a/examples/server-functions/client/model.py b/examples/server-functions/client/model.py new file mode 100644 index 000000000..6ad344770 --- /dev/null +++ b/examples/server-functions/client/model.py @@ -0,0 +1,76 @@ +import collections + +import torch + +from fedn.utils.helpers.helpers import get_helper + +HELPER_MODULE = "numpyhelper" +helper = get_helper(HELPER_MODULE) + + +def compile_model(): + """Compile the pytorch model. + + :return: The compiled model. + :rtype: torch.nn.Module + """ + + class Net(torch.nn.Module): + def __init__(self): + super(Net, self).__init__() + self.fc1 = torch.nn.Linear(784, 64) + self.fc2 = torch.nn.Linear(64, 32) + self.fc3 = torch.nn.Linear(32, 10) + + def forward(self, x): + x = torch.nn.functional.relu(self.fc1(x.reshape(x.size(0), 784))) + x = torch.nn.functional.dropout(x, p=0.5, training=self.training) + x = torch.nn.functional.relu(self.fc2(x)) + x = torch.nn.functional.log_softmax(self.fc3(x), dim=1) + return x + + return Net() + + +def save_parameters(model, out_path): + """Save model paramters to file. + + :param model: The model to serialize. + :type model: torch.nn.Module + :param out_path: The path to save to. + :type out_path: str + """ + parameters_np = [val.cpu().numpy() for _, val in model.state_dict().items()] + helper.save(parameters_np, out_path) + + +def load_parameters(model_path): + """Load model parameters from file and populate model. + + param model_path: The path to load from. + :type model_path: str + :return: The loaded model. + :rtype: torch.nn.Module + """ + model = compile_model() + parameters_np = helper.load(model_path) + + params_dict = zip(model.state_dict().keys(), parameters_np) + state_dict = collections.OrderedDict({key: torch.tensor(x) for key, x in params_dict}) + model.load_state_dict(state_dict, strict=True) + return model + + +def init_seed(out_path="seed.npz"): + """Initialize seed model and save it to file. + + :param out_path: The path to save the seed model to. + :type out_path: str + """ + # Init and save + model = compile_model() + save_parameters(model, out_path) + + +if __name__ == "__main__": + init_seed("../seed.npz") diff --git a/examples/server-functions/client/predict.py b/examples/server-functions/client/predict.py new file mode 100644 index 000000000..aaf9f0f50 --- /dev/null +++ b/examples/server-functions/client/predict.py @@ -0,0 +1,37 @@ +import os +import sys + +import torch +from data import load_data +from model import load_parameters + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.abspath(dir_path)) + + +def predict(in_model_path, out_artifact_path, data_path=None): + """Validate model. + + :param in_model_path: The path to the input model. + :type in_model_path: str + :param out_artifact_path: The path to save the predict output to. + :type out_artifact_path: str + :param data_path: The path to the data file. + :type data_path: str + """ + # Load data + x_test, y_test = load_data(data_path, is_train=False) + + # Load model + model = load_parameters(in_model_path) + model.eval() + + # Predict + with torch.no_grad(): + y_pred = model(x_test) + # Save prediction to file/artifact, the artifact will be uploaded to the object store by the client + torch.save(y_pred, out_artifact_path) + + +if __name__ == "__main__": + predict(sys.argv[1], sys.argv[2]) diff --git a/examples/server-functions/client/python_env.yaml b/examples/server-functions/client/python_env.yaml new file mode 100644 index 000000000..afdea926f --- /dev/null +++ b/examples/server-functions/client/python_env.yaml @@ -0,0 +1,9 @@ +name: mnist-pytorch +build_dependencies: + - pip + - setuptools + - wheel +dependencies: + - torch==2.3.1 + - torchvision==0.18.1 + - fedn diff --git a/examples/server-functions/client/train.py b/examples/server-functions/client/train.py new file mode 100644 index 000000000..c67d3ec69 --- /dev/null +++ b/examples/server-functions/client/train.py @@ -0,0 +1,110 @@ +import json +import math +import os +import sys + +import torch +from model import load_parameters, save_parameters + +from data import load_data +from fedn.utils.helpers.helpers import save_metadata + + +# swap this to the load_metadata from helpers.helpers on release.. +def load_client_settings(filename): + """Load client settings from file. + + :param filename: The name of the file to load from. + :type filename: str + :return: The loaded metadata. + :rtype: dict + """ + with open(filename + "-metadata", "r") as infile: + metadata = json.load(infile) + return metadata + + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.abspath(dir_path)) + + +def validate(model, data_path): + """Validate a model.""" + x_test, y_test = load_data(data_path, is_train=False) + model.eval() + + model.train() + criterion = torch.nn.NLLLoss() + with torch.no_grad(): + test_out = model(x_test) + test_loss = criterion(test_out, y_test) + test_accuracy = torch.sum(torch.argmax(test_out, dim=1) == y_test) / len(test_out) + return test_loss, test_accuracy + + +def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01): + """Complete a model update. + + Load model paramters from in_model_path (managed by the FEDn client), + perform a model update, and write updated paramters + to out_model_path (picked up by the FEDn client). + + :param in_model_path: The path to the input model. + :type in_model_path: str + :param out_model_path: The path to save the output model to. + :type out_model_path: str + :param data_path: The path to the data file. + :type data_path: str + :param batch_size: The batch size to use. + :type batch_size: int + :param epochs: The number of epochs to train. + :type epochs: int + :param lr: The learning rate to use. + :type lr: float + """ + # Load data + x_train, y_train = load_data(data_path) + + # Load parmeters and initialize model + model = load_parameters(in_model_path) + + client_settings = load_client_settings(in_model_path) + lr = client_settings["learning_rate"] + + # Train + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + n_batches = int(math.ceil(len(x_train) / batch_size)) + criterion = torch.nn.NLLLoss() + for e in range(epochs): # epoch loop + for b in range(n_batches): # batch loop + # Retrieve current batch + batch_x = x_train[b * batch_size : (b + 1) * batch_size] + batch_y = y_train[b * batch_size : (b + 1) * batch_size] + # Train on batch + optimizer.zero_grad() + outputs = model(batch_x) + loss = criterion(outputs, batch_y) + loss.backward() + optimizer.step() + # Log + if b % 100 == 0: + print(f"Epoch {e}/{epochs-1} | Batch: {b}/{n_batches-1} | Loss: {loss.item()}") + + # Metadata needed for aggregation server side + metadata = { + # num_examples are mandatory + "num_examples": len(x_train), + "batch_size": batch_size, + "epochs": epochs, + "lr": lr, + } + + # Save JSON metadata file (mandatory) + save_metadata(metadata, out_model_path) + + # Save model update (mandatory) + save_parameters(model, out_model_path) + + +if __name__ == "__main__": + train(sys.argv[1], sys.argv[2]) diff --git a/examples/server-functions/client/validate.py b/examples/server-functions/client/validate.py new file mode 100644 index 000000000..09328181f --- /dev/null +++ b/examples/server-functions/client/validate.py @@ -0,0 +1,55 @@ +import os +import sys + +import torch +from model import load_parameters + +from data import load_data +from fedn.utils.helpers.helpers import save_metrics + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.abspath(dir_path)) + + +def validate(in_model_path, out_json_path, data_path=None): + """Validate model. + + :param in_model_path: The path to the input model. + :type in_model_path: str + :param out_json_path: The path to save the output JSON to. + :type out_json_path: str + :param data_path: The path to the data file. + :type data_path: str + """ + # Load data + x_train, y_train = load_data(data_path) + x_test, y_test = load_data(data_path, is_train=False) + + # Load model + model = load_parameters(in_model_path) + model.eval() + + # Evaluate + criterion = torch.nn.NLLLoss() + with torch.no_grad(): + train_out = model(x_train) + training_loss = criterion(train_out, y_train) + training_accuracy = torch.sum(torch.argmax(train_out, dim=1) == y_train) / len(train_out) + test_out = model(x_test) + test_loss = criterion(test_out, y_test) + test_accuracy = torch.sum(torch.argmax(test_out, dim=1) == y_test) / len(test_out) + + # JSON schema + report = { + "training_loss": training_loss.item(), + "training_accuracy": training_accuracy.item(), + "test_loss": test_loss.item(), + "test_accuracy": test_accuracy.item(), + } + + # Save JSON + save_metrics(report, out_json_path) + + +if __name__ == "__main__": + validate(sys.argv[1], sys.argv[2]) diff --git a/examples/server-functions/docker-compose.override.yaml b/examples/server-functions/docker-compose.override.yaml new file mode 100644 index 000000000..822a696dc --- /dev/null +++ b/examples/server-functions/docker-compose.override.yaml @@ -0,0 +1,35 @@ +# Compose schema version +version: '3.4' + +# Overriding requirements + +x-env: &defaults + GET_HOSTS_FROM: dns + FEDN_PACKAGE_EXTRACT_DIR: package + FEDN_NUM_DATA_SPLITS: 2 + +services: + + client1: + extends: + file: ${HOST_REPO_DIR:-.}/docker-compose.yaml + service: client + environment: + <<: *defaults + FEDN_DATA_PATH: /app/package/client/data/clients/1/mnist.pt + deploy: + replicas: 1 + volumes: + - ${HOST_REPO_DIR:-.}/fedn:/app/fedn + + client2: + extends: + file: ${HOST_REPO_DIR:-.}/docker-compose.yaml + service: client + environment: + <<: *defaults + FEDN_DATA_PATH: /app/package/client/data/clients/2/mnist.pt + deploy: + replicas: 1 + volumes: + - ${HOST_REPO_DIR:-.}/fedn:/app/fedn diff --git a/examples/server-functions/server_functions.py b/examples/server-functions/server_functions.py new file mode 100644 index 000000000..e07ad83ad --- /dev/null +++ b/examples/server-functions/server_functions.py @@ -0,0 +1,44 @@ +from fedn.common.log_config import logger +from fedn.network.combiner.hooks.allowed_import import Dict, List, ServerFunctionsBase, Tuple, np, random + +# See allowed_imports for what packages you can use in this class. + + +class ServerFunctions(ServerFunctionsBase): + # toy example to highlight functionality of ServerFunctions. + def __init__(self) -> None: + # You can keep a state between different functions to have them work together. + self.round = 0 + self.lr = 0.1 + + # Skip any function to use the default FEDn implementation for the function. + + # Called first in the beggining of a round to select clients. + def client_selection(self, client_ids: List[str]) -> List: + # Pick 10 random clients + client_ids = random.sample(client_ids, min(len(client_ids), 10)) # noqa: F405 + return client_ids + + # Called secondly before sending the global model. + def client_settings(self, global_model: List[np.ndarray]) -> dict: + # Decrease learning rate every 10 rounds + if self.round % 10 == 0: + self.lr = self.lr * 0.1 + # see client/train.py for how to load the client settings. + self.round += 1 + return {"learning_rate": self.lr} + + # Called third to aggregate the client updates. + def aggregate(self, previous_global: List[np.ndarray], client_updates: Dict[str, Tuple[List[np.ndarray], dict]]) -> List[np.ndarray]: + # Weighted fedavg implementation. + weighted_sum = [np.zeros_like(param) for param in previous_global] + total_weight = 0 + for client_id, (client_parameters, metadata) in client_updates.items(): + num_examples = metadata.get("num_examples", 1) + total_weight += num_examples + for i in range(len(weighted_sum)): + weighted_sum[i] += client_parameters[i] * num_examples + + logger.info("Models aggregated") + averaged_updates = [weighted / total_weight for weighted in weighted_sum] + return averaged_updates diff --git a/fedn/cli/__init__.py b/fedn/cli/__init__.py index bcd27dc53..7028dbfa6 100644 --- a/fedn/cli/__init__.py +++ b/fedn/cli/__init__.py @@ -1,6 +1,7 @@ from .client_cmd import client_cmd # noqa: F401 from .combiner_cmd import combiner_cmd # noqa: F401 from .config_cmd import config_cmd # noqa: F401 +from .hooks_cmd import hooks_cmd # noqa: F401 from .main import main # noqa: F401 from .model_cmd import model_cmd # noqa: F401 from .package_cmd import package_cmd # noqa: F401 diff --git a/fedn/cli/hooks_cmd.py b/fedn/cli/hooks_cmd.py new file mode 100644 index 000000000..1b263fee3 --- /dev/null +++ b/fedn/cli/hooks_cmd.py @@ -0,0 +1,20 @@ +import click + +from fedn.network.combiner.hooks.hooks import serve + +from .main import main + + +@main.group("hooks") +@click.pass_context +def hooks_cmd(ctx): + """:param ctx:""" + pass + + +@hooks_cmd.command("start") +@click.pass_context +def start_cmd(ctx): + """:param ctx:""" + click.echo("Started hooks container") + serve() diff --git a/fedn/network/api/client.py b/fedn/network/api/client.py index e7408ba3a..ab3e2e07d 100644 --- a/fedn/network/api/client.py +++ b/fedn/network/api/client.py @@ -1,7 +1,10 @@ +import inspect import os import requests +from fedn.network.combiner.hooks.serverfunctionsbase import ServerFunctionsBase + __all__ = ["APIClient"] @@ -574,6 +577,7 @@ def start_session( helper: str = "", min_clients: int = 1, requested_clients: int = 8, + server_functions: ServerFunctionsBase = None, ): """Start a new session. @@ -617,6 +621,7 @@ def start_session( "helper": helper, "min_clients": min_clients, "requested_clients": requested_clients, + "server_functions": None if server_functions is None else inspect.getsource(server_functions), }, verify=self.verify, headers=self.headers, diff --git a/fedn/network/api/interface.py b/fedn/network/api/interface.py index 493fae8d6..eeaff3e05 100644 --- a/fedn/network/api/interface.py +++ b/fedn/network/api/interface.py @@ -906,6 +906,7 @@ def start_session( helper="", min_clients=1, requested_clients=8, + server_functions=None, ): """Start a session. @@ -1008,6 +1009,7 @@ def start_session( "task": (""), "validate": validate, "helper_type": helper, + "server_functions": server_functions, } # Start session diff --git a/fedn/network/clients/client.py b/fedn/network/clients/client.py index 12f76e2a4..92731ccf3 100644 --- a/fedn/network/clients/client.py +++ b/fedn/network/clients/client.py @@ -23,7 +23,7 @@ from fedn.network.clients.package import PackageRuntime from fedn.network.clients.state import ClientState, ClientStateToString from fedn.network.combiner.modelservice import get_tmp_path, upload_request_generator -from fedn.utils.helpers.helpers import get_helper +from fedn.utils.helpers.helpers import get_helper, load_metadata, save_metadata CHUNK_SIZE = 1024 * 1024 VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$" @@ -456,7 +456,7 @@ def _listen_to_task_stream(self): if not self._connected: return - def _process_training_request(self, model_id: str, session_id: str = None): + def _process_training_request(self, model_id: str, session_id: str = None, client_settings: dict = None): """Process a training (model update) request. :param model_id: The model id of the model to be updated. @@ -482,6 +482,8 @@ def _process_training_request(self, model_id: str, session_id: str = None): with open(inpath, "wb") as fh: fh.write(mdl.getbuffer()) + save_metadata(metadata=client_settings, filename=inpath) + outpath = self.helper.get_tmp_path() tic = time.time() # TODO: Check return status, fail gracefully @@ -502,8 +504,7 @@ def _process_training_request(self, model_id: str, session_id: str = None): meta["upload_model"] = time.time() - tic # Read the metadata file - with open(outpath + "-metadata", "r") as fh: - training_metadata = json.loads(fh.read()) + training_metadata = load_metadata(outpath) meta["training_metadata"] = training_metadata os.unlink(inpath) @@ -614,7 +615,8 @@ def process_request(self): if task_type == "train": tic = time.time() self.state = ClientState.training - model_id, meta = self._process_training_request(request.model_id, session_id=request.session_id) + client_settings = json.loads(request.data).get("client_settings", {}) + model_id, meta = self._process_training_request(request.model_id, session_id=request.session_id, client_settings=client_settings) if meta is not None: processing_time = time.time() - tic @@ -625,6 +627,7 @@ def process_request(self): # Send model update to combiner update = fedn.ModelUpdate() update.sender.name = self.name + update.sender.client_id = self.id update.sender.role = fedn.WORKER update.receiver.name = request.sender.name update.receiver.role = request.sender.role diff --git a/fedn/network/combiner/aggregators/aggregatorbase.py b/fedn/network/combiner/aggregators/aggregatorbase.py index 44d10fca2..6866c6260 100644 --- a/fedn/network/combiner/aggregators/aggregatorbase.py +++ b/fedn/network/combiner/aggregators/aggregatorbase.py @@ -1,10 +1,7 @@ import importlib -import json -import queue -import traceback from abc import ABC, abstractmethod -from fedn.common.log_config import logger +from fedn.network.combiner.updatehandler import UpdateHandler AGGREGATOR_PLUGIN_PATH = "fedn.network.combiner.aggregators.{}" @@ -12,27 +9,15 @@ class AggregatorBase(ABC): """Abstract class defining an aggregator. - :param id: A reference to id of :class: `fedn.network.combiner.Combiner` - :type id: str - :param storage: Model repository for :class: `fedn.network.combiner.Combiner` - :type storage: class: `fedn.common.storage.s3.s3repo.S3ModelRepository` - :param server: A handle to the Combiner class :class: `fedn.network.combiner.Combiner` - :type server: class: `fedn.network.combiner.Combiner` - :param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService` - :type modelservice: class: `fedn.network.combiner.modelservice.ModelService` - :param control: A handle to the :class: `fedn.network.combiner.roundhandler.RoundHandler` - :type control: class: `fedn.network.combiner.roundhandler.RoundHandler` + :param control: A handle to the :class: `fedn.network.combiner.updatehandler.UpdateHandler` + :type control: class: `fedn.network.combiner.updatehandler.UpdateHandler` """ @abstractmethod - def __init__(self, storage, server, modelservice, round_handler): + def __init__(self, update_handler: UpdateHandler): """Initialize the aggregator.""" self.name = self.__class__.__name__ - self.storage = storage - self.server = server - self.modelservice = modelservice - self.round_handler = round_handler - self.model_updates = queue.Queue() + self.update_handler = update_handler @abstractmethod def combine_models(self, nr_expected_models=None, nr_required_models=1, helper=None, timeout=180, delete_models=True, parameters=None): @@ -55,96 +40,8 @@ def combine_models(self, nr_expected_models=None, nr_required_models=1, helper=N """ pass - def on_model_update(self, model_update): - """Callback when a new client model update is recieved. - - Performs (optional) validation and pre-processing, - and then puts the update id on the aggregation queue. - Override in subclass as needed. - - :param model_update: fedn.network.grpc.fedn.proto.ModelUpdate - :type model_id: str - """ - try: - logger.info("AGGREGATOR({}): callback received model update {}".format(self.name, model_update.model_update_id)) - - # Validate the update and metadata - valid_update = self._validate_model_update(model_update) - if valid_update: - # Push the model update to the processing queue - self.model_updates.put(model_update) - else: - logger.warning("AGGREGATOR({}): Invalid model update, skipping.".format(self.name)) - except Exception as e: - tb = traceback.format_exc() - logger.error("AGGREGATOR({}): failed to receive model update: {}".format(self.name, e)) - logger.error(tb) - pass - - def _validate_model_update(self, model_update): - """Validate the model update. - - :param model_update: A ModelUpdate message. - :type model_update: object - :return: True if the model update is valid, False otherwise. - :rtype: bool - """ - try: - data = json.loads(model_update.meta)["training_metadata"] - _ = data["num_examples"] - except KeyError: - tb = traceback.format_exc() - logger.error("AGGREGATOR({}): Invalid model update, missing metadata.".format(self.name)) - logger.error(tb) - return False - return True - - def next_model_update(self): - """Get the next model update from the queue. - - :param helper: A helper object. - :type helper: object - :return: The model update. - :rtype: fedn.network.grpc.fedn.proto.ModelUpdate - """ - model_update = self.model_updates.get(block=False) - return model_update - - def load_model_update(self, model_update, helper): - """Load the memory representation of the model update. - - Load the model update paramters and the - associate metadata into memory. - - :param model_update: The model update. - :type model_update: fedn.network.grpc.fedn.proto.ModelUpdate - :param helper: A helper object. - :type helper: fedn.utils.helpers.helperbase.Helper - :return: A tuple of (parameters, metadata) - :rtype: tuple - """ - model_id = model_update.model_update_id - model = self.round_handler.load_model_update(helper, model_id) - # Get relevant metadata - metadata = json.loads(model_update.meta) - if "config" in metadata.keys(): - # Used in Python client - config = json.loads(metadata["config"]) - else: - # Used in C++ client - config = json.loads(model_update.config) - training_metadata = metadata["training_metadata"] - training_metadata["round_id"] = config["round_id"] - - return model, training_metadata - - def get_state(self): - """Get the state of the aggregator's queue, including the number of model updates.""" - state = {"queue_len": self.model_updates.qsize()} - return state - -def get_aggregator(aggregator_module_name, storage, server, modelservice, control): +def get_aggregator(aggregator_module_name, update_handler): """Return an instance of the helper class. :param helper_module_name: The name of the helper plugin module. @@ -162,4 +59,4 @@ def get_aggregator(aggregator_module_name, storage, server, modelservice, contro """ aggregator_plugin = AGGREGATOR_PLUGIN_PATH.format(aggregator_module_name) aggregator = importlib.import_module(aggregator_plugin) - return aggregator.Aggregator(storage, server, modelservice, control) + return aggregator.Aggregator(update_handler) diff --git a/fedn/network/combiner/aggregators/fedavg.py b/fedn/network/combiner/aggregators/fedavg.py index 0d965cfa0..71dab273a 100644 --- a/fedn/network/combiner/aggregators/fedavg.py +++ b/fedn/network/combiner/aggregators/fedavg.py @@ -8,22 +8,13 @@ class Aggregator(AggregatorBase): """Local SGD / Federated Averaging (FedAvg) aggregator. Computes a weighted mean of parameter updates. - :param id: A reference to id of :class: `fedn.network.combiner.Combiner` - :type id: str - :param storage: Model repository for :class: `fedn.network.combiner.Combiner` - :type storage: class: `fedn.common.storage.s3.s3repo.S3ModelRepository` - :param server: A handle to the Combiner class :class: `fedn.network.combiner.Combiner` - :type server: class: `fedn.network.combiner.Combiner` - :param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService` - :type modelservice: class: `fedn.network.combiner.modelservice.ModelService` - :param control: A handle to the :class: `fedn.network.combiner.roundhandler.RoundHandler` - :type control: class: `fedn.network.combiner.roundhandler.RoundHandler` - + :param control: A handle to the :class: `fedn.network.combiner.updatehandler.UpdateHandler` + :type control: class: `fedn.network.combiner.updatehandler.UpdateHandler` """ - def __init__(self, storage, server, modelservice, round_handler): + def __init__(self, update_handler): """Constructor method""" - super().__init__(storage, server, modelservice, round_handler) + super().__init__(update_handler) self.name = "fedavg" @@ -52,15 +43,14 @@ def combine_models(self, helper=None, delete_models=True, parameters=None): logger.info("AGGREGATOR({}): Aggregating model updates... ".format(self.name)) - while not self.model_updates.empty(): + while not self.update_handler.model_updates.empty(): try: - # Get next model from queue logger.info("AGGREGATOR({}): Getting next model update from queue.".format(self.name)) - model_update = self.next_model_update() + model_update = self.update_handler.next_model_update() # Load model parameters and metadata logger.info("AGGREGATOR({}): Loading model metadata {}.".format(self.name, model_update.model_update_id)) - model_next, metadata = self.load_model_update(model_update, helper) + model_next, metadata = self.update_handler.load_model_update(model_update, helper) logger.info("AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_update.model_update_id, metadata)) @@ -75,14 +65,11 @@ def combine_models(self, helper=None, delete_models=True, parameters=None): nr_aggregated_models += 1 # Delete model from storage if delete_models: - self.modelservice.temp_model_storage.delete(model_update.model_update_id) - logger.info("AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_update.model_update_id)) - self.model_updates.task_done() + self.update_handler.delete_model(model_update) except Exception as e: tb = traceback.format_exc() logger.error(f"AGGREGATOR({self.name}): Error encoutered while processing model update: {e}") logger.error(tb) - self.model_updates.task_done() data["nr_aggregated_models"] = nr_aggregated_models diff --git a/fedn/network/combiner/aggregators/fedopt.py b/fedn/network/combiner/aggregators/fedopt.py index 5041e097f..d91fe6d22 100644 --- a/fedn/network/combiner/aggregators/fedopt.py +++ b/fedn/network/combiner/aggregators/fedopt.py @@ -6,7 +6,7 @@ class Aggregator(AggregatorBase): - """ Federated Optimization (FedOpt) aggregator. + """Federated Optimization (FedOpt) aggregator. Implmentation following: https://arxiv.org/pdf/2003.00295.pdf @@ -16,23 +16,13 @@ class Aggregator(AggregatorBase): are "adam", "yogi", "adagrad". - - :param id: A reference to id of :class: `fedn.network.combiner.Combiner` - :type id: str - :param storage: Model repository for :class: `fedn.network.combiner.Combiner` - :type storage: class: `fedn.common.storage.s3.s3repo.S3ModelRepository` - :param server: A handle to the Combiner class :class: `fedn.network.combiner.Combiner` - :type server: class: `fedn.network.combiner.Combiner` - :param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService` - :type modelservice: class: `fedn.network.combiner.modelservice.ModelService` - :param control: A handle to the :class: `fedn.network.combiner.roundhandler.RoundHandler` - :type control: class: `fedn.network.combiner.roundhandler.RoundHandler` + :param control: A handle to the :class: `fedn.network.combiner.updatehandler.UpdateHandler` + :type control: class: `fedn.network.combiner.updatehandler.UpdateHandler` """ - def __init__(self, storage, server, modelservice, round_handler): - - super().__init__(storage, server, modelservice, round_handler) + def __init__(self, update_handler): + super().__init__(update_handler) self.name = "fedopt" # To store momentum @@ -103,42 +93,34 @@ def combine_models(self, helper=None, delete_models=True, parameters=None): nr_aggregated_models = 0 total_examples = 0 - logger.info( - "AGGREGATOR({}): Aggregating model updates... ".format(self.name)) + logger.info("AGGREGATOR({}): Aggregating model updates... ".format(self.name)) - while not self.model_updates.empty(): + while not self.update_handler.model_updates.empty(): try: - # Get next model from queue - model_update = self.next_model_update() - + logger.info("AGGREGATOR({}): Getting next model update from queue.".format(self.name)) + model_update = self.update_handler.next_model_update() # Load model paratmeters and metadata - model_next, metadata = self.load_model_update(model_update, helper) + model_next, metadata = self.update_handler.load_model_update(model_update, helper) - logger.info( - "AGGREGATOR({}): Processing model update {}".format(self.name, model_update.model_update_id)) + logger.info("AGGREGATOR({}): Processing model update {}".format(self.name, model_update.model_update_id)) # Increment total number of examples total_examples += metadata["num_examples"] if nr_aggregated_models == 0: - model_old = self.round_handler.load_model_update(helper, model_update.model_id) + model_old = self.update_handler.load_model(helper, model_update.model_id) pseudo_gradient = helper.subtract(model_next, model_old) else: pseudo_gradient_next = helper.subtract(model_next, model_old) - pseudo_gradient = helper.increment_average( - pseudo_gradient, pseudo_gradient_next, metadata["num_examples"], total_examples) + pseudo_gradient = helper.increment_average(pseudo_gradient, pseudo_gradient_next, metadata["num_examples"], total_examples) nr_aggregated_models += 1 # Delete model from storage if delete_models: - self.modelservice.temp_model_storage.delete(model_update.model_update_id) - logger.info( - "AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_update.model_update_id)) - self.model_updates.task_done() + self.update_handler.delete_model(model_update.model_update_id) + logger.info("AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_update.model_update_id)) except Exception as e: - logger.error( - "AGGREGATOR({}): Error encoutered while processing model update {}, skipping this update.".format(self.name, e)) - self.model_updates.task_done() + logger.error("AGGREGATOR({}): Error encoutered while processing model update {}, skipping this update.".format(self.name, e)) if parameters["serveropt"] == "adam": model = self.serveropt_adam(helper, pseudo_gradient, model_old, parameters) @@ -156,7 +138,7 @@ def combine_models(self, helper=None, delete_models=True, parameters=None): return model, data def serveropt_adam(self, helper, pseudo_gradient, model_old, parameters): - """ Server side optimization, FedAdam. + """Server side optimization, FedAdam. :param helper: instance of helper class. :type helper: Helper @@ -178,12 +160,12 @@ def serveropt_adam(self, helper, pseudo_gradient, model_old, parameters): self.v = helper.ones(pseudo_gradient, math.pow(tau, 2)) if not self.m: - self.m = helper.multiply(pseudo_gradient, [(1.0-beta1)]*len(pseudo_gradient)) + self.m = helper.multiply(pseudo_gradient, [(1.0 - beta1)] * len(pseudo_gradient)) else: - self.m = helper.add(self.m, pseudo_gradient, beta1, (1.0-beta1)) + self.m = helper.add(self.m, pseudo_gradient, beta1, (1.0 - beta1)) p = helper.power(pseudo_gradient, 2) - self.v = helper.add(self.v, p, beta2, (1.0-beta2)) + self.v = helper.add(self.v, p, beta2, (1.0 - beta2)) sv = helper.add(helper.sqrt(self.v), helper.ones(self.v, tau)) t = helper.divide(self.m, sv) @@ -192,7 +174,7 @@ def serveropt_adam(self, helper, pseudo_gradient, model_old, parameters): return model def serveropt_yogi(self, helper, pseudo_gradient, model_old, parameters): - """ Server side optimization, FedYogi. + """Server side optimization, FedYogi. :param helper: instance of helper class. :type helper: Helper @@ -214,14 +196,14 @@ def serveropt_yogi(self, helper, pseudo_gradient, model_old, parameters): self.v = helper.ones(pseudo_gradient, math.pow(tau, 2)) if not self.m: - self.m = helper.multiply(pseudo_gradient, [(1.0-beta1)]*len(pseudo_gradient)) + self.m = helper.multiply(pseudo_gradient, [(1.0 - beta1)] * len(pseudo_gradient)) else: - self.m = helper.add(self.m, pseudo_gradient, beta1, (1.0-beta1)) + self.m = helper.add(self.m, pseudo_gradient, beta1, (1.0 - beta1)) p = helper.power(pseudo_gradient, 2) s = helper.sign(helper.add(self.v, p, 1.0, -1.0)) s = helper.multiply(s, p) - self.v = helper.add(self.v, s, 1.0, -(1.0-beta2)) + self.v = helper.add(self.v, s, 1.0, -(1.0 - beta2)) sv = helper.add(helper.sqrt(self.v), helper.ones(self.v, tau)) t = helper.divide(self.m, sv) @@ -230,7 +212,7 @@ def serveropt_yogi(self, helper, pseudo_gradient, model_old, parameters): return model def serveropt_adagrad(self, helper, pseudo_gradient, model_old, parameters): - """ Server side optimization, FedAdam. + """Server side optimization, FedAdam. :param helper: instance of helper class. :type helper: Helper @@ -251,9 +233,9 @@ def serveropt_adagrad(self, helper, pseudo_gradient, model_old, parameters): self.v = helper.ones(pseudo_gradient, math.pow(tau, 2)) if not self.m: - self.m = helper.multiply(pseudo_gradient, [(1.0-beta1)]*len(pseudo_gradient)) + self.m = helper.multiply(pseudo_gradient, [(1.0 - beta1)] * len(pseudo_gradient)) else: - self.m = helper.add(self.m, pseudo_gradient, beta1, (1.0-beta1)) + self.m = helper.add(self.m, pseudo_gradient, beta1, (1.0 - beta1)) p = helper.power(pseudo_gradient, 2) self.v = helper.add(self.v, p, 1.0, 1.0) diff --git a/fedn/network/combiner/aggregators/tests/test_fedavg.py b/fedn/network/combiner/aggregators/tests/test_fedavg.py index 55e5052b8..583a7f51a 100644 --- a/fedn/network/combiner/aggregators/tests/test_fedavg.py +++ b/fedn/network/combiner/aggregators/tests/test_fedavg.py @@ -18,7 +18,7 @@ def test_fedavg_init(self, *args, **kwargs): def test_fedavg_combine_models(self, *args, **kwargs): """Test the FedAvg aggregator combine_models method with mock classes and methods""" aggregator = FedAvg("id", None, None, None, None) - aggregator.next_model_update = MagicMock(return_value=(None, None, None)) + aggregator.update_handler.next_model_update = MagicMock(return_value=[(None, None, None)]) aggregator.server = MagicMock() data = {} diff --git a/fedn/network/combiner/combiner.py b/fedn/network/combiner/combiner.py index 3f732ecd4..df82b0c54 100644 --- a/fedn/network/combiner/combiner.py +++ b/fedn/network/combiner/combiner.py @@ -489,7 +489,6 @@ def SetAggregator(self, control: fedn.ControlRequest, context): logger.debug("grpc.Combiner.SetAggregator: Called") for parameter in control.parameter: aggregator = parameter.value - status = self.round_handler.set_aggregator(aggregator) response = fedn.ControlResponse() @@ -497,7 +496,27 @@ def SetAggregator(self, control: fedn.ControlRequest, context): response.message = "Success" else: response.message = "Failed" + return response + + def SetServerFunctions(self, control: fedn.ControlRequest, context): + """Set a function provider. + + :param control: the control request + :type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest` + :param context: the context (unused) + :type context: :class:`grpc._server._Context` + :return: the control response + :rtype: :class:`fedn.network.grpc.fedn_pb2.ControlResponse` + """ + logger.debug("grpc.Combiner.SetServerFunctions: Called") + for parameter in control.parameter: + server_functions = parameter.value + + self.round_handler.set_server_functions(server_functions) + response = fedn.ControlResponse() + response.message = "Success" + logger.info(f"set function provider response {response}") return response def FlushAggregationQueue(self, control: fedn.ControlRequest, context): @@ -719,7 +738,7 @@ def SendModelUpdate(self, request, context): :return: the response :rtype: :class:`fedn.network.grpc.fedn_pb2.Response` """ - self.round_handler.aggregator.on_model_update(request) + self.round_handler.update_handler.on_model_update(request) response = fedn.Response() response.response = "RECEIVED ModelUpdate {} from client {}".format(response, response.sender.name) diff --git a/fedn/network/combiner/hooks/__init__.py b/fedn/network/combiner/hooks/__init__.py new file mode 100644 index 000000000..e2f53ea23 --- /dev/null +++ b/fedn/network/combiner/hooks/__init__.py @@ -0,0 +1 @@ +"""The FEDn Hooks package responsible for executing user defined code on the server.""" diff --git a/fedn/network/combiner/hooks/allowed_import.py b/fedn/network/combiner/hooks/allowed_import.py new file mode 100644 index 000000000..692790b12 --- /dev/null +++ b/fedn/network/combiner/hooks/allowed_import.py @@ -0,0 +1,7 @@ +import random # noqa: F401 +from typing import Dict, List, Tuple # noqa: F401 + +import numpy as np # noqa: F401 + +from fedn.common.log_config import logger # noqa: F401 +from fedn.network.combiner.hooks.serverfunctionsbase import ServerFunctionsBase # noqa: F401 diff --git a/fedn/network/combiner/hooks/hook_client.py b/fedn/network/combiner/hooks/hook_client.py new file mode 100644 index 000000000..340305881 --- /dev/null +++ b/fedn/network/combiner/hooks/hook_client.py @@ -0,0 +1,107 @@ +import json +import os + +import grpc + +import fedn.network.grpc.fedn_pb2 as fedn +import fedn.network.grpc.fedn_pb2_grpc as rpc +from fedn.common.log_config import logger +from fedn.network.combiner.modelservice import bytesIO_request_generator, model_as_bytesIO, unpack_model +from fedn.network.combiner.updatehandler import UpdateHandler + +CHUNK_SIZE = 1024 * 1024 + + +class CombinerHookInterface: + """Combiner to server function hooks client.""" + + def __init__(self): + """Initialize CombinerHookInterface client.""" + self.hook_service_host = os.getenv("HOOK_SERVICE_HOST", "hook:12081") + self.channel = grpc.insecure_channel( + self.hook_service_host, + options=[ + ("grpc.keepalive_time_ms", 30000), # 30 seconds ping interval + ("grpc.keepalive_timeout_ms", 5000), # 5 seconds timeout for a response + ("grpc.keepalive_permit_without_calls", 1), # allow keepalives even with no active calls + ("grpc.enable_retries", 1), # automatic retries + ("grpc.initial_reconnect_backoff_ms", 1000), # initial delay before retrying + ("grpc.max_reconnect_backoff_ms", 5000), # maximum delay before retrying + ], + ) + self.stub = rpc.FunctionServiceStub(self.channel) + + def provided_functions(self, server_functions: str): + """Communicates to hook container and asks which functions are available. + + :param server_functions: String version of an implementation of the ServerFunctionsBase interface. + :type server_functions: :str: + :return: dictionary specifing which functions are implemented. + :rtype: dict + """ + request = fedn.ProvidedFunctionsRequest(function_code=server_functions) + + response = self.stub.HandleProvidedFunctions(request) + return response.available_functions + + def client_settings(self, global_model) -> dict: + """Communicates to hook container to get a client config. + + :param global_model: The global model that will be distributed to clients. + :type global_model: :bytes: + :return: config that will be distributed to clients. + :rtype: dict + """ + request_function = fedn.ClientConfigRequest + args = {} + model = model_as_bytesIO(global_model) + response = self.stub.HandleClientConfig(bytesIO_request_generator(mdl=model, request_function=request_function, args=args)) + return json.loads(response.client_settings) + + def client_selection(self, clients: list) -> list: + request = fedn.ClientSelectionRequest(client_ids=json.dumps(clients)) + response = self.stub.HandleClientSelection(request) + return json.loads(response.client_ids) + + def aggregate(self, previous_global, update_handler: UpdateHandler, helper, delete_models: bool): + """Aggregation call to the hook functions. Sends models in chunks, then asks for aggregation. + + :param global_model: The global model that will be distributed to clients. + :type global_model: :bytes: + :return: config that will be distributed to clients. + :rtype: dict + """ + data = {} + data["time_model_load"] = 0.0 + data["time_model_aggregation"] = 0.0 + # send previous global + request_function = fedn.StoreModelRequest + args = {"id": "global_model"} + response = self.stub.HandleStoreModel(bytesIO_request_generator(mdl=previous_global, request_function=request_function, args=args)) + logger.info(f"Store model response: {response.status}") + # send client models and metadata + nr_updates = 0 + while not update_handler.model_updates.empty(): + logger.info("Getting next model update from queue.") + update = update_handler.next_model_update() + metadata = json.loads(update.meta)["training_metadata"] + model = update_handler.load_model_update_bytesIO(update.model_update_id) + # send metadata + client_id = update.sender.client_id + request = fedn.ClientMetaRequest(metadata=json.dumps(metadata), client_id=client_id) + response = self.stub.HandleMetadata(request) + # send client model + args = {"id": client_id} + request_function = fedn.StoreModelRequest + response = self.stub.HandleStoreModel(bytesIO_request_generator(mdl=model, request_function=request_function, args=args)) + logger.info(f"Store model response: {response.status}") + nr_updates += 1 + if delete_models: + # delete model from disk + update_handler.delete_model(model_update=update) + # ask for aggregation + request = fedn.AggregationRequest(aggregate="aggregate") + response_generator = self.stub.HandleAggregation(request) + data["nr_aggregated_models"] = nr_updates + model, _ = unpack_model(response_generator, helper) + return model, data diff --git a/fedn/network/combiner/hooks/hooks.py b/fedn/network/combiner/hooks/hooks.py new file mode 100644 index 000000000..a20ddb27f --- /dev/null +++ b/fedn/network/combiner/hooks/hooks.py @@ -0,0 +1,193 @@ +import json +from concurrent import futures + +import grpc + +import fedn.network.grpc.fedn_pb2 as fedn +import fedn.network.grpc.fedn_pb2_grpc as rpc +from fedn.common.log_config import logger + +# imports for user code +from fedn.network.combiner.hooks.allowed_import import Dict, List, ServerFunctionsBase, Tuple, np, random # noqa: F401 +from fedn.network.combiner.modelservice import bytesIO_request_generator, model_as_bytesIO, unpack_model +from fedn.utils.helpers.plugins.numpyhelper import Helper + +CHUNK_SIZE = 1024 * 1024 +VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$" + + +class FunctionServiceServicer(rpc.FunctionServiceServicer): + """Function service running in an environment combined with each combiner. + + Receiving requests from the combiner. + """ + + def __init__(self) -> None: + """Initialize long-running Function server.""" + super().__init__() + + self.helper = Helper() + self.server_functions: ServerFunctionsBase = None + self.server_functions_code: str = None + self.client_updates = {} + self.implemented_functions = None + + def HandleClientConfig(self, request_iterator: fedn.ClientConfigRequest, context): + """Distribute client configs to clients from user defined code. + + :param request_iterator: the client config request + :type request_iterator: :class:`fedn.network.grpc.fedn_pb2.ClientConfigRequest` + :param context: the context (unused) + :type context: :class:`grpc._server._Context` + :return: the client config response + :rtype: :class:`fedn.network.grpc.fedn_pb2.ClientConfigResponse` + """ + logger.info("Received client config request.") + model, _ = unpack_model(request_iterator, self.helper) + client_settings = self.server_functions.client_settings(global_model=model) + logger.info(f"Client config response: {client_settings}") + return fedn.ClientConfigResponse(client_settings=json.dumps(client_settings)) + + def HandleClientSelection(self, request: fedn.ClientSelectionRequest, context): + """Handle client selection from user defined code. + + :param request: the client selection request + :type request: :class:`fedn.network.grpc.fedn_pb2.fedn.ClientSelectionRequest` + :param context: the context (unused) + :type context: :class:`grpc._server._Context` + :return: the client selection response + :rtype: :class:`fedn.network.grpc.fedn_pb2.ClientSelectionResponse` + """ + logger.info("Received client selection request.") + client_ids = json.loads(request.client_ids) + client_ids = self.server_functions.client_selection(client_ids) + logger.info(f"Clients selected: {client_ids}") + return fedn.ClientSelectionResponse(client_ids=json.dumps(client_ids)) + + def HandleMetadata(self, request: fedn.ClientMetaRequest, context): + """Store client metadata from a request. + + :param request: the client meta request + :type request: :class:`fedn.network.grpc.fedn_pb2.fedn.ClientMetaRequest` + :param context: the context (unused) + :type context: :class:`grpc._server._Context` + :return: the client meta response + :rtype: :class:`fedn.network.grpc.fedn_pb2.ClientMetaResponse` + """ + logger.info("Received metadata") + client_id = request.client_id + metadata = json.loads(request.metadata) + self.client_updates[client_id] = self.client_updates.get(client_id, []) + [metadata] + return fedn.ClientMetaResponse(status="Metadata stored") + + def HandleStoreModel(self, request_iterator, context): + model, final_request = unpack_model(request_iterator, self.helper) + client_id = final_request.id + if client_id == "global_model": + logger.info("Received previous global model") + self.previous_global = model + else: + logger.info("Received client model") + self.client_updates[client_id] = [model] + self.client_updates.get(client_id, []) + return fedn.StoreModelResponse(status=f"Received model originating from {client_id}") + + def HandleAggregation(self, request, context): + """Receive and store models and aggregate based on user-defined code when specified in the request. + + :param request_iterator: the aggregation request + :type request_iterator: :class:`fedn.network.grpc.fedn_pb2.fedn.AggregationRequest` + :param context: the context (unused) + :type context: :class:`grpc._server._Context` + :return: the aggregation response (aggregated model or None) + :rtype: :class:`fedn.network.grpc.fedn_pb2.AggregationResponse` + """ + logger.info(f"Receieved aggregation request: {request.aggregate}") + aggregated_model = self.server_functions.aggregate(self.previous_global, self.client_updates) + model_bytesIO = model_as_bytesIO(aggregated_model, self.helper) + request_function = fedn.AggregationResponse + self.client_updates = {} + logger.info("Returning aggregate model.") + response_generator = bytesIO_request_generator(mdl=model_bytesIO, request_function=request_function, args={}) + for response in response_generator: + yield response + + def HandleProvidedFunctions(self, request: fedn.ProvidedFunctionsResponse, context): + """Handles the 'provided_functions' request. Sends back which functions are available. + + :param request: the provided function request + :type request: :class:`fedn.network.grpc.fedn_pb2.fedn.ProvidedFunctionsRequest` + :param context: the context (unused) + :type context: :class:`grpc._server._Context` + :return: dict with str -> bool for which functions are available + :rtype: :class:`fedn.network.grpc.fedn_pb2.ProvidedFunctionsResponse` + """ + logger.info("Receieved provided functions request.") + if self.implemented_functions is not None: + return fedn.ProvidedFunctionsResponse(available_functions=self.implemented_functions) + server_functions_code = request.function_code + self.server_functions_code = server_functions_code + self.implemented_functions = {} + self._instansiate_server_functions_code() + # if crashed or not returning None we assume function is implemented + # check if aggregation is available + try: + ret = self.server_functions.aggregate(0, 0) + if ret is None: + self.implemented_functions["aggregate"] = False + else: + self.implemented_functions["aggregate"] = True + except Exception: + self.implemented_functions["aggregate"] = True + # check if client_settings is available + try: + ret = self.server_functions.client_settings(0) + if ret is None: + self.implemented_functions["client_settings"] = False + else: + self.implemented_functions["client_settings"] = True + except Exception: + self.implemented_functions["client_settings"] = True + # check if client_selection is available + try: + ret = self.server_functions.client_selection(0) + if ret is None: + self.implemented_functions["client_selection"] = False + else: + self.implemented_functions["client_selection"] = True + except Exception: + self.implemented_functions["client_selection"] = True + logger.info(f"Provided function: {self.implemented_functions}") + return fedn.ProvidedFunctionsResponse(available_functions=self.implemented_functions) + + def _instansiate_server_functions_code(self): + # this will create a new user defined instance of the ServerFunctions class. + try: + namespace = {} + exec(self.server_functions_code, globals(), namespace) # noqa: S102 + exec("server_functions = ServerFunctions()", globals(), namespace) # noqa: S102 + self.server_functions = namespace.get("server_functions") + except Exception as e: + logger.error(f"Exec failed with error: {str(e)}") + + +def serve(): + """Start the hooks service.""" + # Keepalive settings: these detect if the client is alive + KEEPALIVE_TIME_MS = 60 * 1000 # send keepalive ping every 60 seconds + KEEPALIVE_TIMEOUT_MS = 20 * 1000 # wait 20 seconds for keepalive ping ack before considering connection dead + MAX_CONNECTION_IDLE_MS = 5 * 60 * 1000 # max idle time before server terminates the connection (5 minutes) + MAX_MESSAGE_LENGTH = 1 * 1024 * 1024 * 1024 # 1 GB in bytes + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=100), # Increase based on expected load + options=[ + ("grpc.keepalive_time_ms", KEEPALIVE_TIME_MS), + ("grpc.keepalive_timeout_ms", KEEPALIVE_TIMEOUT_MS), + ("grpc.max_connection_idle_ms", MAX_CONNECTION_IDLE_MS), + ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), + ("grpc.max_receive_message_length", -1), + ], + ) + rpc.add_FunctionServiceServicer_to_server(FunctionServiceServicer(), server) + server.add_insecure_port("[::]:12081") + server.start() + server.wait_for_termination() diff --git a/fedn/network/combiner/hooks/serverfunctionsbase.py b/fedn/network/combiner/hooks/serverfunctionsbase.py new file mode 100644 index 000000000..6b0666164 --- /dev/null +++ b/fedn/network/combiner/hooks/serverfunctionsbase.py @@ -0,0 +1,71 @@ +from abc import ABC +from typing import Dict, List, Tuple + +import numpy as np + + +class ServerFunctionsBase(ABC): + """Base class that defines the structure for the Server Functions. Override these functions + to add to the server workflow. + """ + + def __init__(self) -> None: + """Initialize the ServerFunctionsBase class. This method can be overridden + by subclasses if initialization logic is required. + """ + pass + + def aggregate(self, previous_global: List[np.ndarray], client_updates: Dict[str, Tuple[List[np.ndarray], Dict]]) -> List[np.ndarray]: + """Aggregates a list of parameters from clients. + + Args: + ---- + previous_global (list[np.ndarray]): A list of parameters representing the global + model from the previous round. + + client_updates (list[list[np.ndarray]]): A dictionary where the key is client ID, + pointing to a tuple with the first element being client parameter and second element + being the clients metadata. + + Returns: + ------- + list[np.ndarray]: A list of numpy arrays representing the aggregated + parameters across all clients. + + """ + pass + + def client_settings(self, global_model: List[np.ndarray]) -> Dict: + """Returns metadata related to the model, which gets distributed to the clients. + The dictionary may only contain primitive types. + + Args: + ---- + global_model (list[np.ndarray]): A list of parameters representing the global + model for the upcomming round. + + Returns: + ------- + dict: A dictionary containing metadata information, supporting only primitive python types. + + """ + pass + + def client_selection(self, client_ids: List[str]) -> List: + """Returns a list of client_id's of which clients to be used for the next training request. + + Args: + ---- + client_ids (list[str]): A list of client_ids for all connected clients. + + Returns: + ------- + list[str]: A list of client ids for which clients should be chosen for the next training round. + + """ + pass + + +# base implementation +class ServerFunctions(ServerFunctionsBase): + pass diff --git a/fedn/network/combiner/interfaces.py b/fedn/network/combiner/interfaces.py index 20da29d23..426e7d9f6 100644 --- a/fedn/network/combiner/interfaces.py +++ b/fedn/network/combiner/interfaces.py @@ -196,6 +196,28 @@ def set_aggregator(self, aggregator): else: raise + def set_server_functions(self, server_functions): + """Set the function provider module. + + :param function provider: Stringified function provider code. + :type config: str + """ + channel = Channel(self.address, self.port, self.certificate).get_channel() + control = rpc.ControlStub(channel) + + request = fedn.ControlRequest() + p = request.parameter.add() + p.key = "server_functions" + p.value = server_functions + + try: + control.SetServerFunctions(request) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAVAILABLE: + raise CombinerUnavailableError + else: + raise + def submit(self, config: RoundConfig): """Submit a compute plan to the combiner. diff --git a/fedn/network/combiner/modelservice.py b/fedn/network/combiner/modelservice.py index b5e7bff73..89901dabb 100644 --- a/fedn/network/combiner/modelservice.py +++ b/fedn/network/combiner/modelservice.py @@ -2,6 +2,8 @@ import tempfile from io import BytesIO +import numpy as np + import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc from fedn.common.log_config import logger @@ -29,7 +31,36 @@ def upload_request_generator(mdl, id): break -def model_as_bytesIO(model): +def bytesIO_request_generator(mdl, request_function, args): + """Generator function for model upload requests. + + :param mdl: The model update object. + :type mdl: BytesIO + :param request_function: Function for sending requests. + :type request_function: Function + :param args: request arguments, excluding data argument. + :type args: dict + :return: Yields grpc request for streaming. + :rtype: grpc request generator. + """ + while True: + b = mdl.read(CHUNK_SIZE) + if b: + result = request_function(data=b, **args) + else: + result = request_function(data=None, **args) + yield result + if not b: + break + + +def model_as_bytesIO(model, helper=None): + if isinstance(model, list): + bt = BytesIO() + model_dict = {str(i): w for i, w in enumerate(model)} + np.savez_compressed(bt, **model_dict) + bt.seek(0) + return bt if not isinstance(model, BytesIO): bt = BytesIO() @@ -44,6 +75,31 @@ def model_as_bytesIO(model): return bt +def unpack_model(request_iterator, helper): + """Unpack an incoming model sent in chunks from a request iterator. + + :param request_iterator: A streaming iterator from an gRPC service. + :return: The reconstructed model parameters. + """ + model_buffer = BytesIO() + try: + for request in request_iterator: + if request.data: + model_buffer.write(request.data) + except MemoryError as e: + logger.error(f"Memory error occured when loading model, reach out to the FEDn team if you need a solution to this. {e}") + raise + except Exception as e: + logger.error(f"Exception occured during model loading: {e}") + raise + + model_buffer.seek(0) + + model_bytes = model_buffer.getvalue() + + return load_model_from_bytes(model_bytes, helper), request + + def get_tmp_path(): """Return a temporary output path compatible with save_model, load_model.""" fd, path = tempfile.mkstemp() @@ -51,10 +107,10 @@ def get_tmp_path(): return path -def load_model_from_BytesIO(model_bytesio, helper): - """Load a model from a BytesIO object. - :param model_bytesio: A BytesIO object containing the model. - :type model_bytesio: :class:`io.BytesIO` +def load_model_from_bytes(model_bytes, helper): + """Load a model from a bytes object. + :param model_bytesio: A bytes object containing the model. + :type model_bytes: :class:`bytes` :param helper: The helper object for the model. :type helper: :class:`fedn.utils.helperbase.HelperBase` :return: The model object. @@ -62,7 +118,7 @@ def load_model_from_BytesIO(model_bytesio, helper): """ path = get_tmp_path() with open(path, "wb") as fh: - fh.write(model_bytesio) + fh.write(model_bytes) fh.flush() model = helper.load(path) os.unlink(path) diff --git a/fedn/network/combiner/roundhandler.py b/fedn/network/combiner/roundhandler.py index a4808f7a5..5eb5387d8 100644 --- a/fedn/network/combiner/roundhandler.py +++ b/fedn/network/combiner/roundhandler.py @@ -1,15 +1,18 @@ import ast +import inspect import queue import random -import sys import time import uuid from typing import TypedDict from fedn.common.log_config import logger from fedn.network.combiner.aggregators.aggregatorbase import get_aggregator -from fedn.network.combiner.modelservice import load_model_from_BytesIO, serialize_model_to_BytesIO +from fedn.network.combiner.hooks.hook_client import CombinerHookInterface +from fedn.network.combiner.hooks.serverfunctionsbase import ServerFunctions +from fedn.network.combiner.modelservice import serialize_model_to_BytesIO from fedn.network.combiner.shared import modelservice, repository +from fedn.network.combiner.updatehandler import UpdateHandler from fedn.utils.helpers.helpers import get_helper from fedn.utils.parameters import Parameters @@ -48,6 +51,8 @@ class RoundConfig(TypedDict): :type helper_type: str :param aggregator: The aggregator type. :type aggregator: str + :param client_settings: Settings that are distributed to clients. + :type client_settings: dict """ _job_id: str @@ -65,10 +70,7 @@ class RoundConfig(TypedDict): session_id: str helper_type: str aggregator: str - - -class ModelUpdateError(Exception): - pass + client_settings: dict class RoundHandler: @@ -93,9 +95,15 @@ def __init__(self, server): self.storage = repository self.server = server self.modelservice = modelservice + self.server_functions = inspect.getsource(ServerFunctions) + self.update_handler = UpdateHandler(modelservice=modelservice) + self.hook_interface = CombinerHookInterface() def set_aggregator(self, aggregator): - self.aggregator = get_aggregator(aggregator, self.storage, self.server, self.modelservice, self) + self.aggregator = get_aggregator(aggregator, self.update_handler) + + def set_server_functions(self, server_functions: str): + self.server_functions = server_functions def push_round_config(self, round_config: RoundConfig) -> str: """Add a round_config (job description) to the inbox. @@ -113,76 +121,7 @@ def push_round_config(self, round_config: RoundConfig) -> str: raise return round_config["_job_id"] - def load_model_update(self, helper, model_id): - """Load model update with id model_id into its memory representation. - - :param helper: An instance of :class: `fedn.utils.helpers.helpers.HelperBase` - :type helper: class: `fedn.utils.helpers.helpers.HelperBase` - :param model_id: The ID of the model update, UUID in str format - :type model_id: str - """ - model_str = self.load_model_update_str(model_id) - if model_str: - try: - model = load_model_from_BytesIO(model_str.getbuffer(), helper) - except IOError: - logger.warning("AGGREGATOR({}): Failed to load model!".format(self.name)) - else: - raise ModelUpdateError("Failed to load model.") - - return model - - def load_model_update_str(self, model_id, retry=3): - """Load model update object and return it as BytesIO. - - :param model_id: The ID of the model - :type model_id: str - :param retry: number of times retrying load model update, defaults to 3 - :type retry: int, optional - :return: Updated model - :rtype: class: `io.BytesIO` - """ - # Try reading model update from local disk/combiner memory - model_str = self.modelservice.temp_model_storage.get(model_id) - # And if we cannot access that, try downloading from the server - if model_str is None: - model_str = self.modelservice.get_model(model_id) - # TODO: use retrying library - tries = 0 - while tries < retry: - tries += 1 - if not model_str or sys.getsizeof(model_str) == 80: - logger.warning("Model download failed. retrying") - time.sleep(1) - model_str = self.modelservice.get_model(model_id) - - return model_str - - def waitforit(self, config, buffer_size=100, polling_interval=0.1): - """Defines the policy for how long the server should wait before starting to aggregate models. - - The policy is as follows: - 1. Wait a maximum of time_window time until the round times out. - 2. Terminate if a preset number of model updates (buffer_size) are in the queue. - - :param config: The round config object - :type config: dict - :param buffer_size: The number of model updates to wait for before starting aggregation, defaults to 100 - :type buffer_size: int, optional - :param polling_interval: The polling interval, defaults to 0.1 - :type polling_interval: float, optional - """ - time_window = float(config["round_timeout"]) - - tt = 0.0 - while tt < time_window: - if self.aggregator.model_updates.qsize() >= buffer_size: - break - - time.sleep(polling_interval) - tt += polling_interval - - def _training_round(self, config, clients): + def _training_round(self, config, clients, provided_functions): """Send model update requests to clients and aggregate results. :param config: The round config object (passed to the client). @@ -202,6 +141,10 @@ def _training_round(self, config, clients): session_id = config["session_id"] model_id = config["model_id"] + if provided_functions["client_settings"]: + global_model_bytes = self.modelservice.temp_model_storage.get(model_id) + client_settings = self.hook_interface.client_settings(global_model_bytes) + config["client_settings"] = client_settings # Request model updates from all active clients. self.server.request_model_update(session_id=session_id, model_id=model_id, config=config, clients=clients) @@ -212,12 +155,10 @@ def _training_round(self, config, clients): buffer_size = int(config["buffer_size"]) # Wait / block until the round termination policy has been met. - self.waitforit(config, buffer_size=buffer_size) - + self.update_handler.waitforit(config, buffer_size=buffer_size) tic = time.time() model = None data = None - try: helper = get_helper(config["helper_type"]) logger.info("Config delete_models_storage: {}".format(config["delete_models_storage"])) @@ -231,10 +172,14 @@ def _training_round(self, config, clients): parameters = Parameters(dict_parameters) else: parameters = None - - model, data = self.aggregator.combine_models(helper=helper, delete_models=delete_models, parameters=parameters) + if provided_functions["aggregate"]: + previous_model_bytes = self.modelservice.temp_model_storage.get(model_id) + model, data = self.hook_interface.aggregate(previous_model_bytes, self.update_handler, helper, delete_models=delete_models) + else: + model, data = self.aggregator.combine_models(helper=helper, delete_models=delete_models, parameters=parameters) except Exception as e: logger.warning("AGGREGATION FAILED AT COMBINER! {}".format(e)) + raise meta["time_combination"] = time.time() - tic meta["aggregation_time"] = data @@ -379,8 +324,13 @@ def execute_training_round(self, config): # Download model to update and set in temp storage. self.stage_model(config["model_id"]) - clients = self._assign_round_clients(self.server.max_clients) - model, meta = self._training_round(config, clients) + provided_functions = self.hook_interface.provided_functions(self.server_functions) + + if provided_functions["client_selection"]: + clients = self.hook_interface.client_selection(clients=self.server.get_active_trainers()) + else: + clients = self._assign_round_clients(self.server.max_clients) + model, meta = self._training_round(config, clients, provided_functions) data["data"] = meta if model is None: diff --git a/fedn/network/combiner/updatehandler.py b/fedn/network/combiner/updatehandler.py new file mode 100644 index 000000000..517595d13 --- /dev/null +++ b/fedn/network/combiner/updatehandler.py @@ -0,0 +1,210 @@ +import json +import queue +import sys +import time +import traceback + +from fedn.common.log_config import logger +from fedn.network.combiner.modelservice import ModelService, load_model_from_bytes + + +class ModelUpdateError(Exception): + pass + + +class UpdateHandler: + """Update handler. + + Responsible for receiving, loading and supplying client model updates. + + :param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService` + :type modelservice: class: `fedn.network.combiner.modelservice.ModelService` + """ + + def __init__(self, modelservice: ModelService) -> None: + self.model_updates = queue.Queue() + self.modelservice = modelservice + + self.model_id_to_model_data = {} + + def delete_model(self, model_update): + self.modelservice.temp_model_storage.delete(model_update.model_update_id) + logger.info("UPDATE HANDLER: Deleted model update {} from storage.".format(model_update.model_update_id)) + + def next_model_update(self): + """Get the next model update from the queue. + + :param helper: A helper object. + :type helper: object + :return: The model update. + :rtype: fedn.network.grpc.fedn.proto.ModelUpdate + """ + model_update = self.model_updates.get(block=False) + return model_update + + def on_model_update(self, model_update): + """Callback when a new client model update is recieved. + + Performs (optional) validation and pre-processing, + and then puts the update id on the aggregation queue. + Override in subclass as needed. + + :param model_update: fedn.network.grpc.fedn.proto.ModelUpdate + :type model_id: str + """ + try: + logger.info("UPDATE HANDLER: callback received model update {}".format(model_update.model_update_id)) + + # Validate the update and metadata + valid_update = self._validate_model_update(model_update) + if valid_update: + # Push the model update to the processing queue + self.model_updates.put(model_update) + else: + logger.warning("UPDATE HANDLER: Invalid model update, skipping.") + except Exception as e: + tb = traceback.format_exc() + logger.error("UPDATE HANDLER: failed to receive model update: {}".format(e)) + logger.error(tb) + pass + + def _validate_model_update(self, model_update): + """Validate the model update. + + :param model_update: A ModelUpdate message. + :type model_update: object + :return: True if the model update is valid, False otherwise. + :rtype: bool + """ + try: + data = json.loads(model_update.meta)["training_metadata"] + _ = data["num_examples"] + except KeyError: + tb = traceback.format_exc() + logger.error("UPDATE HANDLER: Invalid model update, missing metadata.") + logger.error(tb) + return False + return True + + def load_model_update(self, model_update, helper): + """Load the memory representation of the model update. + + Load the model update paramters and the + associate metadata into memory. + + :param model_update: The model update. + :type model_update: fedn.network.grpc.fedn.proto.ModelUpdate + :param helper: A helper object. + :type helper: fedn.utils.helpers.helperbase.Helper + :return: A tuple of (parameters, metadata) + :rtype: tuple + """ + model_id = model_update.model_update_id + model = self.load_model(helper, model_id) + # Get relevant metadata + metadata = json.loads(model_update.meta) + if "config" in metadata.keys(): + # Used in Python client + config = json.loads(metadata["config"]) + else: + # Used in C++ client + config = json.loads(model_update.config) + training_metadata = metadata["training_metadata"] + training_metadata["round_id"] = config["round_id"] + + return model, training_metadata + + def load_model_update_byte(self, model_update): + """Load the memory representation of the model update. + + Load the model update paramters and the + associate metadata into memory. + + :param model_update: The model update. + :type model_update: fedn.network.grpc.fedn.proto.ModelUpdate + :return: A tuple of parameters(bytes), metadata + :rtype: tuple + """ + model_id = model_update.model_update_id + model = self.load_model_update_bytesIO(model_id).getbuffer() + # Get relevant metadata + metadata = json.loads(model_update.meta) + if "config" in metadata.keys(): + # Used in Python client + config = json.loads(metadata["config"]) + else: + # Used in C++ client + config = json.loads(model_update.config) + training_metadata = metadata["training_metadata"] + training_metadata["round_id"] = config["round_id"] + + return model, training_metadata + + def load_model(self, helper, model_id): + """Load model update with id model_id into its memory representation. + + :param helper: An instance of :class: `fedn.utils.helpers.helpers.HelperBase` + :type helper: class: `fedn.utils.helpers.helpers.HelperBase` + :param model_id: The ID of the model update, UUID in str format + :type model_id: str + """ + model_bytesIO = self.load_model_update_bytesIO(model_id) + if model_bytesIO: + try: + model = load_model_from_bytes(model_bytesIO.getbuffer(), helper) + except IOError: + logger.warning("UPDATE HANDLER: Failed to load model!") + else: + raise ModelUpdateError("Failed to load model.") + + return model + + def load_model_update_bytesIO(self, model_id, retry=3): + """Load model update object and return it as BytesIO. + + :param model_id: The ID of the model + :type model_id: str + :param retry: number of times retrying load model update, defaults to 3 + :type retry: int, optional + :return: Updated model + :rtype: class: `io.BytesIO` + """ + # Try reading model update from local disk/combiner memory + model_str = self.modelservice.temp_model_storage.get(model_id) + # And if we cannot access that, try downloading from the server + if model_str is None: + model_str = self.modelservice.get_model(model_id) + # TODO: use retrying library + tries = 0 + while tries < retry: + tries += 1 + if not model_str or sys.getsizeof(model_str) == 80: + logger.warning("Model download failed. retrying") + time.sleep(1) + model_str = self.modelservice.get_model(model_id) + + return model_str + + def waitforit(self, config, buffer_size=100, polling_interval=0.1): + """Defines the policy for how long the server should wait before starting to aggregate models. + + The policy is as follows: + 1. Wait a maximum of time_window time until the round times out. + 2. Terminate if a preset number of model updates (buffer_size) are in the queue. + + :param config: The round config object + :type config: dict + :param buffer_size: The number of model updates to wait for before starting aggregation, defaults to 100 + :type buffer_size: int, optional + :param polling_interval: The polling interval, defaults to 0.1 + :type polling_interval: float, optional + """ + time_window = float(config["round_timeout"]) + + tt = 0.0 + while tt < time_window: + if self.model_updates.qsize() >= buffer_size: + break + + time.sleep(polling_interval) + tt += polling_interval diff --git a/fedn/network/controller/control.py b/fedn/network/controller/control.py index 129dd1af0..4fcfcd734 100644 --- a/fedn/network/controller/control.py +++ b/fedn/network/controller/control.py @@ -7,7 +7,7 @@ from fedn.common.log_config import logger from fedn.network.combiner.interfaces import CombinerUnavailableError -from fedn.network.combiner.modelservice import load_model_from_BytesIO +from fedn.network.combiner.modelservice import load_model_from_bytes from fedn.network.combiner.roundhandler import RoundConfig from fedn.network.controller.controlbase import ControlBase from fedn.network.state import ReducerState @@ -129,6 +129,8 @@ def start_session(self, session_id: str, rounds: int, round_timeout: int) -> Non for combiner in self.network.get_combiners(): combiner.set_aggregator(aggregator) + if session_config["server_functions"] is not None: + combiner.set_server_functions(session_config["server_functions"]) self.set_session_status(session_id, "Started") @@ -184,6 +186,8 @@ def session(self, config: RoundConfig) -> None: for combiner in self.network.get_combiners(): combiner.set_aggregator(config["aggregator"]) + if config["server_functions"] is not None: + combiner.set_server_functions(config["server_functions"]) self.set_session_status(config["session_id"], "Started") # Execute the rounds in this session @@ -422,14 +426,14 @@ def reduce(self, combiners): try: tic = time.time() helper = self.get_helper() - model_next = load_model_from_BytesIO(data, helper) + model_next = load_model_from_bytes(data, helper) meta["time_load_model"] += time.time() - tic tic = time.time() model = helper.increment_average(model, model_next, 1.0, i) meta["time_aggregate_model"] += time.time() - tic except Exception: tic = time.time() - model = load_model_from_BytesIO(data, helper) + model = load_model_from_bytes(data, helper) meta["time_aggregate_model"] += time.time() - tic i = i + 1 diff --git a/fedn/network/grpc/fedn.proto b/fedn/network/grpc/fedn.proto index 558b7e67d..fb7c312f4 100644 --- a/fedn/network/grpc/fedn.proto +++ b/fedn/network/grpc/fedn.proto @@ -206,6 +206,7 @@ service Control { rpc Stop(ControlRequest) returns (ControlResponse); rpc FlushAggregationQueue(ControlRequest) returns (ControlResponse); rpc SetAggregator(ControlRequest) returns (ControlResponse); + rpc SetServerFunctions(ControlRequest) returns (ControlResponse); } service Reducer { @@ -256,3 +257,61 @@ service Combiner { } +message ProvidedFunctionsRequest { + string function_code = 1; +} + +message ProvidedFunctionsResponse { + map available_functions = 1; +} + +message ClientConfigRequest { + bytes data = 1; +} + +message ClientConfigResponse { + string client_settings = 1; +} + +message ClientSelectionRequest { + string client_ids = 1; +} + +message ClientSelectionResponse { + string client_ids = 1; +} + +message ClientMetaRequest { + string metadata = 1; + string client_id = 2; +} + +message ClientMetaResponse { + string status = 1; +} + +message StoreModelRequest { + bytes data = 1; + string id = 2; +} + +message StoreModelResponse { + string status = 1; +} + +message AggregationRequest { + string aggregate = 1; +} + +message AggregationResponse { + bytes data = 1; +} + +service FunctionService { + rpc HandleProvidedFunctions(ProvidedFunctionsRequest) returns (ProvidedFunctionsResponse); + rpc HandleClientConfig (stream ClientConfigRequest) returns (ClientConfigResponse); + rpc HandleClientSelection (ClientSelectionRequest) returns (ClientSelectionResponse); + rpc HandleMetadata (ClientMetaRequest) returns (ClientMetaResponse); + rpc HandleStoreModel (stream StoreModelRequest) returns (StoreModelResponse); + rpc HandleAggregation (AggregationRequest) returns (stream AggregationResponse); +} \ No newline at end of file diff --git a/fedn/network/grpc/fedn_pb2.py b/fedn/network/grpc/fedn_pb2.py index cb637baba..09ecdec36 100644 --- a/fedn/network/grpc/fedn_pb2.py +++ b/fedn/network/grpc/fedn_pb2.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# source: network/grpc/fedn.proto +# source: fedn/network/grpc/fedn.proto # Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor @@ -15,79 +15,109 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17network/grpc/fedn.proto\x12\x04\x66\x65\x64n\x1a\x1fgoogle/protobuf/timestamp.proto\":\n\x08Response\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08response\x18\x02 \x01(\t\"\xbc\x02\n\x06Status\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x0e\n\x06status\x18\x02 \x01(\t\x12(\n\tlog_level\x18\x03 \x01(\x0e\x32\x15.fedn.Status.LogLevel\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x1e\n\x04type\x18\x07 \x01(\x0e\x32\x10.fedn.StatusType\x12\r\n\x05\x65xtra\x18\x08 \x01(\t\x12\x12\n\nsession_id\x18\t \x01(\t\"B\n\x08LogLevel\x12\x08\n\x04INFO\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x0b\n\x07WARNING\x10\x02\x12\t\n\x05\x45RROR\x10\x03\x12\t\n\x05\x41UDIT\x10\x04\"\xd8\x01\n\x0bTaskRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x12\n\nsession_id\x18\x08 \x01(\t\x12\x1e\n\x04type\x18\t \x01(\x0e\x32\x10.fedn.StatusType\"\xbf\x01\n\x0bModelUpdate\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x17\n\x0fmodel_update_id\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x0e\n\x06\x63onfig\x18\x08 \x01(\t\"\xd8\x01\n\x0fModelValidation\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x12\n\nsession_id\x18\x08 \x01(\t\"\xdb\x01\n\x0fModelPrediction\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x15\n\rprediction_id\x18\x08 \x01(\t\"\x89\x01\n\x0cModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12\n\n\x02id\x18\x04 \x01(\t\x12!\n\x06status\x18\x05 \x01(\x0e\x32\x11.fedn.ModelStatus\"]\n\rModelResponse\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\n\n\x02id\x18\x02 \x01(\t\x12!\n\x06status\x18\x03 \x01(\x0e\x32\x11.fedn.ModelStatus\x12\x0f\n\x07message\x18\x04 \x01(\t\"U\n\x15GetGlobalModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\"h\n\x16GetGlobalModelResponse\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\")\n\tHeartbeat\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\"W\n\x16\x43lientAvailableMessage\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"P\n\x12ListClientsRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1c\n\x07\x63hannel\x18\x02 \x01(\x0e\x32\x0b.fedn.Queue\"*\n\nClientList\x12\x1c\n\x06\x63lient\x18\x01 \x03(\x0b\x32\x0c.fedn.Client\"C\n\x06\x43lient\x12\x18\n\x04role\x18\x01 \x01(\x0e\x32\n.fedn.Role\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tclient_id\x18\x03 \x01(\t\"m\n\x0fReassignRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x0e\n\x06server\x18\x03 \x01(\t\x12\x0c\n\x04port\x18\x04 \x01(\r\"c\n\x10ReconnectRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x11\n\treconnect\x18\x03 \x01(\r\"\'\n\tParameter\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"T\n\x0e\x43ontrolRequest\x12\x1e\n\x07\x63ommand\x18\x01 \x01(\x0e\x32\r.fedn.Command\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.fedn.Parameter\"F\n\x0f\x43ontrolResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.fedn.Parameter\"\x13\n\x11\x43onnectionRequest\"<\n\x12\x43onnectionResponse\x12&\n\x06status\x18\x01 \x01(\x0e\x32\x16.fedn.ConnectionStatus*\x8b\x01\n\nStatusType\x12\x07\n\x03LOG\x10\x00\x12\x18\n\x14MODEL_UPDATE_REQUEST\x10\x01\x12\x10\n\x0cMODEL_UPDATE\x10\x02\x12\x1c\n\x18MODEL_VALIDATION_REQUEST\x10\x03\x12\x14\n\x10MODEL_VALIDATION\x10\x04\x12\x14\n\x10MODEL_PREDICTION\x10\x05*$\n\x05Queue\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x0e\n\nTASK_QUEUE\x10\x01*S\n\x0bModelStatus\x12\x06\n\x02OK\x10\x00\x12\x0f\n\x0bIN_PROGRESS\x10\x01\x12\x12\n\x0eIN_PROGRESS_OK\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03\x12\x0b\n\x07UNKNOWN\x10\x04*8\n\x04Role\x12\n\n\x06WORKER\x10\x00\x12\x0c\n\x08\x43OMBINER\x10\x01\x12\x0b\n\x07REDUCER\x10\x02\x12\t\n\x05OTHER\x10\x03*J\n\x07\x43ommand\x12\x08\n\x04IDLE\x10\x00\x12\t\n\x05START\x10\x01\x12\t\n\x05PAUSE\x10\x02\x12\x08\n\x04STOP\x10\x03\x12\t\n\x05RESET\x10\x04\x12\n\n\x06REPORT\x10\x05*I\n\x10\x43onnectionStatus\x12\x11\n\rNOT_ACCEPTING\x10\x00\x12\r\n\tACCEPTING\x10\x01\x12\x13\n\x0fTRY_AGAIN_LATER\x10\x02\x32z\n\x0cModelService\x12\x33\n\x06Upload\x12\x12.fedn.ModelRequest\x1a\x13.fedn.ModelResponse(\x01\x12\x35\n\x08\x44ownload\x12\x12.fedn.ModelRequest\x1a\x13.fedn.ModelResponse0\x01\x32\xf8\x01\n\x07\x43ontrol\x12\x34\n\x05Start\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12\x33\n\x04Stop\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12\x44\n\x15\x46lushAggregationQueue\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12<\n\rSetAggregator\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse2V\n\x07Reducer\x12K\n\x0eGetGlobalModel\x12\x1b.fedn.GetGlobalModelRequest\x1a\x1c.fedn.GetGlobalModelResponse2\xab\x03\n\tConnector\x12\x44\n\x14\x41llianceStatusStream\x12\x1c.fedn.ClientAvailableMessage\x1a\x0c.fedn.Status0\x01\x12*\n\nSendStatus\x12\x0c.fedn.Status\x1a\x0e.fedn.Response\x12?\n\x11ListActiveClients\x12\x18.fedn.ListClientsRequest\x1a\x10.fedn.ClientList\x12\x45\n\x10\x41\x63\x63\x65ptingClients\x12\x17.fedn.ConnectionRequest\x1a\x18.fedn.ConnectionResponse\x12\x30\n\rSendHeartbeat\x12\x0f.fedn.Heartbeat\x1a\x0e.fedn.Response\x12\x37\n\x0eReassignClient\x12\x15.fedn.ReassignRequest\x1a\x0e.fedn.Response\x12\x39\n\x0fReconnectClient\x12\x16.fedn.ReconnectRequest\x1a\x0e.fedn.Response2\xfd\x01\n\x08\x43ombiner\x12?\n\nTaskStream\x12\x1c.fedn.ClientAvailableMessage\x1a\x11.fedn.TaskRequest0\x01\x12\x34\n\x0fSendModelUpdate\x12\x11.fedn.ModelUpdate\x1a\x0e.fedn.Response\x12<\n\x13SendModelValidation\x12\x15.fedn.ModelValidation\x1a\x0e.fedn.Response\x12<\n\x13SendModelPrediction\x12\x15.fedn.ModelPrediction\x1a\x0e.fedn.Responseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66\x65\x64n/network/grpc/fedn.proto\x12\x04\x66\x65\x64n\x1a\x1fgoogle/protobuf/timestamp.proto\":\n\x08Response\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08response\x18\x02 \x01(\t\"\xbc\x02\n\x06Status\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x0e\n\x06status\x18\x02 \x01(\t\x12(\n\tlog_level\x18\x03 \x01(\x0e\x32\x15.fedn.Status.LogLevel\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x1e\n\x04type\x18\x07 \x01(\x0e\x32\x10.fedn.StatusType\x12\r\n\x05\x65xtra\x18\x08 \x01(\t\x12\x12\n\nsession_id\x18\t \x01(\t\"B\n\x08LogLevel\x12\x08\n\x04INFO\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x0b\n\x07WARNING\x10\x02\x12\t\n\x05\x45RROR\x10\x03\x12\t\n\x05\x41UDIT\x10\x04\"\xd8\x01\n\x0bTaskRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x12\n\nsession_id\x18\x08 \x01(\t\x12\x1e\n\x04type\x18\t \x01(\x0e\x32\x10.fedn.StatusType\"\xbf\x01\n\x0bModelUpdate\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x17\n\x0fmodel_update_id\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x0e\n\x06\x63onfig\x18\x08 \x01(\t\"\xd8\x01\n\x0fModelValidation\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x12\n\nsession_id\x18\x08 \x01(\t\"\xdb\x01\n\x0fModelPrediction\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x15\n\rprediction_id\x18\x08 \x01(\t\"\x89\x01\n\x0cModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12\n\n\x02id\x18\x04 \x01(\t\x12!\n\x06status\x18\x05 \x01(\x0e\x32\x11.fedn.ModelStatus\"]\n\rModelResponse\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\n\n\x02id\x18\x02 \x01(\t\x12!\n\x06status\x18\x03 \x01(\x0e\x32\x11.fedn.ModelStatus\x12\x0f\n\x07message\x18\x04 \x01(\t\"U\n\x15GetGlobalModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\"h\n\x16GetGlobalModelResponse\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\")\n\tHeartbeat\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\"W\n\x16\x43lientAvailableMessage\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"P\n\x12ListClientsRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1c\n\x07\x63hannel\x18\x02 \x01(\x0e\x32\x0b.fedn.Queue\"*\n\nClientList\x12\x1c\n\x06\x63lient\x18\x01 \x03(\x0b\x32\x0c.fedn.Client\"C\n\x06\x43lient\x12\x18\n\x04role\x18\x01 \x01(\x0e\x32\n.fedn.Role\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tclient_id\x18\x03 \x01(\t\"m\n\x0fReassignRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x0e\n\x06server\x18\x03 \x01(\t\x12\x0c\n\x04port\x18\x04 \x01(\r\"c\n\x10ReconnectRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x11\n\treconnect\x18\x03 \x01(\r\"\'\n\tParameter\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"T\n\x0e\x43ontrolRequest\x12\x1e\n\x07\x63ommand\x18\x01 \x01(\x0e\x32\r.fedn.Command\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.fedn.Parameter\"F\n\x0f\x43ontrolResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.fedn.Parameter\"\x13\n\x11\x43onnectionRequest\"<\n\x12\x43onnectionResponse\x12&\n\x06status\x18\x01 \x01(\x0e\x32\x16.fedn.ConnectionStatus\"1\n\x18ProvidedFunctionsRequest\x12\x15\n\rfunction_code\x18\x01 \x01(\t\"\xac\x01\n\x19ProvidedFunctionsResponse\x12T\n\x13\x61vailable_functions\x18\x01 \x03(\x0b\x32\x37.fedn.ProvidedFunctionsResponse.AvailableFunctionsEntry\x1a\x39\n\x17\x41vailableFunctionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x08:\x02\x38\x01\"#\n\x13\x43lientConfigRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"/\n\x14\x43lientConfigResponse\x12\x17\n\x0f\x63lient_settings\x18\x01 \x01(\t\",\n\x16\x43lientSelectionRequest\x12\x12\n\nclient_ids\x18\x01 \x01(\t\"-\n\x17\x43lientSelectionResponse\x12\x12\n\nclient_ids\x18\x01 \x01(\t\"8\n\x11\x43lientMetaRequest\x12\x10\n\x08metadata\x18\x01 \x01(\t\x12\x11\n\tclient_id\x18\x02 \x01(\t\"$\n\x12\x43lientMetaResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\"-\n\x11StoreModelRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\n\n\x02id\x18\x02 \x01(\t\"$\n\x12StoreModelResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\"\'\n\x12\x41ggregationRequest\x12\x11\n\taggregate\x18\x01 \x01(\t\"#\n\x13\x41ggregationResponse\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c*\x8b\x01\n\nStatusType\x12\x07\n\x03LOG\x10\x00\x12\x18\n\x14MODEL_UPDATE_REQUEST\x10\x01\x12\x10\n\x0cMODEL_UPDATE\x10\x02\x12\x1c\n\x18MODEL_VALIDATION_REQUEST\x10\x03\x12\x14\n\x10MODEL_VALIDATION\x10\x04\x12\x14\n\x10MODEL_PREDICTION\x10\x05*$\n\x05Queue\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x0e\n\nTASK_QUEUE\x10\x01*S\n\x0bModelStatus\x12\x06\n\x02OK\x10\x00\x12\x0f\n\x0bIN_PROGRESS\x10\x01\x12\x12\n\x0eIN_PROGRESS_OK\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03\x12\x0b\n\x07UNKNOWN\x10\x04*8\n\x04Role\x12\n\n\x06WORKER\x10\x00\x12\x0c\n\x08\x43OMBINER\x10\x01\x12\x0b\n\x07REDUCER\x10\x02\x12\t\n\x05OTHER\x10\x03*J\n\x07\x43ommand\x12\x08\n\x04IDLE\x10\x00\x12\t\n\x05START\x10\x01\x12\t\n\x05PAUSE\x10\x02\x12\x08\n\x04STOP\x10\x03\x12\t\n\x05RESET\x10\x04\x12\n\n\x06REPORT\x10\x05*I\n\x10\x43onnectionStatus\x12\x11\n\rNOT_ACCEPTING\x10\x00\x12\r\n\tACCEPTING\x10\x01\x12\x13\n\x0fTRY_AGAIN_LATER\x10\x02\x32z\n\x0cModelService\x12\x33\n\x06Upload\x12\x12.fedn.ModelRequest\x1a\x13.fedn.ModelResponse(\x01\x12\x35\n\x08\x44ownload\x12\x12.fedn.ModelRequest\x1a\x13.fedn.ModelResponse0\x01\x32\xbb\x02\n\x07\x43ontrol\x12\x34\n\x05Start\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12\x33\n\x04Stop\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12\x44\n\x15\x46lushAggregationQueue\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12<\n\rSetAggregator\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12\x41\n\x12SetServerFunctions\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse2V\n\x07Reducer\x12K\n\x0eGetGlobalModel\x12\x1b.fedn.GetGlobalModelRequest\x1a\x1c.fedn.GetGlobalModelResponse2\xab\x03\n\tConnector\x12\x44\n\x14\x41llianceStatusStream\x12\x1c.fedn.ClientAvailableMessage\x1a\x0c.fedn.Status0\x01\x12*\n\nSendStatus\x12\x0c.fedn.Status\x1a\x0e.fedn.Response\x12?\n\x11ListActiveClients\x12\x18.fedn.ListClientsRequest\x1a\x10.fedn.ClientList\x12\x45\n\x10\x41\x63\x63\x65ptingClients\x12\x17.fedn.ConnectionRequest\x1a\x18.fedn.ConnectionResponse\x12\x30\n\rSendHeartbeat\x12\x0f.fedn.Heartbeat\x1a\x0e.fedn.Response\x12\x37\n\x0eReassignClient\x12\x15.fedn.ReassignRequest\x1a\x0e.fedn.Response\x12\x39\n\x0fReconnectClient\x12\x16.fedn.ReconnectRequest\x1a\x0e.fedn.Response2\xfd\x01\n\x08\x43ombiner\x12?\n\nTaskStream\x12\x1c.fedn.ClientAvailableMessage\x1a\x11.fedn.TaskRequest0\x01\x12\x34\n\x0fSendModelUpdate\x12\x11.fedn.ModelUpdate\x1a\x0e.fedn.Response\x12<\n\x13SendModelValidation\x12\x15.fedn.ModelValidation\x1a\x0e.fedn.Response\x12<\n\x13SendModelPrediction\x12\x15.fedn.ModelPrediction\x1a\x0e.fedn.Response2\xec\x03\n\x0f\x46unctionService\x12Z\n\x17HandleProvidedFunctions\x12\x1e.fedn.ProvidedFunctionsRequest\x1a\x1f.fedn.ProvidedFunctionsResponse\x12M\n\x12HandleClientConfig\x12\x19.fedn.ClientConfigRequest\x1a\x1a.fedn.ClientConfigResponse(\x01\x12T\n\x15HandleClientSelection\x12\x1c.fedn.ClientSelectionRequest\x1a\x1d.fedn.ClientSelectionResponse\x12\x43\n\x0eHandleMetadata\x12\x17.fedn.ClientMetaRequest\x1a\x18.fedn.ClientMetaResponse\x12G\n\x10HandleStoreModel\x12\x17.fedn.StoreModelRequest\x1a\x18.fedn.StoreModelResponse(\x01\x12J\n\x11HandleAggregation\x12\x18.fedn.AggregationRequest\x1a\x19.fedn.AggregationResponse0\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'network.grpc.fedn_pb2', _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'fedn.network.grpc.fedn_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_STATUSTYPE']._serialized_start=2549 - _globals['_STATUSTYPE']._serialized_end=2688 - _globals['_QUEUE']._serialized_start=2690 - _globals['_QUEUE']._serialized_end=2726 - _globals['_MODELSTATUS']._serialized_start=2728 - _globals['_MODELSTATUS']._serialized_end=2811 - _globals['_ROLE']._serialized_start=2813 - _globals['_ROLE']._serialized_end=2869 - _globals['_COMMAND']._serialized_start=2871 - _globals['_COMMAND']._serialized_end=2945 - _globals['_CONNECTIONSTATUS']._serialized_start=2947 - _globals['_CONNECTIONSTATUS']._serialized_end=3020 - _globals['_RESPONSE']._serialized_start=66 - _globals['_RESPONSE']._serialized_end=124 - _globals['_STATUS']._serialized_start=127 - _globals['_STATUS']._serialized_end=443 - _globals['_STATUS_LOGLEVEL']._serialized_start=377 - _globals['_STATUS_LOGLEVEL']._serialized_end=443 - _globals['_TASKREQUEST']._serialized_start=446 - _globals['_TASKREQUEST']._serialized_end=662 - _globals['_MODELUPDATE']._serialized_start=665 - _globals['_MODELUPDATE']._serialized_end=856 - _globals['_MODELVALIDATION']._serialized_start=859 - _globals['_MODELVALIDATION']._serialized_end=1075 - _globals['_MODELPREDICTION']._serialized_start=1078 - _globals['_MODELPREDICTION']._serialized_end=1297 - _globals['_MODELREQUEST']._serialized_start=1300 - _globals['_MODELREQUEST']._serialized_end=1437 - _globals['_MODELRESPONSE']._serialized_start=1439 - _globals['_MODELRESPONSE']._serialized_end=1532 - _globals['_GETGLOBALMODELREQUEST']._serialized_start=1534 - _globals['_GETGLOBALMODELREQUEST']._serialized_end=1619 - _globals['_GETGLOBALMODELRESPONSE']._serialized_start=1621 - _globals['_GETGLOBALMODELRESPONSE']._serialized_end=1725 - _globals['_HEARTBEAT']._serialized_start=1727 - _globals['_HEARTBEAT']._serialized_end=1768 - _globals['_CLIENTAVAILABLEMESSAGE']._serialized_start=1770 - _globals['_CLIENTAVAILABLEMESSAGE']._serialized_end=1857 - _globals['_LISTCLIENTSREQUEST']._serialized_start=1859 - _globals['_LISTCLIENTSREQUEST']._serialized_end=1939 - _globals['_CLIENTLIST']._serialized_start=1941 - _globals['_CLIENTLIST']._serialized_end=1983 - _globals['_CLIENT']._serialized_start=1985 - _globals['_CLIENT']._serialized_end=2052 - _globals['_REASSIGNREQUEST']._serialized_start=2054 - _globals['_REASSIGNREQUEST']._serialized_end=2163 - _globals['_RECONNECTREQUEST']._serialized_start=2165 - _globals['_RECONNECTREQUEST']._serialized_end=2264 - _globals['_PARAMETER']._serialized_start=2266 - _globals['_PARAMETER']._serialized_end=2305 - _globals['_CONTROLREQUEST']._serialized_start=2307 - _globals['_CONTROLREQUEST']._serialized_end=2391 - _globals['_CONTROLRESPONSE']._serialized_start=2393 - _globals['_CONTROLRESPONSE']._serialized_end=2463 - _globals['_CONNECTIONREQUEST']._serialized_start=2465 - _globals['_CONNECTIONREQUEST']._serialized_end=2484 - _globals['_CONNECTIONRESPONSE']._serialized_start=2486 - _globals['_CONNECTIONRESPONSE']._serialized_end=2546 - _globals['_MODELSERVICE']._serialized_start=3022 - _globals['_MODELSERVICE']._serialized_end=3144 - _globals['_CONTROL']._serialized_start=3147 - _globals['_CONTROL']._serialized_end=3395 - _globals['_REDUCER']._serialized_start=3397 - _globals['_REDUCER']._serialized_end=3483 - _globals['_CONNECTOR']._serialized_start=3486 - _globals['_CONNECTOR']._serialized_end=3913 - _globals['_COMBINER']._serialized_start=3916 - _globals['_COMBINER']._serialized_end=4169 + _globals['_PROVIDEDFUNCTIONSRESPONSE_AVAILABLEFUNCTIONSENTRY']._options = None + _globals['_PROVIDEDFUNCTIONSRESPONSE_AVAILABLEFUNCTIONSENTRY']._serialized_options = b'8\001' + _globals['_STATUSTYPE']._serialized_start=3218 + _globals['_STATUSTYPE']._serialized_end=3357 + _globals['_QUEUE']._serialized_start=3359 + _globals['_QUEUE']._serialized_end=3395 + _globals['_MODELSTATUS']._serialized_start=3397 + _globals['_MODELSTATUS']._serialized_end=3480 + _globals['_ROLE']._serialized_start=3482 + _globals['_ROLE']._serialized_end=3538 + _globals['_COMMAND']._serialized_start=3540 + _globals['_COMMAND']._serialized_end=3614 + _globals['_CONNECTIONSTATUS']._serialized_start=3616 + _globals['_CONNECTIONSTATUS']._serialized_end=3689 + _globals['_RESPONSE']._serialized_start=71 + _globals['_RESPONSE']._serialized_end=129 + _globals['_STATUS']._serialized_start=132 + _globals['_STATUS']._serialized_end=448 + _globals['_STATUS_LOGLEVEL']._serialized_start=382 + _globals['_STATUS_LOGLEVEL']._serialized_end=448 + _globals['_TASKREQUEST']._serialized_start=451 + _globals['_TASKREQUEST']._serialized_end=667 + _globals['_MODELUPDATE']._serialized_start=670 + _globals['_MODELUPDATE']._serialized_end=861 + _globals['_MODELVALIDATION']._serialized_start=864 + _globals['_MODELVALIDATION']._serialized_end=1080 + _globals['_MODELPREDICTION']._serialized_start=1083 + _globals['_MODELPREDICTION']._serialized_end=1302 + _globals['_MODELREQUEST']._serialized_start=1305 + _globals['_MODELREQUEST']._serialized_end=1442 + _globals['_MODELRESPONSE']._serialized_start=1444 + _globals['_MODELRESPONSE']._serialized_end=1537 + _globals['_GETGLOBALMODELREQUEST']._serialized_start=1539 + _globals['_GETGLOBALMODELREQUEST']._serialized_end=1624 + _globals['_GETGLOBALMODELRESPONSE']._serialized_start=1626 + _globals['_GETGLOBALMODELRESPONSE']._serialized_end=1730 + _globals['_HEARTBEAT']._serialized_start=1732 + _globals['_HEARTBEAT']._serialized_end=1773 + _globals['_CLIENTAVAILABLEMESSAGE']._serialized_start=1775 + _globals['_CLIENTAVAILABLEMESSAGE']._serialized_end=1862 + _globals['_LISTCLIENTSREQUEST']._serialized_start=1864 + _globals['_LISTCLIENTSREQUEST']._serialized_end=1944 + _globals['_CLIENTLIST']._serialized_start=1946 + _globals['_CLIENTLIST']._serialized_end=1988 + _globals['_CLIENT']._serialized_start=1990 + _globals['_CLIENT']._serialized_end=2057 + _globals['_REASSIGNREQUEST']._serialized_start=2059 + _globals['_REASSIGNREQUEST']._serialized_end=2168 + _globals['_RECONNECTREQUEST']._serialized_start=2170 + _globals['_RECONNECTREQUEST']._serialized_end=2269 + _globals['_PARAMETER']._serialized_start=2271 + _globals['_PARAMETER']._serialized_end=2310 + _globals['_CONTROLREQUEST']._serialized_start=2312 + _globals['_CONTROLREQUEST']._serialized_end=2396 + _globals['_CONTROLRESPONSE']._serialized_start=2398 + _globals['_CONTROLRESPONSE']._serialized_end=2468 + _globals['_CONNECTIONREQUEST']._serialized_start=2470 + _globals['_CONNECTIONREQUEST']._serialized_end=2489 + _globals['_CONNECTIONRESPONSE']._serialized_start=2491 + _globals['_CONNECTIONRESPONSE']._serialized_end=2551 + _globals['_PROVIDEDFUNCTIONSREQUEST']._serialized_start=2553 + _globals['_PROVIDEDFUNCTIONSREQUEST']._serialized_end=2602 + _globals['_PROVIDEDFUNCTIONSRESPONSE']._serialized_start=2605 + _globals['_PROVIDEDFUNCTIONSRESPONSE']._serialized_end=2777 + _globals['_PROVIDEDFUNCTIONSRESPONSE_AVAILABLEFUNCTIONSENTRY']._serialized_start=2720 + _globals['_PROVIDEDFUNCTIONSRESPONSE_AVAILABLEFUNCTIONSENTRY']._serialized_end=2777 + _globals['_CLIENTCONFIGREQUEST']._serialized_start=2779 + _globals['_CLIENTCONFIGREQUEST']._serialized_end=2814 + _globals['_CLIENTCONFIGRESPONSE']._serialized_start=2816 + _globals['_CLIENTCONFIGRESPONSE']._serialized_end=2863 + _globals['_CLIENTSELECTIONREQUEST']._serialized_start=2865 + _globals['_CLIENTSELECTIONREQUEST']._serialized_end=2909 + _globals['_CLIENTSELECTIONRESPONSE']._serialized_start=2911 + _globals['_CLIENTSELECTIONRESPONSE']._serialized_end=2956 + _globals['_CLIENTMETAREQUEST']._serialized_start=2958 + _globals['_CLIENTMETAREQUEST']._serialized_end=3014 + _globals['_CLIENTMETARESPONSE']._serialized_start=3016 + _globals['_CLIENTMETARESPONSE']._serialized_end=3052 + _globals['_STOREMODELREQUEST']._serialized_start=3054 + _globals['_STOREMODELREQUEST']._serialized_end=3099 + _globals['_STOREMODELRESPONSE']._serialized_start=3101 + _globals['_STOREMODELRESPONSE']._serialized_end=3137 + _globals['_AGGREGATIONREQUEST']._serialized_start=3139 + _globals['_AGGREGATIONREQUEST']._serialized_end=3178 + _globals['_AGGREGATIONRESPONSE']._serialized_start=3180 + _globals['_AGGREGATIONRESPONSE']._serialized_end=3215 + _globals['_MODELSERVICE']._serialized_start=3691 + _globals['_MODELSERVICE']._serialized_end=3813 + _globals['_CONTROL']._serialized_start=3816 + _globals['_CONTROL']._serialized_end=4131 + _globals['_REDUCER']._serialized_start=4133 + _globals['_REDUCER']._serialized_end=4219 + _globals['_CONNECTOR']._serialized_start=4222 + _globals['_CONNECTOR']._serialized_end=4649 + _globals['_COMBINER']._serialized_start=4652 + _globals['_COMBINER']._serialized_end=4905 + _globals['_FUNCTIONSERVICE']._serialized_start=4908 + _globals['_FUNCTIONSERVICE']._serialized_end=5400 # @@protoc_insertion_point(module_scope) diff --git a/fedn/network/grpc/fedn_pb2_grpc.py b/fedn/network/grpc/fedn_pb2_grpc.py index 32ac134d7..b67e7e095 100644 --- a/fedn/network/grpc/fedn_pb2_grpc.py +++ b/fedn/network/grpc/fedn_pb2_grpc.py @@ -2,7 +2,7 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -from ..grpc import fedn_pb2 as network_dot_grpc_dot_fedn__pb2 +from fedn.network.grpc import fedn_pb2 as fedn_dot_network_dot_grpc_dot_fedn__pb2 class ModelServiceStub(object): @@ -16,13 +16,13 @@ def __init__(self, channel): """ self.Upload = channel.stream_unary( '/fedn.ModelService/Upload', - request_serializer=network_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, ) self.Download = channel.unary_stream( '/fedn.ModelService/Download', - request_serializer=network_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, ) @@ -46,13 +46,13 @@ def add_ModelServiceServicer_to_server(servicer, server): rpc_method_handlers = { 'Upload': grpc.stream_unary_rpc_method_handler( servicer.Upload, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ModelRequest.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.ModelResponse.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelResponse.SerializeToString, ), 'Download': grpc.unary_stream_rpc_method_handler( servicer.Download, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ModelRequest.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.ModelResponse.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -76,8 +76,8 @@ def Upload(request_iterator, timeout=None, metadata=None): return grpc.experimental.stream_unary(request_iterator, target, '/fedn.ModelService/Upload', - network_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, - network_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -93,8 +93,8 @@ def Download(request, timeout=None, metadata=None): return grpc.experimental.unary_stream(request, target, '/fedn.ModelService/Download', - network_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, - network_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -110,23 +110,28 @@ def __init__(self, channel): """ self.Start = channel.unary_unary( '/fedn.Control/Start', - request_serializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, ) self.Stop = channel.unary_unary( '/fedn.Control/Stop', - request_serializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, ) self.FlushAggregationQueue = channel.unary_unary( '/fedn.Control/FlushAggregationQueue', - request_serializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, ) self.SetAggregator = channel.unary_unary( '/fedn.Control/SetAggregator', - request_serializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + ) + self.SetServerFunctions = channel.unary_unary( + '/fedn.Control/SetServerFunctions', + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, ) @@ -157,28 +162,39 @@ def SetAggregator(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def SetServerFunctions(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_ControlServicer_to_server(servicer, server): rpc_method_handlers = { 'Start': grpc.unary_unary_rpc_method_handler( servicer.Start, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, ), 'Stop': grpc.unary_unary_rpc_method_handler( servicer.Stop, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, ), 'FlushAggregationQueue': grpc.unary_unary_rpc_method_handler( servicer.FlushAggregationQueue, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, ), 'SetAggregator': grpc.unary_unary_rpc_method_handler( servicer.SetAggregator, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, + ), + 'SetServerFunctions': grpc.unary_unary_rpc_method_handler( + servicer.SetServerFunctions, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -202,8 +218,8 @@ def Start(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/fedn.Control/Start', - network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -219,8 +235,8 @@ def Stop(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/fedn.Control/Stop', - network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -236,8 +252,8 @@ def FlushAggregationQueue(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/fedn.Control/FlushAggregationQueue', - network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -253,8 +269,25 @@ def SetAggregator(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/fedn.Control/SetAggregator', - network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def SetServerFunctions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/fedn.Control/SetServerFunctions', + fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -270,8 +303,8 @@ def __init__(self, channel): """ self.GetGlobalModel = channel.unary_unary( '/fedn.Reducer/GetGlobalModel', - request_serializer=network_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.FromString, ) @@ -289,8 +322,8 @@ def add_ReducerServicer_to_server(servicer, server): rpc_method_handlers = { 'GetGlobalModel': grpc.unary_unary_rpc_method_handler( servicer.GetGlobalModel, - request_deserializer=network_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -314,8 +347,8 @@ def GetGlobalModel(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/fedn.Reducer/GetGlobalModel', - network_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.SerializeToString, - network_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -331,38 +364,38 @@ def __init__(self, channel): """ self.AllianceStatusStream = channel.unary_stream( '/fedn.Connector/AllianceStatusStream', - request_serializer=network_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.Status.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Status.FromString, ) self.SendStatus = channel.unary_unary( '/fedn.Connector/SendStatus', - request_serializer=network_dot_grpc_dot_fedn__pb2.Status.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.Response.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Status.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.FromString, ) self.ListActiveClients = channel.unary_unary( '/fedn.Connector/ListActiveClients', - request_serializer=network_dot_grpc_dot_fedn__pb2.ListClientsRequest.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.ClientList.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ListClientsRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientList.FromString, ) self.AcceptingClients = channel.unary_unary( '/fedn.Connector/AcceptingClients', - request_serializer=network_dot_grpc_dot_fedn__pb2.ConnectionRequest.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.ConnectionResponse.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ConnectionRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ConnectionResponse.FromString, ) self.SendHeartbeat = channel.unary_unary( '/fedn.Connector/SendHeartbeat', - request_serializer=network_dot_grpc_dot_fedn__pb2.Heartbeat.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.Response.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Heartbeat.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.FromString, ) self.ReassignClient = channel.unary_unary( '/fedn.Connector/ReassignClient', - request_serializer=network_dot_grpc_dot_fedn__pb2.ReassignRequest.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.Response.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ReassignRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.FromString, ) self.ReconnectClient = channel.unary_unary( '/fedn.Connector/ReconnectClient', - request_serializer=network_dot_grpc_dot_fedn__pb2.ReconnectRequest.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.Response.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ReconnectRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.FromString, ) @@ -421,38 +454,38 @@ def add_ConnectorServicer_to_server(servicer, server): rpc_method_handlers = { 'AllianceStatusStream': grpc.unary_stream_rpc_method_handler( servicer.AllianceStatusStream, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.Status.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Status.SerializeToString, ), 'SendStatus': grpc.unary_unary_rpc_method_handler( servicer.SendStatus, - request_deserializer=network_dot_grpc_dot_fedn__pb2.Status.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Status.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, ), 'ListActiveClients': grpc.unary_unary_rpc_method_handler( servicer.ListActiveClients, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ListClientsRequest.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.ClientList.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ListClientsRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientList.SerializeToString, ), 'AcceptingClients': grpc.unary_unary_rpc_method_handler( servicer.AcceptingClients, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ConnectionRequest.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.ConnectionResponse.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ConnectionRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ConnectionResponse.SerializeToString, ), 'SendHeartbeat': grpc.unary_unary_rpc_method_handler( servicer.SendHeartbeat, - request_deserializer=network_dot_grpc_dot_fedn__pb2.Heartbeat.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Heartbeat.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, ), 'ReassignClient': grpc.unary_unary_rpc_method_handler( servicer.ReassignClient, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ReassignRequest.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ReassignRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, ), 'ReconnectClient': grpc.unary_unary_rpc_method_handler( servicer.ReconnectClient, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ReconnectRequest.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ReconnectRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -476,8 +509,8 @@ def AllianceStatusStream(request, timeout=None, metadata=None): return grpc.experimental.unary_stream(request, target, '/fedn.Connector/AllianceStatusStream', - network_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - network_dot_grpc_dot_fedn__pb2.Status.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.Status.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -493,8 +526,8 @@ def SendStatus(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/fedn.Connector/SendStatus', - network_dot_grpc_dot_fedn__pb2.Status.SerializeToString, - network_dot_grpc_dot_fedn__pb2.Response.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.Status.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -510,8 +543,8 @@ def ListActiveClients(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/fedn.Connector/ListActiveClients', - network_dot_grpc_dot_fedn__pb2.ListClientsRequest.SerializeToString, - network_dot_grpc_dot_fedn__pb2.ClientList.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ListClientsRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientList.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -527,8 +560,8 @@ def AcceptingClients(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/fedn.Connector/AcceptingClients', - network_dot_grpc_dot_fedn__pb2.ConnectionRequest.SerializeToString, - network_dot_grpc_dot_fedn__pb2.ConnectionResponse.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ConnectionRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ConnectionResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -544,8 +577,8 @@ def SendHeartbeat(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/fedn.Connector/SendHeartbeat', - network_dot_grpc_dot_fedn__pb2.Heartbeat.SerializeToString, - network_dot_grpc_dot_fedn__pb2.Response.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.Heartbeat.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -561,8 +594,8 @@ def ReassignClient(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/fedn.Connector/ReassignClient', - network_dot_grpc_dot_fedn__pb2.ReassignRequest.SerializeToString, - network_dot_grpc_dot_fedn__pb2.Response.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ReassignRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -578,8 +611,8 @@ def ReconnectClient(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/fedn.Connector/ReconnectClient', - network_dot_grpc_dot_fedn__pb2.ReconnectRequest.SerializeToString, - network_dot_grpc_dot_fedn__pb2.Response.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ReconnectRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -595,23 +628,23 @@ def __init__(self, channel): """ self.TaskStream = channel.unary_stream( '/fedn.Combiner/TaskStream', - request_serializer=network_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.TaskRequest.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.TaskRequest.FromString, ) self.SendModelUpdate = channel.unary_unary( '/fedn.Combiner/SendModelUpdate', - request_serializer=network_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.Response.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.FromString, ) self.SendModelValidation = channel.unary_unary( '/fedn.Combiner/SendModelValidation', - request_serializer=network_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.Response.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.FromString, ) self.SendModelPrediction = channel.unary_unary( '/fedn.Combiner/SendModelPrediction', - request_serializer=network_dot_grpc_dot_fedn__pb2.ModelPrediction.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.Response.FromString, + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelPrediction.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.FromString, ) @@ -648,23 +681,23 @@ def add_CombinerServicer_to_server(servicer, server): rpc_method_handlers = { 'TaskStream': grpc.unary_stream_rpc_method_handler( servicer.TaskStream, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.TaskRequest.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.TaskRequest.SerializeToString, ), 'SendModelUpdate': grpc.unary_unary_rpc_method_handler( servicer.SendModelUpdate, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ModelUpdate.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelUpdate.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, ), 'SendModelValidation': grpc.unary_unary_rpc_method_handler( servicer.SendModelValidation, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ModelValidation.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelValidation.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, ), 'SendModelPrediction': grpc.unary_unary_rpc_method_handler( servicer.SendModelPrediction, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ModelPrediction.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelPrediction.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -688,8 +721,8 @@ def TaskStream(request, timeout=None, metadata=None): return grpc.experimental.unary_stream(request, target, '/fedn.Combiner/TaskStream', - network_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - network_dot_grpc_dot_fedn__pb2.TaskRequest.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.TaskRequest.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -705,8 +738,8 @@ def SendModelUpdate(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/fedn.Combiner/SendModelUpdate', - network_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, - network_dot_grpc_dot_fedn__pb2.Response.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -722,8 +755,8 @@ def SendModelValidation(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/fedn.Combiner/SendModelValidation', - network_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, - network_dot_grpc_dot_fedn__pb2.Response.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -739,7 +772,233 @@ def SendModelPrediction(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/fedn.Combiner/SendModelPrediction', - network_dot_grpc_dot_fedn__pb2.ModelPrediction.SerializeToString, - network_dot_grpc_dot_fedn__pb2.Response.FromString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ModelPrediction.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + +class FunctionServiceStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.HandleProvidedFunctions = channel.unary_unary( + '/fedn.FunctionService/HandleProvidedFunctions', + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ProvidedFunctionsRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ProvidedFunctionsResponse.FromString, + ) + self.HandleClientConfig = channel.stream_unary( + '/fedn.FunctionService/HandleClientConfig', + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientConfigRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientConfigResponse.FromString, + ) + self.HandleClientSelection = channel.unary_unary( + '/fedn.FunctionService/HandleClientSelection', + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientSelectionRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientSelectionResponse.FromString, + ) + self.HandleMetadata = channel.unary_unary( + '/fedn.FunctionService/HandleMetadata', + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientMetaRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientMetaResponse.FromString, + ) + self.HandleStoreModel = channel.stream_unary( + '/fedn.FunctionService/HandleStoreModel', + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.StoreModelRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.StoreModelResponse.FromString, + ) + self.HandleAggregation = channel.unary_stream( + '/fedn.FunctionService/HandleAggregation', + request_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.AggregationRequest.SerializeToString, + response_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.AggregationResponse.FromString, + ) + + +class FunctionServiceServicer(object): + """Missing associated documentation comment in .proto file.""" + + def HandleProvidedFunctions(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def HandleClientConfig(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def HandleClientSelection(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def HandleMetadata(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def HandleStoreModel(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def HandleAggregation(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_FunctionServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'HandleProvidedFunctions': grpc.unary_unary_rpc_method_handler( + servicer.HandleProvidedFunctions, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ProvidedFunctionsRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ProvidedFunctionsResponse.SerializeToString, + ), + 'HandleClientConfig': grpc.stream_unary_rpc_method_handler( + servicer.HandleClientConfig, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientConfigRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientConfigResponse.SerializeToString, + ), + 'HandleClientSelection': grpc.unary_unary_rpc_method_handler( + servicer.HandleClientSelection, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientSelectionRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientSelectionResponse.SerializeToString, + ), + 'HandleMetadata': grpc.unary_unary_rpc_method_handler( + servicer.HandleMetadata, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientMetaRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientMetaResponse.SerializeToString, + ), + 'HandleStoreModel': grpc.stream_unary_rpc_method_handler( + servicer.HandleStoreModel, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.StoreModelRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.StoreModelResponse.SerializeToString, + ), + 'HandleAggregation': grpc.unary_stream_rpc_method_handler( + servicer.HandleAggregation, + request_deserializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.AggregationRequest.FromString, + response_serializer=fedn_dot_network_dot_grpc_dot_fedn__pb2.AggregationResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'fedn.FunctionService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class FunctionService(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def HandleProvidedFunctions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/fedn.FunctionService/HandleProvidedFunctions', + fedn_dot_network_dot_grpc_dot_fedn__pb2.ProvidedFunctionsRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ProvidedFunctionsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def HandleClientConfig(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary(request_iterator, target, '/fedn.FunctionService/HandleClientConfig', + fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientConfigRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientConfigResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def HandleClientSelection(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/fedn.FunctionService/HandleClientSelection', + fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientSelectionRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientSelectionResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def HandleMetadata(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/fedn.FunctionService/HandleMetadata', + fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientMetaRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.ClientMetaResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def HandleStoreModel(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary(request_iterator, target, '/fedn.FunctionService/HandleStoreModel', + fedn_dot_network_dot_grpc_dot_fedn__pb2.StoreModelRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.StoreModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def HandleAggregation(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/fedn.FunctionService/HandleAggregation', + fedn_dot_network_dot_grpc_dot_fedn__pb2.AggregationRequest.SerializeToString, + fedn_dot_network_dot_grpc_dot_fedn__pb2.AggregationResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/fedn/utils/helpers/helpers.py b/fedn/utils/helpers/helpers.py index 7fbf83a81..1e5e81d75 100644 --- a/fedn/utils/helpers/helpers.py +++ b/fedn/utils/helpers/helpers.py @@ -29,6 +29,19 @@ def save_metadata(metadata, filename): json.dump(metadata, outfile) +def load_metadata(filename): + """Load metadata from file. + + :param filename: The name of the file to load from. + :type filename: str + :return: The loaded metadata. + :rtype: dict + """ + with open(filename + "-metadata", "r") as infile: + metadata = json.load(infile) + return metadata + + def save_metrics(metrics, filename): """Save metrics to file.