-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/SK-1081 | Use stores in Combiner + ModelPredict #718
Changes from 10 commits
3d92b77
659fe8b
eb74947
4fb372e
92b07bd
3450880
db425d6
399f85e
eac749a
a699720
07eb5ad
185dafe
ca95237
df9d5f7
197dd96
339787c
c45477b
2bd2cac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
from fedn.network.api.v1.client_routes import bp as client_bp | ||
from fedn.network.api.v1.combiner_routes import bp as combiner_bp | ||
from fedn.network.api.v1.inference_routes import bp as inference_bp | ||
from fedn.network.api.v1.prediction_routes import bp as prediction_bp | ||
from fedn.network.api.v1.model_routes import bp as model_bp | ||
from fedn.network.api.v1.package_routes import bp as package_bp | ||
from fedn.network.api.v1.round_routes import bp as round_bp | ||
from fedn.network.api.v1.session_routes import bp as session_bp | ||
from fedn.network.api.v1.status_routes import bp as status_bp | ||
from fedn.network.api.v1.validation_routes import bp as validation_bp | ||
|
||
_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp, inference_bp] | ||
_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp, prediction_bp] |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import threading | ||
|
||
from flask import Blueprint, jsonify, request | ||
|
||
from fedn.network.api.auth import jwt_auth_required | ||
from fedn.network.api.shared import control | ||
from fedn.network.api.v1.shared import api_version, mdb | ||
from fedn.network.storage.statestore.stores.prediction_store import PredictionStore | ||
|
||
bp = Blueprint("prediction", __name__, url_prefix=f"/api/{api_version}/predict") | ||
|
||
prediction_store = PredictionStore(mdb, "control.predictions") | ||
|
||
|
||
@bp.route("/start", methods=["POST"]) | ||
@jwt_auth_required(role="admin") | ||
def start_session(): | ||
"""Start a new prediction session. | ||
param: prediction_id: The session id to start. | ||
type: prediction_id: str | ||
param: rounds: The number of rounds to run. | ||
type: rounds: int | ||
""" | ||
try: | ||
data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict() | ||
prediction_id: str = data.get("prediction_id") | ||
|
||
if not prediction_id or prediction_id == "": | ||
return jsonify({"message": "Session ID is required"}), 400 | ||
|
||
session_config = {"prediction_id": prediction_id} | ||
|
||
threading.Thread(target=control.prediction_session, kwargs={"config": session_config}).start() | ||
|
||
return jsonify({"message": "Prediction session started"}), 200 | ||
except Exception: | ||
return jsonify({"message": "Failed to start prediction session"}), 500 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -176,7 +176,7 @@ def connect(self, combiner_config): | |
logger.debug("Client using metadata: {}.".format(self.metadata)) | ||
port = combiner_config["port"] | ||
secure = False | ||
if combiner_config["fqdn"] is not None: | ||
if "fqdn" in combiner_config.keys() and combiner_config["fqdn"] is not None: | ||
host = combiner_config["fqdn"] | ||
# assuming https if fqdn is used | ||
port = 443 | ||
|
@@ -418,11 +418,11 @@ def _listen_to_task_stream(self): | |
elif request.type == fedn.StatusType.MODEL_VALIDATION and self.config["validator"]: | ||
self.inbox.put(("validate", request)) | ||
elif request.type == fedn.StatusType.INFERENCE and self.config["validator"]: | ||
logger.info("Received inference request for model_id {}".format(request.model_id)) | ||
logger.info("Received prediction request for model_id {}".format(request.model_id)) | ||
presigned_url = json.loads(request.data) | ||
presigned_url = presigned_url["presigned_url"] | ||
logger.info("Inference presigned URL: {}".format(presigned_url)) | ||
self.inbox.put(("infer", request)) | ||
logger.info("Prediction presigned URL: {}".format(presigned_url)) | ||
self.inbox.put(("predict", request)) | ||
else: | ||
logger.error("Unknown request type: {}".format(request.type)) | ||
|
||
|
@@ -519,25 +519,17 @@ def _process_training_request(self, model_id: str, session_id: str = None): | |
|
||
return updated_model_id, meta | ||
|
||
def _process_validation_request(self, model_id: str, is_inference: bool, session_id: str = None): | ||
def _process_validation_request(self, model_id: str, session_id: str = None): | ||
"""Process a validation request. | ||
|
||
:param model_id: The model id of the model to be validated. | ||
:type model_id: str | ||
:param is_inference: True if the validation is an inference request, False if it is a validation request. | ||
:type is_inference: bool | ||
:param session_id: The id of the current session. | ||
:type session_id: str | ||
:return: The validation metrics, or None if validation failed. | ||
:rtype: dict | ||
""" | ||
# Figure out cmd | ||
if is_inference: | ||
cmd = "infer" | ||
else: | ||
cmd = "validate" | ||
|
||
self.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id) | ||
self.send_status(f"Processing validation request for model_id {model_id}", sesssion_id=session_id) | ||
self.state = ClientState.validating | ||
try: | ||
model = self.get_model_from_combiner(str(model_id)) | ||
|
@@ -550,7 +542,7 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session | |
fh.write(model.getbuffer()) | ||
|
||
outpath = get_tmp_path() | ||
self.dispatcher.run_cmd(f"{cmd} {inpath} {outpath}") | ||
self.dispatcher.run_cmd(f"validate {inpath} {outpath}") | ||
|
||
with open(outpath, "r") as fh: | ||
validation = json.loads(fh.read()) | ||
|
@@ -566,22 +558,22 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session | |
self.state = ClientState.idle | ||
return validation | ||
|
||
def _process_inference_request(self, model_id: str, session_id: str, presigned_url: str): | ||
"""Process an inference request. | ||
def _process_prediction_request(self, model_id: str, session_id: str, presigned_url: str): | ||
"""Process an prediction request. | ||
|
||
:param model_id: The model id of the model to be used for inference. | ||
:param model_id: The model id of the model to be used for prediction. | ||
:type model_id: str | ||
:param session_id: The id of the current session. | ||
:type session_id: str | ||
:param presigned_url: The presigned URL for the data to be used for inference. | ||
:param presigned_url: The presigned URL for the data to be used for prediction. | ||
:type presigned_url: str | ||
:return: None | ||
""" | ||
self.send_status(f"Processing inference request for model_id {model_id}", sesssion_id=session_id) | ||
self.send_status(f"Processing prediction request for model_id {model_id}", sesssion_id=session_id) | ||
try: | ||
model = self.get_model_from_combiner(str(model_id)) | ||
if model is None: | ||
logger.error("Could not retrieve model from combiner. Aborting inference request.") | ||
logger.error("Could not retrieve model from combiner. Aborting prediction request.") | ||
return | ||
inpath = self.helper.get_tmp_path() | ||
|
||
|
@@ -591,20 +583,20 @@ def _process_inference_request(self, model_id: str, session_id: str, presigned_u | |
outpath = get_tmp_path() | ||
self.dispatcher.run_cmd(f"predict {inpath} {outpath}") | ||
|
||
# Upload the inference result to the presigned URL | ||
# Upload the prediction result to the presigned URL | ||
with open(outpath, "rb") as fh: | ||
response = requests.put(presigned_url, data=fh.read()) | ||
|
||
os.unlink(inpath) | ||
os.unlink(outpath) | ||
|
||
if response.status_code != 200: | ||
logger.warning("Inference upload failed with status code {}".format(response.status_code)) | ||
logger.warning("Prediction upload failed with status code {}".format(response.status_code)) | ||
self.state = ClientState.idle | ||
return | ||
|
||
except Exception as e: | ||
logger.warning("Inference failed with exception {}".format(e)) | ||
logger.warning("Prediction failed with exception {}".format(e)) | ||
self.state = ClientState.idle | ||
return | ||
|
||
|
@@ -668,7 +660,7 @@ def process_request(self): | |
|
||
elif task_type == "validate": | ||
self.state = ClientState.validating | ||
metrics = self._process_validation_request(request.model_id, False, request.session_id) | ||
metrics = self._process_validation_request(request.model_id, request.session_id) | ||
|
||
if metrics is not None: | ||
# Send validation | ||
|
@@ -707,21 +699,47 @@ def process_request(self): | |
|
||
self.state = ClientState.idle | ||
self.inbox.task_done() | ||
elif task_type == "infer": | ||
self.state = ClientState.inferencing | ||
elif task_type == "predict": | ||
self.state = ClientState.predicting | ||
try: | ||
presigned_url = json.loads(request.data) | ||
except json.JSONDecodeError as e: | ||
logger.error(f"Failed to decode inference request data: {e}") | ||
logger.error(f"Failed to decode prediction request data: {e}") | ||
self.state = ClientState.idle | ||
continue | ||
|
||
if "presigned_url" not in presigned_url: | ||
logger.error("Inference request missing presigned_url.") | ||
logger.error("Prediction request missing presigned_url.") | ||
self.state = ClientState.idle | ||
continue | ||
presigned_url = presigned_url["presigned_url"] | ||
_ = self._process_inference_request(request.model_id, request.session_id, presigned_url) | ||
# Obs that session_id in request is the prediction_id | ||
_ = self._process_prediction_request(request.model_id, request.session_id, presigned_url) | ||
prediction = fedn.ModelPrediction() | ||
prediction.sender.name = self.name | ||
prediction.sender.role = fedn.WORKER | ||
prediction.receiver.name = request.sender.name | ||
prediction.receiver.name = request.sender.name | ||
prediction.receiver.role = request.sender.role | ||
prediction.model_id = str(request.model_id) | ||
# TODO: Add prediction data | ||
prediction.data = "" | ||
prediction.timestamp.GetCurrentTime() | ||
prediction.correlation_id = request.correlation_id | ||
# Obs that session_id in request is the prediction_id | ||
prediction.prediction_id = request.session_id | ||
|
||
try: | ||
_ = self.combinerStub.SendModelPrediction(prediction, metadata=self.metadata) | ||
status_type = fedn.StatusType.INFERENCE | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can change StatusType.INFERENCE to StatusType.MODEL_PREDICTION in fedn.proto |
||
self.send_status( | ||
"Model prediction completed.", log_level=fedn.Status.AUDIT, type=status_type, request=prediction, sesssion_id=request.session_id | ||
) | ||
except grpc.RpcError as e: | ||
status_code = e.code() | ||
logger.error("GRPC error, {}.".format(status_code.name)) | ||
logger.debug(e) | ||
|
||
self.state = ClientState.idle | ||
except queue.Empty: | ||
pass | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ class ClientState(Enum): | |
idle = 1 | ||
training = 2 | ||
validating = 3 | ||
inferencing = 4 | ||
predicting = 4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can also add the ClientStateToString for "predicting" |
||
|
||
|
||
def ClientStateToString(state): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"a" predtiction