Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-1081 | Use stores in Combiner + ModelPredict #718

Merged
merged 18 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ def _wait_n_rounds(collection):
# Wait for successful rounds
succeded = _wait_n_rounds(client['fedn-test-network']['control']['status'])
assert(succeded == N_CLIENTS) # check that all rounds succeeded
_eprint(f'Succeded inference clients: {succeded}. Test passed.')
_eprint(f'Succeded prediction clients: {succeded}. Test passed.')
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ if [ "$#" -lt 1 ]; then
fi
example="$1"

>&2 echo "Run inference"
>&2 echo "Run prediction"
pushd "examples/$example"
curl -k -X POST https://localhost:8090/infer
curl -k -X POST https://localhost:8090/predict

>&2 echo "Checking inference success"
".$example/bin/python" ../../.ci/tests/examples/inference_test.py
>&2 echo "Checking prediction success"
".$example/bin/python" ../../.ci/tests/examples/prediction_test.py

>&2 echo "Test completed successfully"
popd
4 changes: 2 additions & 2 deletions .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ jobs:
- name: run ${{ matrix.to_test }}
run: .ci/tests/examples/run.sh ${{ matrix.to_test }}

# - name: run ${{ matrix.to_test }} inference
# run: .ci/tests/examples/run_inference.sh ${{ matrix.to_test }}
# - name: run ${{ matrix.to_test }} prediction
# run: .ci/tests/examples/run_prediction.sh ${{ matrix.to_test }}
# if: ${{ matrix.os != 'macos-11' && matrix.to_test == 'mnist-keras keras' }} # example available for Keras

- name: print logs
Expand Down
4 changes: 2 additions & 2 deletions examples/mnist-keras/client/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@


def predict(in_model_path, out_json_path, data_path=None):
# Using test data for inference but another dataset could be loaded
# Using test data for prediction but another dataset could be loaded
x_test, _ = load_data(data_path, is_train=False)

# Load model
model = load_parameters(in_model_path)

# Infer
# Predict
y_pred = model.predict(x_test)
y_pred = np.argmax(y_pred, axis=1)

Expand Down
3 changes: 1 addition & 2 deletions fedn/cli/status_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
@main.group("status")
@click.pass_context
def status_cmd(ctx):
""":param ctx:
"""
""":param ctx:"""
pass


Expand Down
4 changes: 2 additions & 2 deletions fedn/network/api/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from fedn.network.api.v1.client_routes import bp as client_bp
from fedn.network.api.v1.combiner_routes import bp as combiner_bp
from fedn.network.api.v1.inference_routes import bp as inference_bp
from fedn.network.api.v1.prediction_routes import bp as prediction_bp
from fedn.network.api.v1.model_routes import bp as model_bp
from fedn.network.api.v1.package_routes import bp as package_bp
from fedn.network.api.v1.round_routes import bp as round_bp
from fedn.network.api.v1.session_routes import bp as session_bp
from fedn.network.api.v1.status_routes import bp as status_bp
from fedn.network.api.v1.validation_routes import bp as validation_bp

_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp, inference_bp]
_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp, prediction_bp]
34 changes: 0 additions & 34 deletions fedn/network/api/v1/inference_routes.py

This file was deleted.

37 changes: 37 additions & 0 deletions fedn/network/api/v1/prediction_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import threading

from flask import Blueprint, jsonify, request

from fedn.network.api.auth import jwt_auth_required
from fedn.network.api.shared import control
from fedn.network.api.v1.shared import api_version, mdb
from fedn.network.storage.statestore.stores.prediction_store import PredictionStore

bp = Blueprint("prediction", __name__, url_prefix=f"/api/{api_version}/predict")

prediction_store = PredictionStore(mdb, "control.predictions")


@bp.route("/start", methods=["POST"])
@jwt_auth_required(role="admin")
def start_session():
"""Start a new prediction session.
param: prediction_id: The session id to start.
type: prediction_id: str
param: rounds: The number of rounds to run.
type: rounds: int
"""
try:
data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict()
prediction_id: str = data.get("prediction_id")

if not prediction_id or prediction_id == "":
return jsonify({"message": "Session ID is required"}), 400

session_config = {"prediction_id": prediction_id}

threading.Thread(target=control.prediction_session, kwargs={"config": session_config}).start()

return jsonify({"message": "Prediction session started"}), 200
except Exception:
return jsonify({"message": "Failed to start prediction session"}), 500
4 changes: 4 additions & 0 deletions fedn/network/api/v1/status_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,12 @@ def get_statuses():
limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers)
kwargs = request.args.to_dict()

# print all the typed headers
print(f"limit: {limit}, skip: {skip}, sort_key: {sort_key}, sort_order: {sort_order}, use_typing: {use_typing}")
print(f"kwargs: {kwargs}")
statuses = status_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs)

print(f"statuses: {statuses}")
result = [status.__dict__ for status in statuses["result"]] if use_typing else statuses["result"]

response = {"count": statuses["count"], "result": result}
Expand Down
78 changes: 48 additions & 30 deletions fedn/network/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def connect(self, combiner_config):
logger.debug("Client using metadata: {}.".format(self.metadata))
port = combiner_config["port"]
secure = False
if combiner_config["fqdn"] is not None:
if "fqdn" in combiner_config.keys() and combiner_config["fqdn"] is not None:
host = combiner_config["fqdn"]
# assuming https if fqdn is used
port = 443
Expand Down Expand Up @@ -418,11 +418,11 @@ def _listen_to_task_stream(self):
elif request.type == fedn.StatusType.MODEL_VALIDATION and self.config["validator"]:
self.inbox.put(("validate", request))
elif request.type == fedn.StatusType.INFERENCE and self.config["validator"]:
logger.info("Received inference request for model_id {}".format(request.model_id))
logger.info("Received prediction request for model_id {}".format(request.model_id))
presigned_url = json.loads(request.data)
presigned_url = presigned_url["presigned_url"]
logger.info("Inference presigned URL: {}".format(presigned_url))
self.inbox.put(("infer", request))
logger.info("Prediction presigned URL: {}".format(presigned_url))
self.inbox.put(("predict", request))
else:
logger.error("Unknown request type: {}".format(request.type))

Expand Down Expand Up @@ -519,25 +519,17 @@ def _process_training_request(self, model_id: str, session_id: str = None):

return updated_model_id, meta

def _process_validation_request(self, model_id: str, is_inference: bool, session_id: str = None):
def _process_validation_request(self, model_id: str, session_id: str = None):
"""Process a validation request.

:param model_id: The model id of the model to be validated.
:type model_id: str
:param is_inference: True if the validation is an inference request, False if it is a validation request.
:type is_inference: bool
:param session_id: The id of the current session.
:type session_id: str
:return: The validation metrics, or None if validation failed.
:rtype: dict
"""
# Figure out cmd
if is_inference:
cmd = "infer"
else:
cmd = "validate"

self.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id)
self.send_status(f"Processing validation request for model_id {model_id}", sesssion_id=session_id)
self.state = ClientState.validating
try:
model = self.get_model_from_combiner(str(model_id))
Expand All @@ -550,7 +542,7 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session
fh.write(model.getbuffer())

outpath = get_tmp_path()
self.dispatcher.run_cmd(f"{cmd} {inpath} {outpath}")
self.dispatcher.run_cmd(f"validate {inpath} {outpath}")

with open(outpath, "r") as fh:
validation = json.loads(fh.read())
Expand All @@ -566,22 +558,22 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session
self.state = ClientState.idle
return validation

def _process_inference_request(self, model_id: str, session_id: str, presigned_url: str):
"""Process an inference request.
def _process_prediction_request(self, model_id: str, session_id: str, presigned_url: str):
"""Process an prediction request.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"a" predtiction


:param model_id: The model id of the model to be used for inference.
:param model_id: The model id of the model to be used for prediction.
:type model_id: str
:param session_id: The id of the current session.
:type session_id: str
:param presigned_url: The presigned URL for the data to be used for inference.
:param presigned_url: The presigned URL for the data to be used for prediction.
:type presigned_url: str
:return: None
"""
self.send_status(f"Processing inference request for model_id {model_id}", sesssion_id=session_id)
self.send_status(f"Processing prediction request for model_id {model_id}", sesssion_id=session_id)
try:
model = self.get_model_from_combiner(str(model_id))
if model is None:
logger.error("Could not retrieve model from combiner. Aborting inference request.")
logger.error("Could not retrieve model from combiner. Aborting prediction request.")
return
inpath = self.helper.get_tmp_path()

Expand All @@ -591,20 +583,20 @@ def _process_inference_request(self, model_id: str, session_id: str, presigned_u
outpath = get_tmp_path()
self.dispatcher.run_cmd(f"predict {inpath} {outpath}")

# Upload the inference result to the presigned URL
# Upload the prediction result to the presigned URL
with open(outpath, "rb") as fh:
response = requests.put(presigned_url, data=fh.read())

os.unlink(inpath)
os.unlink(outpath)

if response.status_code != 200:
logger.warning("Inference upload failed with status code {}".format(response.status_code))
logger.warning("Prediction upload failed with status code {}".format(response.status_code))
self.state = ClientState.idle
return

except Exception as e:
logger.warning("Inference failed with exception {}".format(e))
logger.warning("Prediction failed with exception {}".format(e))
self.state = ClientState.idle
return

Expand Down Expand Up @@ -668,7 +660,7 @@ def process_request(self):

elif task_type == "validate":
self.state = ClientState.validating
metrics = self._process_validation_request(request.model_id, False, request.session_id)
metrics = self._process_validation_request(request.model_id, request.session_id)

if metrics is not None:
# Send validation
Expand Down Expand Up @@ -707,21 +699,47 @@ def process_request(self):

self.state = ClientState.idle
self.inbox.task_done()
elif task_type == "infer":
self.state = ClientState.inferencing
elif task_type == "predict":
self.state = ClientState.predicting
try:
presigned_url = json.loads(request.data)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode inference request data: {e}")
logger.error(f"Failed to decode prediction request data: {e}")
self.state = ClientState.idle
continue

if "presigned_url" not in presigned_url:
logger.error("Inference request missing presigned_url.")
logger.error("Prediction request missing presigned_url.")
self.state = ClientState.idle
continue
presigned_url = presigned_url["presigned_url"]
_ = self._process_inference_request(request.model_id, request.session_id, presigned_url)
# Obs that session_id in request is the prediction_id
_ = self._process_prediction_request(request.model_id, request.session_id, presigned_url)
prediction = fedn.ModelPrediction()
prediction.sender.name = self.name
prediction.sender.role = fedn.WORKER
prediction.receiver.name = request.sender.name
prediction.receiver.name = request.sender.name
prediction.receiver.role = request.sender.role
prediction.model_id = str(request.model_id)
# TODO: Add prediction data
prediction.data = ""
prediction.timestamp.GetCurrentTime()
prediction.correlation_id = request.correlation_id
# Obs that session_id in request is the prediction_id
prediction.prediction_id = request.session_id

try:
_ = self.combinerStub.SendModelPrediction(prediction, metadata=self.metadata)
status_type = fedn.StatusType.INFERENCE
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can change StatusType.INFERENCE to StatusType.MODEL_PREDICTION in fedn.proto

self.send_status(
"Model prediction completed.", log_level=fedn.Status.AUDIT, type=status_type, request=prediction, sesssion_id=request.session_id
)
except grpc.RpcError as e:
status_code = e.code()
logger.error("GRPC error, {}.".format(status_code.name))
logger.debug(e)

self.state = ClientState.idle
except queue.Empty:
pass
Expand Down
2 changes: 1 addition & 1 deletion fedn/network/clients/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class ClientState(Enum):
idle = 1
training = 2
validating = 3
inferencing = 4
predicting = 4
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also add the ClientStateToString for "predicting"



def ClientStateToString(state):
Expand Down
Loading
Loading