diff --git a/fl_server_api/openapi.py b/fl_server_api/openapi.py index 0f2354a..81519dd 100644 --- a/fl_server_api/openapi.py +++ b/fl_server_api/openapi.py @@ -67,6 +67,15 @@ def create_error_response( """Generic OpenAPI 403 response.""" +error_response_404 = create_error_response( + "Not found", + "Not found", + "The server cannot find the requested resource.", + "Provide valid request data." +) +"""Generic OpenAPI 404 response.""" + + def custom_preprocessing_hook(endpoints: List[Tuple[str, str, str, Callable]]): """ Hide the "/api/dummy/" endpoint from the OpenAPI schema. diff --git a/fl_server_api/serializers/model.py b/fl_server_api/serializers/model.py index bdb46b4..131cb50 100644 --- a/fl_server_api/serializers/model.py +++ b/fl_server_api/serializers/model.py @@ -93,6 +93,8 @@ def to_representation(self, instance): del data["weights"] if self.context.get("with-stats", False): data["stats"] = self.analyze_torch_model(instance) + if isinstance(instance, GlobalModel): + data["has_preprocessing"] = bool(instance.preprocessing) return data def analyze_torch_model(self, instance: Model): @@ -175,6 +177,7 @@ class ModelSerializerNoWeights(ModelSerializer): class Meta: model = Model exclude = ["polymorphic_ctype", "weights"] + include = ["has_preprocessing"] class ModelSerializerNoWeightsWithStats(ModelSerializerNoWeights): @@ -186,6 +189,7 @@ class Meta: model = Model exclude = ["polymorphic_ctype", "weights"] include = ["stats"] + include = ["has_preprocessing", "stats"] ####################################################################################################################### diff --git a/fl_server_api/tests/test_inference.py b/fl_server_api/tests/test_inference.py index bcc200c..e0ba031 100644 --- a/fl_server_api/tests/test_inference.py +++ b/fl_server_api/tests/test_inference.py @@ -3,11 +3,15 @@ # # SPDX-License-Identifier: Apache-2.0 +import base64 from django.core.files.uploadedfile import SimpleUploadedFile from django.test import TestCase import json +import io import pickle import torch +import torch.nn +from torchvision.transforms.functional import to_pil_image from uuid import uuid4 from fl_server_core.tests import BASE_URL, Dummy @@ -168,3 +172,103 @@ def _inference_result(self, torch_model: torch.nn.Module): self.assertIsNotNone(inference) inference_tensor = torch.as_tensor(inference) self.assertTrue(torch.all(torch.tensor([2, 0, 0]) == inference_tensor)) + + def test_inference_input_shape_positive(self): + inp = from_torch_tensor(torch.zeros(3, 3)) + model = Dummy.create_model(input_shape=[None, 3]) + training = Dummy.create_training(actor=self.user, model=model) + input_file = SimpleUploadedFile( + "input.pt", + inp, + content_type="application/octet-stream" + ) + response = self.client.post( + f"{BASE_URL}/inference/", + {"model_id": str(training.model.id), "model_input": input_file} + ) + self.assertEqual(response.status_code, 200) + + def test_inference_input_shape_negative(self): + inp = from_torch_tensor(torch.zeros(3, 3)) + model = Dummy.create_model(input_shape=[None, 5]) + training = Dummy.create_training(actor=self.user, model=model) + input_file = SimpleUploadedFile( + "input.pt", + inp, + content_type="application/octet-stream" + ) + with self.assertLogs("root", level="WARNING") as cm: + response = self.client.post( + f"{BASE_URL}/inference/", + {"model_id": str(training.model.id), "model_input": input_file} + ) + self.assertEqual(cm.output, [ + "WARNING:django.request:Bad Request: /api/inference/", + ]) + self.assertEqual(response.status_code, 400) + self.assertEqual(response.json()[0], "Input shape does not match model input shape.") + + def test_inference_input_pil_image(self): + img = to_pil_image(torch.zeros(1, 5, 5)) + img_byte_arr = io.BytesIO() + img.save(img_byte_arr, format="jpeg") + img_byte_arr = img_byte_arr.getvalue() + + torch.manual_seed(42) + torch_model = torch.jit.script(torch.nn.Sequential( + torch.nn.Conv2d(1, 2, 3), + torch.nn.Flatten(), + torch.nn.Linear(3*3, 2) + )) + model = Dummy.create_model(input_shape=[None, 5, 5], weights=from_torch_module(torch_model)) + training = Dummy.create_training(actor=self.user, model=model) + input_file = SimpleUploadedFile( + "input.pt", + img_byte_arr, + content_type="application/octet-stream" + ) + response = self.client.post( + f"{BASE_URL}/inference/", + {"model_id": str(training.model.id), "model_input": input_file} + ) + self.assertEqual(response.status_code, 200) + + results = pickle.loads(response.content) + self.assertEqual({}, results["uncertainty"]) + inference = results["inference"] + self.assertIsNotNone(inference) + inference_tensor = torch.as_tensor(inference) + self.assertTrue(torch.all(torch.tensor([0, 0]) == inference_tensor)) + + def test_inference_input_pil_image_base64(self): + img = to_pil_image(torch.zeros(1, 5, 5)) + img_byte_arr = io.BytesIO() + img.save(img_byte_arr, format="jpeg") + img_byte_arr = img_byte_arr.getvalue() + inp = base64.b64encode(img_byte_arr) + + torch.manual_seed(42) + torch_model = torch.jit.script(torch.nn.Sequential( + torch.nn.Conv2d(1, 2, 3), + torch.nn.Flatten(), + torch.nn.Linear(3*3, 2) + )) + model = Dummy.create_model(input_shape=[None, 5, 5], weights=from_torch_module(torch_model)) + training = Dummy.create_training(actor=self.user, model=model) + input_file = SimpleUploadedFile( + "input.pt", + inp, + content_type="application/octet-stream" + ) + response = self.client.post( + f"{BASE_URL}/inference/", + {"model_id": str(training.model.id), "model_input": input_file} + ) + self.assertEqual(response.status_code, 200) + + results = pickle.loads(response.content) + self.assertEqual({}, results["uncertainty"]) + inference = results["inference"] + self.assertIsNotNone(inference) + inference_tensor = torch.as_tensor(inference) + self.assertTrue(torch.all(torch.tensor([0, 0]) == inference_tensor)) diff --git a/fl_server_api/tests/test_model.py b/fl_server_api/tests/test_model.py index b10d29a..30841e7 100644 --- a/fl_server_api/tests/test_model.py +++ b/fl_server_api/tests/test_model.py @@ -149,6 +149,7 @@ def test_get_model_metadata(self): self.assertEqual(str(model.name), response_json["name"]) self.assertEqual(str(model.description), response_json["description"]) self.assertEqual(model.input_shape, response_json["input_shape"]) + self.assertFalse(response_json["has_preprocessing"]) # check stats stats = response_json["stats"] self.assertIsNotNone(stats) @@ -232,6 +233,28 @@ def test_get_model_metadata(self): self.assertIsNotNone(layer4["output_bytes"]) self.assertIsNotNone(layer4["macs"]) + def test_get_model_metadata_with_preprocessing(self): + model_bytes = from_torch_module(torch.nn.Sequential( + torch.nn.Linear(3, 64), + torch.nn.ELU(), + torch.nn.Linear(64, 1), + )) + torch_model_preprocessing = from_torch_module(transforms.Compose([ + transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), + transforms.Normalize(mean=(0.,), std=(1.,)), + ])) + model = Dummy.create_model(weights=model_bytes, preprocessing=torch_model_preprocessing, input_shape=[None, 3]) + response = self.client.get(f"{BASE_URL}/models/{model.id}/metadata/") + self.assertEqual(200, response.status_code) + self.assertEqual("application/json", response["content-type"]) + response_json = response.json() + self.assertEqual(str(model.id), response_json["id"]) + self.assertEqual(str(model.name), response_json["name"]) + self.assertEqual(str(model.description), response_json["description"]) + self.assertEqual(model.input_shape, response_json["input_shape"]) + self.assertTrue(response_json["has_preprocessing"]) + def test_get_model_metadata_torchscript_model(self): torchscript_model_bytes = from_torch_module(torch.jit.script(torch.nn.Sequential( torch.nn.Linear(3, 64), @@ -552,6 +575,30 @@ def test_upload_model_preprocessing_v2_Compose_good(self): self.assertIsNotNone(model.preprocessing) self.assertTrue(isinstance(model.get_preprocessing_torch_model(), torch.nn.Module)) + def test_download_model_preprocessing(self): + torch_model_preprocessing = from_torch_module(torch.jit.script(torch.nn.Sequential( + transforms.Normalize(mean=(0.,), std=(1.,)), + ))) + model = Dummy.create_model(owner=self.user, preprocessing=torch_model_preprocessing) + response = self.client.get(f"{BASE_URL}/models/{model.id}/preprocessing/") + self.assertEqual(200, response.status_code) + self.assertEqual("application/octet-stream", response["content-type"]) + torch_model = torch.jit.load(io.BytesIO(response.content)) + self.assertIsNotNone(torch_model) + self.assertTrue(isinstance(torch_model, torch.nn.Module)) + + def test_download_model_preprocessing_with_undefined_preprocessing(self): + model = Dummy.create_model(owner=self.user, preprocessing=None) + with self.assertLogs("django.request", level="WARNING") as cm: + response = self.client.get(f"{BASE_URL}/models/{model.id}/preprocessing/") + self.assertEqual(cm.output, [ + f"WARNING:django.request:Not Found: /api/models/{model.id}/preprocessing/", + ]) + self.assertEqual(404, response.status_code) + response_json = response.json() + self.assertIsNotNone(response_json) + self.assertEqual(f"Model '{model.id}' has no preprocessing model defined.", response_json["detail"]) + @patch("fl_server_ai.trainer.tasks.process_trainer_task.apply_async") def test_upload_update(self, apply_async: MagicMock): model = Dummy.create_model(owner=self.user, round=0) diff --git a/fl_server_api/urls.py b/fl_server_api/urls.py index 3432278..2441f29 100644 --- a/fl_server_api/urls.py +++ b/fl_server_api/urls.py @@ -38,7 +38,7 @@ {"get": "get_model_metrics", "post": "create_model_metrics"} ), name="model-metrics"), path("models//preprocessing/", view=Model.as_view( - {"post": "upload_model_preprocessing"} + {"get": "get_model_proprecessing", "post": "upload_model_preprocessing"} ), name="model-preprocessing"), path("models//swag/", view=Model.as_view({"post": "create_swag_stats"}), name="model-swag"), # trainings diff --git a/fl_server_api/views/inference.py b/fl_server_api/views/inference.py index d842e5c..3e82b19 100644 --- a/fl_server_api/views/inference.py +++ b/fl_server_api/views/inference.py @@ -3,19 +3,24 @@ # # SPDX-License-Identifier: Apache-2.0 +import base64 from django.http import HttpRequest, HttpResponse from drf_spectacular.utils import inline_serializer, extend_schema, OpenApiExample import json +from io import BytesIO import pickle +from PIL import Image from rest_framework import status from rest_framework.exceptions import APIException, UnsupportedMediaType, ValidationError -from rest_framework.fields import ListField, DictField, FloatField, CharField +from rest_framework.fields import CharField, ChoiceField, DictField, FloatField, ListField import torch -from typing import Any, Dict, Tuple, Type +from torchvision.transforms.functional import to_tensor +from typing import Any, Dict, List, Literal, Optional, Tuple, Type from fl_server_ai.uncertainty import get_uncertainty_class, UncertaintyBase from fl_server_core.exceptions import TorchDeserializationException -from fl_server_core.models import Model, GlobalModel +from fl_server_core.models import Model, GlobalModel, LocalModel +from fl_server_core.utils.logging import disable_logger from fl_server_core.utils.torch_serialization import to_torch_tensor from .base import ViewSet @@ -40,7 +45,7 @@ class Inference(ViewSet): fields={ "model_id": CharField(), "model_input": ListField(child=ListField(child=FloatField())), - "return_format": CharField() + "return_format": ChoiceField(["binary", "json"]) } ), responses={ @@ -60,90 +65,286 @@ class Inference(ViewSet): ) def inference(self, request: HttpRequest) -> HttpResponse: """ - Processes a request to do inference on a existing model. + Performs inference on the provided model and input data. - This method checks the content type of the request and calls the appropriate method to process the request. - If the content type is not supported, it raises an UnsupportedMediaType exception. + This method takes in an HTTP request containing the necessary metadata and input data, + performs any required preprocessing on the input data, runs the inference using the specified model, + and returns a response in the format specified by the `return_format` parameter including + possible uncertainty measurements if defined. - This method can process both JSON data as well as formdata with an attached PyTorch serialised input tensor. - For the provided example to run, the model "mymodel" must have been created via the model-endpoint first, - and the user must be authorized to access it. + Args: + request (HttpRequest): The current HTTP request. + + Returns: + HttpResponse: A HttpResponse containing the result of the inference as well as its uncertainty. + """ + request_body, is_json = self._get_handle_content_type(request) + model, preprocessing, input_shape, return_format = self._get_inference_metadata( + request_body, + "json" if is_json else "binary" + ) + model_input = self._get_model_input(request, request_body) + + if preprocessing: + model_input = preprocessing(model_input) + else: + # if no preprocessing is defined, at least try to convert/interpret the model_input as + # PyTorch tensor, before raising an exception + model_input = self._try_cast_model_input_to_tensor(model_input) + self._validate_model_input_after_preprocessing(model_input, input_shape, bool(preprocessing)) + + uncertainty_cls, inference, uncertainty = self._do_inference(model, model_input) + return self._make_response(uncertainty_cls, inference, uncertainty, return_format) + + def _get_handle_content_type(self, request: HttpRequest) -> Tuple[dict, bool]: + """ + Handles HTTP request body based on their content type. + + This function checks if the request content type is either `application/json` + or `multipart/form-data`. If it matches, it returns the corresponding data and + a boolean indicating whether it's JSON (True) or multipart/form-data (False). Args: request (HttpRequest): The request. Returns: - HttpResponse: The results of the inference. + tuple: A tuple containing the parsed data and a boolean indicating the content type. + * If content type is `application/json`, returns the JSON payload as a Python object (dict) + and True to indicate it's JSON. + * If content type is `multipart/form-data`, returns the request POST data and False. + + Raises: + UnsupportedMediaType: If an unknown content type is specified, raising an error with + details on supported types (`application/json` and `multipart/form-data`). """ match request.content_type.lower(): case s if s.startswith("multipart/form-data"): - return self._process_post(request) + return request.POST, False case s if s.startswith("application/json"): - return self._process_post_json(request) - case _: - # If the content type is specified, but not supported, return 415 - self._logger.error(f"Unknown Content-Type '{request.content_type}'") - raise UnsupportedMediaType( - "Only Content-Type 'application/json' and 'multipart/form-data' is supported." - ) + return json.loads(request.body), True + + # if the content type is specified, but not supported, return 415 + self._logger.error(f"Unknown Content-Type '{request.content_type}'") + raise UnsupportedMediaType( + "Only Content-Type 'application/json' and 'multipart/form-data' is supported." + ) - def _process_post(self, request: HttpRequest) -> HttpResponse: + def _get_inference_metadata( + self, + request_body: dict, + return_format_default: Literal["binary", "json"] + ) -> Tuple[Model, Optional[torch.nn.Module], Optional[List[Optional[int]]], str]: """ - Processes a POST request with form-data. + Retrieves inference metadata based on the content of the provided request body. + + This method checks if a `model_id` is present in the request body and retrieves + the corresponding model entity. It then determines the return format based on the + request body or default to one of the two supported formats (`binary` or `json`). Args: - request (HttpRequest): The request. + request_body (dict): The data sent with the request, containing at least `model_id`. + return_format_default (Literal["binary", "json"]): The default return format to use if not specified in + the request body. + + Returns: + Tuple[Model, Optional[torch.nn.Module], Optional[List[Optional[int]]], str]: A tuple containing: + * The retrieved model entity. + * The global model's preprocessing torch module (if applicable). + * The input shape of the global model (if applicable). + * The return format (`binary` or `json`). + + Raises: + ValidationError: If no valid `model_id` is provided in the request body, or if an unknown return format + is specified. + """ + if "model_id" not in request_body: + self._logger.error("No 'model_id' provided in request.") + raise ValidationError("No 'model_id' provided in request.") + model_id = request_body["model_id"] + model = get_entity(Model, pk=model_id) + + return_format = request_body.get("return_format", return_format_default) + if return_format not in ["binary", "json"]: + self._logger.error(f"Unknown return format '{return_format}'. Supported are binary and json.") + raise ValidationError(f"Unknown return format '{return_format}'. Supported are binary and json.") + + global_model: Optional[GlobalModel] = None + if isinstance(model, GlobalModel): + global_model = model + elif isinstance(model, LocalModel): + global_model = model.base_model + else: + self._logger.error("Unknown model type. Not a GlobalModel and not a LocalModel. Skip preprocessing.") + + preprocessing: Optional[torch.nn.Module] = None + input_shape: Optional[List[Optional[int]]] = None + if global_model: + if global_model.preprocessing is not None: + preprocessing = global_model.get_preprocessing_torch_model() + if global_model.input_shape is not None: + input_shape = global_model.input_shape + + return model, preprocessing, input_shape, return_format + + def _get_model_input(self, request: HttpRequest, request_body: dict) -> Any: + """ + Retrieves and decodes the model input from either an uploaded file or the request body. + + Args: + request (HttpRequest): The current HTTP request. + request_body (dict): The parsed request body as a dictionary. + + Returns: + Any: The decoded model input data. + + Raises: + ValidationError: If no `model_input` is found in the uploaded file or the request body. + """ + uploaded_file = request.FILES.get("model_input", None) + if uploaded_file and uploaded_file.file: + model_input = uploaded_file.file.read() + else: + model_input = request_body.get("model_input", None) + if not model_input: + raise ValidationError("No uploaded file 'model_input' found.") + return self._try_decode_model_input(model_input) + + def _try_decode_model_input(self, model_input: Any) -> Any: + """ + Attempts to decode the input `model_input` from various formats and returns it in a usable form. + + This function first tries to deserialize the input as a PyTorch tensor. If that fails, it attempts to + decode the input as a base64-encoded string. If neither attempt is successful, the original input is returned. + + Args: + model_input (Any): The input to be decoded, which can be in any format. Returns: - HttpResponse: The results of the inference. + Any: The decoded input, which may still be in an unknown format if decoding attempts fail. """ + # 1. try to deserialize model_input as PyTorch tensor try: - model_id = request.POST["model_id"] - uploaded_file = request.FILES.get("model_input") - return_format = request.POST.get("return_format", "binary") - assert return_format in ["binary", "json"] - if not uploaded_file or not uploaded_file.file: - raise ValidationError("No uploaded file 'model_input' not found.") - feature_vectors = uploaded_file.file.read() - except Exception as e: - self._logger.error(e) - raise ValidationError("Inference Request could not be interpreted!") + with disable_logger(self._logger): + model_input = to_torch_tensor(model_input) + except Exception: + pass + # 2. try to decode model_input as base64 + try: + is_base64, tmp_model_input = self._is_base64(model_input) + if is_base64: + model_input = tmp_model_input + except Exception: + pass + # result + return model_input - model = get_entity(Model, pk=model_id) - input_tensor = to_torch_tensor(feature_vectors) - if isinstance(model, GlobalModel) and model.preprocessing is not None: - preprocessing = model.get_preprocessing_torch_model() - input_tensor = preprocessing(input_tensor) - uncertainty_cls, inference, uncertainty = self.do_inference(model, input_tensor) - return self._make_response(uncertainty_cls, inference, uncertainty, return_format) + def _try_cast_model_input_to_tensor(self, model_input: Any) -> Any: + """ + Attempt to cast the given model input to a PyTorch tensor. + + This function tries to interpret the input in several formats: + + 1. PIL Image (and later convert it to a PyTorch tensor, see 3.) + 2. PyTorch tensor via `torch.as_tensor` + 3. PyTorch tensor via torchvision `ToTensor` (supports e.g. PIL images) - def _process_post_json(self, request: HttpRequest, body: Any = None) -> HttpResponse: + If none of these attempts are successful, the original input is returned. + + Args: + model_input: The input data to be cast to a PyTorch tensor. + Can be any type that can be converted to a tensor. + + Returns: + A PyTorch tensor representation of the input data, or the original + input if it cannot be converted. """ - Processes a POST request with JSON data. + def _try_to_pil_image(model_input: Any) -> Any: + stream = BytesIO(model_input) + return Image.open(stream) + + if isinstance(model_input, torch.Tensor): + return model_input + + # In the following order, try to: + # 1. interpret model_input as PIL image (and later to PyTorch tensor, see step 3), + # 2. interpret model_input as PyTorch tensor, + # 3. interpret model_input as PyTorch tensor via torchvision ToTensor (supports e.g. PIL images). + for fn in [_try_to_pil_image, torch.as_tensor, to_tensor]: + try: + model_input = fn(model_input) # type: ignore + except Exception: + pass + return model_input + + def _is_base64(self, sb: str | bytes) -> Tuple[bool, bytes]: + """ + Check if a string or bytes object is a valid Base64 encoded string. + + This function checks if the input can be decoded and re-encoded without any changes. + If decoding and encoding returns the same result as the original input, it's likely + that the input was indeed a valid Base64 encoded string. + + Note: This code is based on the reference implementation from the linked Stack Overflow answer. Args: - request (HttpRequest): The request. - body (Any, optional): The request body. Defaults to None. + sb (str | bytes): The input string or bytes object to check. Returns: - HttpResponse: The results of the inference. + Tuple[bool, bytes]: A tuple containing a boolean indicating whether the input is + a valid Base64 encoded string and the decoded bytes if it is. + + References: + https://stackoverflow.com/a/45928164 """ try: - body = body or json.loads(request.body) - return_format = body.get("return_format", "json") - model_id = body["model_id"] - model_input = body["model_input"] - except Exception as e: - self._logger.error(e) - raise ValidationError("Inference Request could not be interpreted!") + if isinstance(sb, str): + # If there's any unicode here, an exception will be thrown and the function will return false + sb_bytes = bytes(sb, "ascii") + elif isinstance(sb, bytes): + sb_bytes = sb + else: + raise ValueError("Argument must be string or bytes") + decoded = base64.b64decode(sb_bytes) + return base64.b64encode(decoded) == sb_bytes, decoded + except Exception: + return False, b"" - model = get_entity(Model, pk=model_id) - input_tensor = torch.as_tensor(model_input) - if isinstance(model, GlobalModel) and model.preprocessing is not None: - preprocessing = model.get_preprocessing_torch_model() - input_tensor = preprocessing(input_tensor) - uncertainty_cls, inference, uncertainty = self.do_inference(model, input_tensor) - return self._make_response(uncertainty_cls, inference, uncertainty, return_format) + def _validate_model_input_after_preprocessing( + self, + model_input: Any, + model_input_shape: Optional[List[Optional[int]]], + preprocessing: bool + ) -> None: + """ + Validates the model input after preprocessing. + + Ensures that the provided `model_input` is a valid PyTorch tensor and its shape matches + the expected`model_input_shape`. + + Args: + model_input (Any): The model input to be validated. + model_input_shape (Optional[List[Optional[int]]]): The expected shape of the model input. + Can contain None values if not all dimensions are fixed (e.g. first dimension as batch size). + preprocessing (bool): Whether a preprocessing model was defined or not. (Only for a better error message.) + + Raises: + ValidationError: If the `model_input` is not a valid PyTorch tensor or + its shape does not match the expected `model_input_shape`. + """ + if not isinstance(model_input, torch.Tensor): + msg = "Model input could not be casted or interpreted as a PyTorch tensor object" + if preprocessing: + msg += " and is still not a PyTorch tensor after preprecessing." + else: + msg += " and no preprecessing is defined." + raise ValidationError(msg) + + if model_input_shape and not all( + dim_input == dim_model + for (dim_input, dim_model) in zip(model_input.shape, model_input_shape) + if dim_model is not None + ): + raise ValidationError("Input shape does not match model input shape.") def _make_response( self, @@ -157,6 +358,9 @@ def _make_response( This method checks the return type and makes a response with the appropriate content type. + If return_type is "binary", a binary-encoded response will be generated using pickle. + Otherwise, a JSON response will be generated by serializing the uncertainty object using its to_json method. + Args: uncertainty_cls (Type[UncertaintyBase]): The uncertainty class. inference (torch.Tensor): The inference. @@ -172,29 +376,36 @@ def _make_response( return HttpResponse(uncertainty_cls.to_json(inference, uncertainty), content_type="application/json") - def do_inference( + def _do_inference( self, model: Model, input_tensor: torch.Tensor ) -> Tuple[Type[UncertaintyBase], torch.Tensor, Dict[str, Any]]: """ - Performs inference on a model. + Perform inference on a given input tensor using the provided model. - This method gets the uncertainty class, performs prediction, and returns the uncertainty class, the inference, - and the uncertainty. + This methods retrieves the uncertainty class, performs the prediction. + The output of this method consists of: + + * The uncertainty class used for inference + * The result of the model's prediction on the input tensor + * Any associated uncertainty for the prediction Args: - model (Model): The model. - input_tensor (torch.Tensor): The input tensor. + model (Model): The model to perform inference with. + input_tensor (torch.Tensor): Input tensor to pass through the model. Returns: Tuple[Type[UncertaintyBase], torch.Tensor, Dict[str, Any]]: - The uncertainty class, the inference, and the uncertainty. + A tuple containing the uncertainty class, prediction result, and any associated uncertainty. + + Raises: + APIException: If an error occurs during inference """ try: uncertainty_cls = get_uncertainty_class(model) inference, uncertainty = uncertainty_cls.prediction(input_tensor, model) return uncertainty_cls, inference, uncertainty except TorchDeserializationException as e: - raise APIException(e) + raise APIException(e) from e except Exception as e: self._logger.error(e) - raise APIException("Internal Server Error occurred during inference!") + raise APIException("Internal Server Error occurred during inference!") from e diff --git a/fl_server_api/views/model.py b/fl_server_api/views/model.py index bbfecaf..69f9a0e 100644 --- a/fl_server_api/views/model.py +++ b/fl_server_api/views/model.py @@ -16,6 +16,7 @@ from rest_framework.exceptions import APIException, NotFound, ParseError, PermissionDenied, ValidationError from rest_framework.response import Response from rest_framework.fields import UUIDField, CharField +import torch from typing import Any, List, Union from uuid import UUID @@ -37,7 +38,7 @@ from ..serializers.generic import ErrorSerializer, MetricSerializer from ..serializers.model import ModelSerializer, ModelSerializerNoWeightsWithStats, load_and_create_model_request, \ GlobalModelSerializer, ModelSerializerNoWeights, verify_model_object -from ..openapi import error_response_403 +from ..openapi import error_response_403, error_response_404 class Model(ViewSet): @@ -232,7 +233,7 @@ def get_model(self, _request: HttpRequest, id: str) -> HttpResponseBase: id (str): The unique identifier of the model. Returns: - HttpResponseBase: model as file response or 404 if model not found + HttpResponseBase: model as file response """ model = get_entity(ModelDB, pk=id) if isinstance(model, SWAGModelDB) and model.swag_first_moment is not None: @@ -247,6 +248,45 @@ def get_model(self, _request: HttpRequest, id: str) -> HttpResponseBase: response["Content-Disposition"] = f'filename="model-{id}.pt"' return response + @extend_schema( + responses={ + status.HTTP_200_OK: OpenApiResponse( + response=bytes, + description="Proprecessing model is returned as bytes" + ), + status.HTTP_400_BAD_REQUEST: ErrorSerializer, + status.HTTP_403_FORBIDDEN: error_response_403, + status.HTTP_404_NOT_FOUND: error_response_404, + }, + ) + def get_model_proprecessing(self, _request: HttpRequest, id: str) -> HttpResponseBase: + """ + Download the whole preprocessing model as PyTorch serialized file. + + Args: + request (HttpRequest): The incoming request object. + id (str): The unique identifier of the model. + + Returns: + HttpResponseBase: proprecessing model as file response or 404 if proprecessing model not found + """ + model = get_entity(ModelDB, pk=id) + global_model: torch.nn.Module + if isinstance(model, GlobalModelDB): + global_model = model + elif isinstance(model, LocalModelDB): + global_model = model.base_model + else: + self._logger.error("Unknown model type. Not a GlobalModel and not a LocalModel.") + raise ValidationError(f"Unknown model type. Model id: {id}") + if global_model.preprocessing is None: + raise NotFound(f"Model '{id}' has no preprocessing model defined.") + # NOTE: FileResponse does strange stuff with bytes + # and in case of sqlite the weights will be bytes and not a memoryview + response = HttpResponse(global_model.preprocessing, content_type="application/octet-stream") + response["Content-Disposition"] = f'filename="model-{id}-proprecessing.pt"' + return response + @extend_schema(responses={ status.HTTP_200_OK: inline_serializer( "DeleteModelSuccessSerializer", diff --git a/fl_server_core/utils/logging.py b/fl_server_core/utils/logging.py new file mode 100644 index 0000000..ff04107 --- /dev/null +++ b/fl_server_core/utils/logging.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: 2024 Benedikt Franke +# SPDX-FileCopyrightText: 2024 Florian Heinrich +# +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import contextmanager +import logging +from typing import Optional + + +@contextmanager +def disable_logger(logger: Optional[logging.Logger] = None): + """ + Temporary disable the Logger. + """ + previous = logger.disabled if logger else logging.root.manager.disable + if logger: + logger.disabled = True + else: + logging.disable() + try: + yield + finally: + if logger: + logger.disabled = previous # type: ignore + else: + logging.disable(previous) diff --git a/fl_server_core/utils/torch_serialization.py b/fl_server_core/utils/torch_serialization.py index fc285ce..e1b8093 100644 --- a/fl_server_core/utils/torch_serialization.py +++ b/fl_server_core/utils/torch_serialization.py @@ -41,7 +41,7 @@ def to_torch(obj: Any, supported_types: Type[T] | Tuple[Type[T], ...]): message="'torch.load' received a zip file that looks like a TorchScript archive", category=UserWarning ) - t_obj = torch.load(obj) + t_obj = torch.load(obj, weights_only=False) except Exception as e: getLogger("fl.server").error(f"Error loading torch object: {e}") raise TorchDeserializationException("Error loading torch object") from e