Skip to content

Commit

Permalink
Feature/SK-946 | Add functionality for user defined server-functions (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
viktorvaladi authored Oct 31, 2024
1 parent 76471af commit 6462978
Show file tree
Hide file tree
Showing 40 changed files with 1,888 additions and 470 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ config/extra-hosts-reducer.yaml
config/settings-client.yaml
config/settings-reducer.yaml
config/settings-combiner.yaml
config/settings-hooks.yaml

./tmp/*

Expand Down
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ARG REQUIREMENTS=""
COPY . /app
COPY config/settings-client.yaml.template /app/config/settings-client.yaml
COPY config/settings-combiner.yaml.template /app/config/settings-combiner.yaml
COPY config/settings-hooks.yaml.template /app/config/settings-hooks.yaml
COPY config/settings-reducer.yaml.template /app/config/settings-reducer.yaml
COPY $REQUIREMENTS /app/config/requirements.txt

Expand Down
8 changes: 8 additions & 0 deletions config/settings-hooks.yaml.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
network_id: fedn-network
discover_host: api-server
discover_port: 8092

name: hooks
host: hooks
port: 12081
max_clients: 30
40 changes: 27 additions & 13 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,7 @@ services:
- MINIO_ROOT_PASSWORD=password
command: server /data --console-address minio:9001
healthcheck:
test:
[
"CMD",
"curl",
"-f",
"http://minio:9000/minio/health/live"
]
test: [ "CMD", "curl", "-f", "http://minio:9000/minio/health/live" ]
interval: 30s
timeout: 20s
retries: 3
Expand Down Expand Up @@ -89,6 +83,7 @@ services:
- GET_HOSTS_FROM=dns
- STATESTORE_CONFIG=/app/config/settings-combiner.yaml
- MODELSTORAGE_CONFIG=/app/config/settings-combiner.yaml
- HOOK_SERVICE_HOST=hook:12081
build:
context: .
args:
Expand All @@ -103,17 +98,36 @@ services:
ports:
- 12080:12080
healthcheck:
test:
[
"CMD",
"/bin/grpc_health_probe",
"-addr=localhost:12080"
]
test: [ "CMD", "/bin/grpc_health_probe", "-addr=localhost:12080" ]
interval: 20s
timeout: 10s
retries: 5
depends_on:
- api-server
- hooks
# Hooks
hooks:
container_name: hook
environment:
- GET_HOSTS_FROM=dns
build:
context: .
args:
BASE_IMG: ${BASE_IMG:-python:3.10-slim}
GRPC_HEALTH_PROBE_VERSION: v0.4.24
working_dir: /app
volumes:
- ${HOST_REPO_DIR:-.}/fedn:/app/fedn
entrypoint: [ "sh", "-c" ]
command:
- "/venv/bin/pip install --no-cache-dir -e . && /venv/bin/fedn hooks start"
ports:
- 12081:12081
healthcheck:
test: [ "CMD", "/bin/grpc_health_probe", "-addr=localhost:12081" ]
interval: 20s
timeout: 10s
retries: 5

# Client
client:
Expand Down
4 changes: 4 additions & 0 deletions examples/server-functions/.dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
data
seed.npz
*.tgz
*.tar.gz
6 changes: 6 additions & 0 deletions examples/server-functions/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
data
*.npz
*.tgz
*.tar.gz
.mnist-pytorch
client.yaml
11 changes: 11 additions & 0 deletions examples/server-functions/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FEDn Project: Server functions toy example
-----------------------------

See server_functions.py for details.

README Will be updated after studio update.

To run with server functions:

from server_functions import ServerFunctions
client.start_session(server_functions=ServerFunctions)
97 changes: 97 additions & 0 deletions examples/server-functions/client/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
from math import floor

import torch
import torchvision

dir_path = os.path.dirname(os.path.realpath(__file__))
abs_path = os.path.abspath(dir_path)


def get_data(out_dir="data"):
# Make dir if necessary
if not os.path.exists(out_dir):
os.mkdir(out_dir)

# Only download if not already downloaded
if not os.path.exists(f"{out_dir}/train"):
torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True, download=True)
if not os.path.exists(f"{out_dir}/test"):
torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False, download=True)


def load_data(data_path, is_train=True):
"""Load data from disk.
:param data_path: Path to data file.
:type data_path: str
:param is_train: Whether to load training or test data.
:type is_train: bool
:return: Tuple of data and labels.
:rtype: tuple
"""
if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.pt")

data = torch.load(data_path)

if is_train:
X = data["x_train"]
y = data["y_train"]
else:
X = data["x_test"]
y = data["y_test"]

# Normalize
X = X / 255

return X, y


def splitset(dataset, parts):
n = dataset.shape[0]
local_n = floor(n / parts)
result = []
for i in range(parts):
result.append(dataset[i * local_n : (i + 1) * local_n])
return result


def split(out_dir="data"):
n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2))

# Make dir
if not os.path.exists(f"{out_dir}/clients"):
os.mkdir(f"{out_dir}/clients")

# Load and convert to dict
train_data = torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True)
test_data = torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False)
data = {
"x_train": splitset(train_data.data, n_splits),
"y_train": splitset(train_data.targets, n_splits),
"x_test": splitset(test_data.data, n_splits),
"y_test": splitset(test_data.targets, n_splits),
}

# Make splits
for i in range(n_splits):
subdir = f"{out_dir}/clients/{str(i+1)}"
if not os.path.exists(subdir):
os.mkdir(subdir)
torch.save(
{
"x_train": data["x_train"][i],
"y_train": data["y_train"][i],
"x_test": data["x_test"][i],
"y_test": data["y_test"][i],
},
f"{subdir}/mnist.pt",
)


if __name__ == "__main__":
# Prepare data if not already done
if not os.path.exists(abs_path + "/data/clients/1"):
get_data()
split()
12 changes: 12 additions & 0 deletions examples/server-functions/client/fedn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
python_env: python_env.yaml
entry_points:
build:
command: python model.py
startup:
command: python data.py
train:
command: python train.py
validate:
command: python validate.py
predict:
command: python predict.py
76 changes: 76 additions & 0 deletions examples/server-functions/client/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import collections

import torch

from fedn.utils.helpers.helpers import get_helper

HELPER_MODULE = "numpyhelper"
helper = get_helper(HELPER_MODULE)


def compile_model():
"""Compile the pytorch model.
:return: The compiled model.
:rtype: torch.nn.Module
"""

class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = torch.nn.Linear(784, 64)
self.fc2 = torch.nn.Linear(64, 32)
self.fc3 = torch.nn.Linear(32, 10)

def forward(self, x):
x = torch.nn.functional.relu(self.fc1(x.reshape(x.size(0), 784)))
x = torch.nn.functional.dropout(x, p=0.5, training=self.training)
x = torch.nn.functional.relu(self.fc2(x))
x = torch.nn.functional.log_softmax(self.fc3(x), dim=1)
return x

return Net()


def save_parameters(model, out_path):
"""Save model paramters to file.
:param model: The model to serialize.
:type model: torch.nn.Module
:param out_path: The path to save to.
:type out_path: str
"""
parameters_np = [val.cpu().numpy() for _, val in model.state_dict().items()]
helper.save(parameters_np, out_path)


def load_parameters(model_path):
"""Load model parameters from file and populate model.
param model_path: The path to load from.
:type model_path: str
:return: The loaded model.
:rtype: torch.nn.Module
"""
model = compile_model()
parameters_np = helper.load(model_path)

params_dict = zip(model.state_dict().keys(), parameters_np)
state_dict = collections.OrderedDict({key: torch.tensor(x) for key, x in params_dict})
model.load_state_dict(state_dict, strict=True)
return model


def init_seed(out_path="seed.npz"):
"""Initialize seed model and save it to file.
:param out_path: The path to save the seed model to.
:type out_path: str
"""
# Init and save
model = compile_model()
save_parameters(model, out_path)


if __name__ == "__main__":
init_seed("../seed.npz")
37 changes: 37 additions & 0 deletions examples/server-functions/client/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import sys

import torch
from data import load_data
from model import load_parameters

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))


def predict(in_model_path, out_artifact_path, data_path=None):
"""Validate model.
:param in_model_path: The path to the input model.
:type in_model_path: str
:param out_artifact_path: The path to save the predict output to.
:type out_artifact_path: str
:param data_path: The path to the data file.
:type data_path: str
"""
# Load data
x_test, y_test = load_data(data_path, is_train=False)

# Load model
model = load_parameters(in_model_path)
model.eval()

# Predict
with torch.no_grad():
y_pred = model(x_test)
# Save prediction to file/artifact, the artifact will be uploaded to the object store by the client
torch.save(y_pred, out_artifact_path)


if __name__ == "__main__":
predict(sys.argv[1], sys.argv[2])
9 changes: 9 additions & 0 deletions examples/server-functions/client/python_env.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
name: mnist-pytorch
build_dependencies:
- pip
- setuptools
- wheel
dependencies:
- torch==2.3.1
- torchvision==0.18.1
- fedn
Loading

0 comments on commit 6462978

Please sign in to comment.