Skip to content

Commit

Permalink
Feature/SK-1081 | Use stores in Combiner + ModelPredict (#718)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede authored Oct 23, 2024
1 parent b63f3a6 commit dddaebb
Show file tree
Hide file tree
Showing 27 changed files with 579 additions and 305 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _eprint(*args, **kwargs):
def _wait_n_rounds(collection):
n = 0
for _ in range(RETRIES):
query = {'type': 'INFERENCE'}
query = {'type': 'MODEL_PREDICTION'}
n = collection.count_documents(query)
if n == N_CLIENTS:
return n
Expand All @@ -32,4 +32,4 @@ def _wait_n_rounds(collection):
# Wait for successful rounds
succeded = _wait_n_rounds(client['fedn-test-network']['control']['status'])
assert(succeded == N_CLIENTS) # check that all rounds succeeded
_eprint(f'Succeded inference clients: {succeded}. Test passed.')
_eprint(f'Succeded prediction clients: {succeded}. Test passed.')
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ if [ "$#" -lt 1 ]; then
fi
example="$1"

>&2 echo "Run inference"
>&2 echo "Run prediction"
pushd "examples/$example"
curl -k -X POST https://localhost:8090/infer
curl -k -X POST https://localhost:8090/predict

>&2 echo "Checking inference success"
".$example/bin/python" ../../.ci/tests/examples/inference_test.py
>&2 echo "Checking prediction success"
".$example/bin/python" ../../.ci/tests/examples/prediction_test.py

>&2 echo "Test completed successfully"
popd
4 changes: 2 additions & 2 deletions .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ jobs:
- name: run ${{ matrix.to_test }}
run: .ci/tests/examples/run.sh ${{ matrix.to_test }}

# - name: run ${{ matrix.to_test }} inference
# run: .ci/tests/examples/run_inference.sh ${{ matrix.to_test }}
# - name: run ${{ matrix.to_test }} prediction
# run: .ci/tests/examples/run_prediction.sh ${{ matrix.to_test }}
# if: ${{ matrix.os != 'macos-11' && matrix.to_test == 'mnist-keras keras' }} # example available for Keras

- name: print logs
Expand Down
4 changes: 2 additions & 2 deletions examples/mnist-keras/client/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@


def predict(in_model_path, out_json_path, data_path=None):
# Using test data for inference but another dataset could be loaded
# Using test data for prediction but another dataset could be loaded
x_test, _ = load_data(data_path, is_train=False)

# Load model
model = load_parameters(in_model_path)

# Infer
# Predict
y_pred = model.predict(x_test)
y_pred = np.argmax(y_pred, axis=1)

Expand Down
3 changes: 1 addition & 2 deletions fedn/cli/status_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
@main.group("status")
@click.pass_context
def status_cmd(ctx):
""":param ctx:
"""
""":param ctx:"""
pass


Expand Down
4 changes: 2 additions & 2 deletions fedn/network/api/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from fedn.network.api.v1.client_routes import bp as client_bp
from fedn.network.api.v1.combiner_routes import bp as combiner_bp
from fedn.network.api.v1.helper_routes import bp as helper_bp
from fedn.network.api.v1.inference_routes import bp as inference_bp
from fedn.network.api.v1.model_routes import bp as model_bp
from fedn.network.api.v1.package_routes import bp as package_bp
from fedn.network.api.v1.prediction_routes import bp as prediction_bp
from fedn.network.api.v1.round_routes import bp as round_bp
from fedn.network.api.v1.session_routes import bp as session_bp
from fedn.network.api.v1.status_routes import bp as status_bp
from fedn.network.api.v1.validation_routes import bp as validation_bp

_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp, inference_bp, helper_bp]
_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp, prediction_bp, helper_bp]
34 changes: 0 additions & 34 deletions fedn/network/api/v1/inference_routes.py

This file was deleted.

51 changes: 51 additions & 0 deletions fedn/network/api/v1/prediction_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import threading

from flask import Blueprint, jsonify, request

from fedn.network.api.auth import jwt_auth_required
from fedn.network.api.shared import control
from fedn.network.api.v1.shared import api_version, mdb
from fedn.network.storage.statestore.stores.model_store import ModelStore
from fedn.network.storage.statestore.stores.prediction_store import PredictionStore
from fedn.network.storage.statestore.stores.shared import EntityNotFound

bp = Blueprint("prediction", __name__, url_prefix=f"/api/{api_version}/predict")

prediction_store = PredictionStore(mdb, "control.predictions")
model_store = ModelStore(mdb, "control.model")


@bp.route("/start", methods=["POST"])
@jwt_auth_required(role="admin")
def start_session():
"""Start a new prediction session.
param: prediction_id: The session id to start.
type: prediction_id: str
param: rounds: The number of rounds to run.
type: rounds: int
"""
try:
data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict()
prediction_id: str = data.get("prediction_id")

if not prediction_id or prediction_id == "":
return jsonify({"message": "prediction_id is required"}), 400

if data.get("model_id") is None:
count = model_store.count()
if count == 0:
return jsonify({"message": "No models available"}), 400
else:
try:
model_id = data.get("model_id")
_ = model_store.get(model_id)
except EntityNotFound:
return jsonify({"message": f"Model {model_id} not found"}), 404

session_config = {"prediction_id": prediction_id}

threading.Thread(target=control.prediction_session, kwargs={"config": session_config}).start()

return jsonify({"message": "Prediction session started"}), 200
except Exception:
return jsonify({"message": "Failed to start prediction session"}), 500
4 changes: 4 additions & 0 deletions fedn/network/api/v1/status_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,12 @@ def get_statuses():
limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers)
kwargs = request.args.to_dict()

# print all the typed headers
print(f"limit: {limit}, skip: {skip}, sort_key: {sort_key}, sort_order: {sort_order}, use_typing: {use_typing}")
print(f"kwargs: {kwargs}")
statuses = status_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs)

print(f"statuses: {statuses}")
result = [status.__dict__ for status in statuses["result"]] if use_typing else statuses["result"]

response = {"count": statuses["count"], "result": result}
Expand Down
80 changes: 49 additions & 31 deletions fedn/network/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def connect(self, combiner_config):
logger.debug("Client using metadata: {}.".format(self.metadata))
port = combiner_config["port"]
secure = False
if combiner_config["fqdn"] is not None:
if "fqdn" in combiner_config.keys() and combiner_config["fqdn"] is not None:
host = combiner_config["fqdn"]
# assuming https if fqdn is used
port = 443
Expand Down Expand Up @@ -417,12 +417,12 @@ def _listen_to_task_stream(self):
self.inbox.put(("train", request))
elif request.type == fedn.StatusType.MODEL_VALIDATION and self.config["validator"]:
self.inbox.put(("validate", request))
elif request.type == fedn.StatusType.INFERENCE and self.config["validator"]:
logger.info("Received inference request for model_id {}".format(request.model_id))
elif request.type == fedn.StatusType.MODEL_PREDICTION and self.config["validator"]:
logger.info("Received prediction request for model_id {}".format(request.model_id))
presigned_url = json.loads(request.data)
presigned_url = presigned_url["presigned_url"]
logger.info("Inference presigned URL: {}".format(presigned_url))
self.inbox.put(("infer", request))
logger.info("Prediction presigned URL: {}".format(presigned_url))
self.inbox.put(("predict", request))
else:
logger.error("Unknown request type: {}".format(request.type))

Expand Down Expand Up @@ -519,25 +519,17 @@ def _process_training_request(self, model_id: str, session_id: str = None):

return updated_model_id, meta

def _process_validation_request(self, model_id: str, is_inference: bool, session_id: str = None):
def _process_validation_request(self, model_id: str, session_id: str = None):
"""Process a validation request.
:param model_id: The model id of the model to be validated.
:type model_id: str
:param is_inference: True if the validation is an inference request, False if it is a validation request.
:type is_inference: bool
:param session_id: The id of the current session.
:type session_id: str
:return: The validation metrics, or None if validation failed.
:rtype: dict
"""
# Figure out cmd
if is_inference:
cmd = "infer"
else:
cmd = "validate"

self.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id)
self.send_status(f"Processing validation request for model_id {model_id}", sesssion_id=session_id)
self.state = ClientState.validating
try:
model = self.get_model_from_combiner(str(model_id))
Expand All @@ -550,7 +542,7 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session
fh.write(model.getbuffer())

outpath = get_tmp_path()
self.dispatcher.run_cmd(f"{cmd} {inpath} {outpath}")
self.dispatcher.run_cmd(f"validate {inpath} {outpath}")

with open(outpath, "r") as fh:
validation = json.loads(fh.read())
Expand All @@ -566,22 +558,22 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session
self.state = ClientState.idle
return validation

def _process_inference_request(self, model_id: str, session_id: str, presigned_url: str):
"""Process an inference request.
def _process_prediction_request(self, model_id: str, session_id: str, presigned_url: str):
"""Process a prediction request.
:param model_id: The model id of the model to be used for inference.
:param model_id: The model id of the model to be used for prediction.
:type model_id: str
:param session_id: The id of the current session.
:type session_id: str
:param presigned_url: The presigned URL for the data to be used for inference.
:param presigned_url: The presigned URL for the data to be used for prediction.
:type presigned_url: str
:return: None
"""
self.send_status(f"Processing inference request for model_id {model_id}", sesssion_id=session_id)
self.send_status(f"Processing prediction request for model_id {model_id}", sesssion_id=session_id)
try:
model = self.get_model_from_combiner(str(model_id))
if model is None:
logger.error("Could not retrieve model from combiner. Aborting inference request.")
logger.error("Could not retrieve model from combiner. Aborting prediction request.")
return
inpath = self.helper.get_tmp_path()

Expand All @@ -591,20 +583,20 @@ def _process_inference_request(self, model_id: str, session_id: str, presigned_u
outpath = get_tmp_path()
self.dispatcher.run_cmd(f"predict {inpath} {outpath}")

# Upload the inference result to the presigned URL
# Upload the prediction result to the presigned URL
with open(outpath, "rb") as fh:
response = requests.put(presigned_url, data=fh.read())

os.unlink(inpath)
os.unlink(outpath)

if response.status_code != 200:
logger.warning("Inference upload failed with status code {}".format(response.status_code))
logger.warning("Prediction upload failed with status code {}".format(response.status_code))
self.state = ClientState.idle
return

except Exception as e:
logger.warning("Inference failed with exception {}".format(e))
logger.warning("Prediction failed with exception {}".format(e))
self.state = ClientState.idle
return

Expand Down Expand Up @@ -668,7 +660,7 @@ def process_request(self):

elif task_type == "validate":
self.state = ClientState.validating
metrics = self._process_validation_request(request.model_id, False, request.session_id)
metrics = self._process_validation_request(request.model_id, request.session_id)

if metrics is not None:
# Send validation
Expand Down Expand Up @@ -707,21 +699,47 @@ def process_request(self):

self.state = ClientState.idle
self.inbox.task_done()
elif task_type == "infer":
self.state = ClientState.inferencing
elif task_type == "predict":
self.state = ClientState.predicting
try:
presigned_url = json.loads(request.data)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode inference request data: {e}")
logger.error(f"Failed to decode prediction request data: {e}")
self.state = ClientState.idle
continue

if "presigned_url" not in presigned_url:
logger.error("Inference request missing presigned_url.")
logger.error("Prediction request missing presigned_url.")
self.state = ClientState.idle
continue
presigned_url = presigned_url["presigned_url"]
_ = self._process_inference_request(request.model_id, request.session_id, presigned_url)
# Obs that session_id in request is the prediction_id
_ = self._process_prediction_request(request.model_id, request.session_id, presigned_url)
prediction = fedn.ModelPrediction()
prediction.sender.name = self.name
prediction.sender.role = fedn.WORKER
prediction.receiver.name = request.sender.name
prediction.receiver.name = request.sender.name
prediction.receiver.role = request.sender.role
prediction.model_id = str(request.model_id)
# TODO: Add prediction data
prediction.data = ""
prediction.timestamp.GetCurrentTime()
prediction.correlation_id = request.correlation_id
# Obs that session_id in request is the prediction_id
prediction.prediction_id = request.session_id

try:
_ = self.combinerStub.SendModelPrediction(prediction, metadata=self.metadata)
status_type = fedn.StatusType.MODEL_PREDICTION
self.send_status(
"Model prediction completed.", log_level=fedn.Status.AUDIT, type=status_type, request=prediction, sesssion_id=request.session_id
)
except grpc.RpcError as e:
status_code = e.code()
logger.error("GRPC error, {}.".format(status_code.name))
logger.debug(e)

self.state = ClientState.idle
except queue.Empty:
pass
Expand Down
Loading

0 comments on commit dddaebb

Please sign in to comment.