Skip to content

Commit

Permalink
test(runner): add handle_pipeline_exception test
Browse files Browse the repository at this point in the history
This commit adds a test for the 'handle_pipeline_exception' route
utility function. It also fixes some errors into that function.
  • Loading branch information
rickstaa committed Oct 14, 2024
1 parent 870c2f0 commit 6702956
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 20 deletions.
2 changes: 2 additions & 0 deletions runner/app/pipelines/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

from app.pipelines.utils.utils import (
LoraLoader,
LoraLoadingError,
SafetyChecker,
get_model_dir,
get_model_path,
get_torch_device,
is_lightning_model,
is_turbo_model,
is_numeric,
split_prompt,
validate_torch_device,
)
37 changes: 17 additions & 20 deletions runner/app/routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,45 +192,42 @@ def handle_pipeline_exception(
error message and status code.
Args:
e (object): The exception to handle. Can be any type of object.
default_error_message (Union[str, Dict[str, Any]]): The default error message
or content dictionary. Default will be used if no specific error type is
matched.
default_status_code (int): The default status code to use if no specific error
type is matched. Defaults to HTTP_500_INTERNAL_SERVER_ERROR.
custom_error_config (Dict[str, Tuple[str, int]]): Custom error configuration
to override the application error configuration.
e(int): The exception to handle. Can be any type of object.
default_error_message: The default error message or content dictionary. Default
will be used if no specific error type ismatched.
default_status_code: The default status code to use if no specific error type is
matched. Defaults to HTTP_500_INTERNAL_SERVER_ERROR.
custom_error_config: Custom error configuration to override the application
error configuration.
Returns:
JSONResponse: The JSON response with appropriate status code and error message.
The JSON response with appropriate status code and error message.
"""
error_config = ERROR_CONFIG.copy()

# Update error_config with custom_error_config if provided.
if custom_error_config:
error_config.update(custom_error_config)

error_message = default_error_message
status_code = default_status_code
error_message = default_error_message

error_type = type(e).__name__
if error_type in error_config:
message, status_code = error_config[error_type]
error_message = str(e) if message is None else message
error_message, status_code = error_config[error_type]
else:
for error_pattern, (message, code) in error_config.items():
if error_pattern.lower() in str(e).lower():
error_message = str(e) if message is None else message
status_code = code
error_message = message
break

if error_message == "":
if error_message is None:
error_message = f"{e}."
elif error_message == "":
error_message = default_error_message

if isinstance(error_message, str):
content = http_error(error_message)
else:
content = error_message
content = (
http_error(error_message) if isinstance(error_message, str) else error_message
)

return JSONResponse(
status_code=status_code,
Expand Down
Empty file added runner/tests/__init__.py
Empty file.
Empty file added runner/tests/routes/__init__.py
Empty file.
99 changes: 99 additions & 0 deletions runner/tests/routes/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
from app.routes.utils import handle_pipeline_exception
from app.pipelines.utils import LoraLoadingError
import torch
from fastapi import status
from fastapi.responses import JSONResponse
import json


class TestHandlePipelineException:
"""Tests for the handle_pipeline_exception function."""

@staticmethod
def parse_response(response: JSONResponse):
"""Parses the JSON response body from a FastAPI JSONResponse object."""
return json.loads(response.body)

@pytest.mark.parametrize(
"exception, expected_status, expected_message, description",
[
(
Exception("Unknown error"),
status.HTTP_500_INTERNAL_SERVER_ERROR,
"Pipeline error.",
"Returns default message and status code for unknown error.",
),
(
torch.cuda.OutOfMemoryError("Some Message"),
status.HTTP_500_INTERNAL_SERVER_ERROR,
"GPU out of memory.",
"Returns global message and status code for type match.",
),
(
Exception("CUDA out of memory"),
status.HTTP_500_INTERNAL_SERVER_ERROR,
"Out of memory.",
"Returns global message and status code for pattern match.",
),
(
LoraLoadingError("A custom error message"),
status.HTTP_400_BAD_REQUEST,
"A custom error message.",
"Forwards exception message if configured with None.",
),
(
ValueError("A custom error message"),
status.HTTP_400_BAD_REQUEST,
"Pipeline error.",
"Returns default message if configured with ''.",
),
],
)
def test_handle_pipeline_exception(
self, exception, expected_status, expected_message, description
):
"""Test that the handle_pipeline_exception function returns the correct status
code and error message for different types of exceptions.
"""
response = handle_pipeline_exception(exception)
response_body = self.parse_response(response)
assert response.status_code == expected_status, f"Failed: {description}"
assert (
response_body["detail"]["msg"] == expected_message
), f"Failed: {description}"

def test_handle_pipeline_exception_custom_default_message(self):
"""Test that a custom default error message is used when provided."""
exception = ValueError("Some value error")
response = handle_pipeline_exception(
exception, default_error_message="A custom error message."
)
response_body = self.parse_response(response)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response_body["detail"]["msg"] == "A custom error message."

def test_handle_pipeline_exception_custom_status_code(self):
"""Test that a custom default status code is used when provided."""
exception = Exception("Some value error")
response = handle_pipeline_exception(
exception, default_status_code=status.HTTP_404_NOT_FOUND
)
response_body = self.parse_response(response)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response_body["detail"]["msg"] == "Pipeline error."

def test_handle_pipeline_exception_custom_error_config(self):
"""Test that custom error configuration overrides the global error
configuration, which prints the exception message.
"""
exception = LoraLoadingError("Some error message.")
response = handle_pipeline_exception(
exception,
custom_error_config={
"LoraLoadingError": ("Custom message.", status.HTTP_400_BAD_REQUEST)
},
)
response_body = self.parse_response(response)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response_body["detail"]["msg"] == "Custom message."

0 comments on commit 6702956

Please sign in to comment.