Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sequential execution of middleware #160

Merged
merged 16 commits into from
Oct 20, 2023
5 changes: 5 additions & 0 deletions pro_tes/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,8 @@ tes:

storeLogs:
execution_trace: True

middlewares:
mutate_destination:
- - 'pro_tes.middleware.middleware.DistanceTaskDistribution'
- 'pro_tes.middleware.middleware.RandomTaskDistribution'
10 changes: 9 additions & 1 deletion pro_tes/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class IPDistanceCalculationError(ValueError):
"""Raised when IP distance cannot be calculated."""


class InvalidMiddleware(ValueError):
"""Raised when the middleware does not follow the AbstractMiddleware."""


exceptions = {
Exception: {
"message": "An unexpected error occurred.",
Expand Down Expand Up @@ -101,5 +105,9 @@ class IPDistanceCalculationError(ValueError):
IPDistanceCalculationError: {
"message": "IP distance calculation failed.",
"code": "500",
}
},
InvalidMiddleware: {
"message": "Middleware doesn't follow the abstract class",
"code": "500",
},
}
26 changes: 23 additions & 3 deletions pro_tes/ga4gh/tes/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
TesNextTes,
)
from pro_tes.ga4gh.tes.states import States
from pro_tes.middleware.middleware import TaskDistributionMiddleware
from pro_tes.middleware.middleware_pipeline import (
MiddlewarePipeline,
MiddlewareLoader
)
from pro_tes.tasks.track_task_progress import task__track_task_progress
from pro_tes.utils.db import DbDocumentConnector
from pro_tes.utils.misc import remove_auth_from_url
Expand Down Expand Up @@ -58,7 +61,6 @@ def __init__(self) -> None:
self.foca_config.db.dbs["taskStore"].collections["tasks"].client
)
self.store_logs = self.foca_config.storeLogs["execution_trace"]
self.task_distributor = TaskDistributionMiddleware()

def create_task( # pylint: disable=too-many-statements,too-many-branches
self, **kwargs
Expand All @@ -80,7 +82,8 @@ def create_task( # pylint: disable=too-many-statements,too-many-branches
db_document.task_original = TesTask(**payload)

# middleware is called after the task is created in the database
payload = self.task_distributor.modify_request(request=request).json
pipeline = self.create_middleware_pipeline()
payload = pipeline.process_request(request=request).json

tes_uri_list = deepcopy(payload["tes_uri"])
del payload["tes_uri"]
Expand Down Expand Up @@ -628,3 +631,20 @@ def parse_basic_auth(auth: Optional[Dict[str, str]]) -> BasicAuth:
username=auth.get("username"),
password=auth.get("password"),
)

@staticmethod
def create_middleware_pipeline() -> MiddlewarePipeline:
"""Create middleware pipeline.
uniqueg marked this conversation as resolved.
Show resolved Hide resolved

Returns: A `MiddlewarePipeline` object that can be used to execute the
configured middleware on incoming task requests using the
`process_request` method.
"""
middleware_config = current_app.config.foca.middlewares
uniqueg marked this conversation as resolved.
Show resolved Hide resolved
middlewares = MiddlewareLoader().load_middlewares_from_config(
middleware_config
uniqueg marked this conversation as resolved.
Show resolved Hide resolved
)
logger.info(f"Loaded middlewares: {middlewares}")
pipeline = MiddlewarePipeline(middlewares)
logger.info(f"Created middleware pipeline: {pipeline}")
return pipeline
89 changes: 79 additions & 10 deletions pro_tes/middleware/middleware.py
uniqueg marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
from typing import List
import requests

from pro_tes.exceptions import (
NoTesInstancesAvailable,
Expand All @@ -18,20 +19,24 @@ class AbstractMiddleware(metaclass=abc.ABCMeta):
"""Abstract class to implement different middleware."""

@abc.abstractmethod
def modify_request(self, request):
"""Modify the incoming task request.
def set_request(
self,
request: requests.Request,
*args,
**kwargs
) -> requests.Request:
"""Set the incoming request object.

Abstract method.

Args:
request: The incoming request object.

Returns:
The modified request object.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""


class TaskDistributionMiddleware(AbstractMiddleware):
class DistanceTaskDistribution(AbstractMiddleware):
"""Inject task distribution logic.

Attributes:
Expand All @@ -44,7 +49,7 @@ def __init__(self) -> None:
self.tes_uris: List[str] = []
self.input_uris: List[str] = []

def modify_request(self, request):
def set_request(self, request: requests.Request, *args, **kwargs):
"""Modify the incoming task request.

Abstract method
Expand All @@ -66,19 +71,83 @@ def modify_request(self, request):
request.json["inputs"][index]["url"]
)

self.tes_uris = self._set_url(self.input_uris)

if self.tes_uris:
request.json["tes_uri"] = self.tes_uris
else:
raise NoTesInstancesAvailable
return request

def _set_url(self, input_uris: List[str]) -> List[str]:
"""Set the TES URI.

Args:
input_uris: A list of input URIs from the incoming request.

Returns:
List of TES URIs.
"""
try:
self.tes_uris = distance.task_distribution(self.input_uris)
tes_uris = distance.task_distribution(input_uris)
except (
TesUriError,
InputUriError,
IPDistanceCalculationError,
KeyError,
ValueError
):
self.tes_uris = random.task_distribution()
) as exc:
raise NoTesInstancesAvailable from exc
return tes_uris


class RandomTaskDistribution(AbstractMiddleware):
"""Inject task distribution logic.

Attributes:
tes_uri: TES instance best suited for TES task.
input_uris: A list of input URIs from the incoming request.
"""

def __init__(self) -> None:
"""Construct object instance."""
self.tes_uris: List[str] = []
self.input_uris: List[str] = []

def set_request(self, request: requests.Request, *args, **kwargs):
"""Modify the incoming task request.

Abstract method

Args:
request: Incoming request object.

Returns:
The modified request object.

Raises:
pro_tes.exceptions.NoTesInstancesAvailable: If no valid TES
instances are available.
"""
self.tes_uris = self._set_url()

if self.tes_uris:
request.json["tes_uri"] = self.tes_uris
else:
raise NoTesInstancesAvailable
return request

def _set_url(self) -> List[str]:
"""Set the TES URI.

Args:
input_uris: A list of input URIs from the incoming request.

Returns:
List of TES URIs.
"""
try:
tes_uris = random.task_distribution()
except Exception as exc:
raise NoTesInstancesAvailable from exc
return tes_uris
173 changes: 173 additions & 0 deletions pro_tes/middleware/middleware_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""Set up the middleware pipeline."""
import importlib
import logging
from typing import Type
import requests

from pro_tes.exceptions import InvalidMiddleware
from pro_tes.middleware.middleware import AbstractMiddleware

logger = logging.getLogger(__name__)


class MiddlewarePipeline:
"""
Middleware Pipeline.
uniqueg marked this conversation as resolved.
Show resolved Hide resolved

A class for managing and processing incoming requests through a sequence of
middleware components.
uniqueg marked this conversation as resolved.
Show resolved Hide resolved

Methods:
__init__(self, middlewares=None): Initializes a MiddlewarePipeline
instance, optionally with an initial list of middleware components.
add_middleware(self, middleware): Adds a middleware component to the
pipeline.
process_request(self, request, *args, **kwargs): Processes an incoming
request through the middleware pipeline, applying each
middlewares logic in sequence.
uniqueg marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, middlewares=None):
uniqueg marked this conversation as resolved.
Show resolved Hide resolved
"""
Initialize a MiddlewarePipeline instance.

Args:
middlewares (list, optional): A list of middleware components to
initialize the pipeline. Defaults to an empty list if not provided.
uniqueg marked this conversation as resolved.
Show resolved Hide resolved

Attributes:
middlewares (list): A list that holds the middleware components
in the pipeline.
"""
self.middlewares = middlewares or []
logger.info(f"middleware: {self.middlewares}")

def add_middleware(self, middleware) -> None:
uniqueg marked this conversation as resolved.
Show resolved Hide resolved
"""Add a middleware to the pipeline."""
self.middlewares.append(middleware)

def process_request(
self,
request: requests.Request,
*args,
**kwargs
) -> requests.Request:
"""
Process the incoming request through the middleware pipeline.

This method iterates through the list of middleware components in the
pipeline and calls the `set_request` method of each component to
process the incoming request.

Args:
request: The incoming requests.Request object.
*args: Additional positional arguments to pass to the middleware.
**kwargs: Additional keyword arguments to pass to the middleware.

Returns:
requests.Request object tha is the modified request object after
processing by all middleware components.

Raises:
ValueError: If a list of middleware components is provided,and all
of them fail to process the request, this exception is raised.
"""
for middleware in self.middlewares:
logger.info(
f"trying to execute the middleware {middleware}"
)
if isinstance(middleware, list):
inner_success = False
for mid_instance in middleware:
try:
request = mid_instance.set_request(
request,
*args,
**kwargs
)
inner_success = True
break
except Exception as exc: # pylint: disable=W0703
logger.exception(
f"Error occurred in middleware: {exc}"
)
if not inner_success:
raise ValueError(
f"List of Failed Middlewares {middleware} "
)
else:
try:
request = middleware.set_request(
request,
*args,
**kwargs
)
except Exception as exc: # pylint: disable=W0703
logger.exception(
f"Error occurred in middleware: {exc}"
)
return request


class MiddlewareLoader:
uniqueg marked this conversation as resolved.
Show resolved Hide resolved
"""A class for loading middleware instances based on configuration."""

def __init__(self):
"""Initialize a MiddlewareLoader instance."""
self.middleware_list = []

def load_middleware_instance(
self,
middleware_path: str
) -> Type[AbstractMiddleware]:
"""Load a middleware instance.

Args:
middleware_path: Middleware path in the form:
"module.submodule.MiddlewareClass".

Returns:
Middleware instance.
"""
module_path, class_name = middleware_path.rsplit('.', 1)
module = importlib.import_module(module_path)
middleware_class = getattr(module, class_name)
if not issubclass(middleware_class, AbstractMiddleware):
raise InvalidMiddleware

return middleware_class()

def load_middlewares_from_config(self, config: dict) -> list:
"""Load all middlewares from config.

Args:
config: Middleware config.

Returns:
List of middleware objects in the form:
[
mw1(),
[
mw2(),
mw2_fb1(),
mw2_fb2()
]
mw3(),
].
"""
for key in config:
if isinstance(config[key], list):
self.middleware_list.extend(config[key])

new_middleware_list = []

for item in self.middleware_list:
if isinstance(item, list):
new_sublist = [
self.load_middleware_instance(subitem) for subitem in item
]
new_middleware_list.append(new_sublist)
else:
new_middleware_list.append(self.load_middleware_instance(item))

return new_middleware_list
Loading
Loading