Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add window manager within middleware #437

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions vumi/middleware/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- test-case-name: vumi.middleware.tests.test_base -*-

from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.defer import inlineCallbacks, returnValue, maybeDeferred

from vumi.utils import load_class_by_string
from vumi.errors import ConfigError, VumiError
Expand All @@ -10,6 +10,14 @@ class MiddlewareError(VumiError):
pass


class MiddlewareControlFlag():
pass


class StopPropagation(MiddlewareControlFlag):
pass


class BaseMiddleware(object):
"""Common middleware base class.

Expand Down Expand Up @@ -95,6 +103,10 @@ def handle_failure(self, failure, endpoint):
"""
return failure

def resume_handling(self, handle_name, message, endpoint):
return self.worker._middlewares.resume_handling(
self, handle_name, message, endpoint)


class TransportMiddleware(BaseMiddleware):
"""Message processor middleware for Transports.
Expand All @@ -113,17 +125,47 @@ class MiddlewareStack(object):
def __init__(self, middlewares):
self.middlewares = middlewares

def _get_middleware_index(self, middleware):
return self.middlewares.index(middleware)

@inlineCallbacks
def resume_handling(self, mw, handle_name, message, endpoint):
mw_index = self._get_middleware_index(mw)
#In case there are no other middleware after this one
if mw_index + 1 == len(self.middlewares):
returnValue(message)
message = yield self._handle(self.middlewares, handle_name, message, endpoint, mw_index + 1)
returnValue(message)

@inlineCallbacks
def _handle(self, middlewares, handler_name, message, endpoint):
def _handle(self, middlewares, handler_name, message, endpoint, from_index=0):
method_name = 'handle_%s' % (handler_name,)
for middleware in middlewares:
for index, middleware in enumerate(middlewares[from_index:]):
handler = getattr(middleware, method_name)
message = yield handler(message, endpoint)
message = yield self._handle_middleware(handler, message, endpoint, index)
if message is None:
raise MiddlewareError('Returned value of %s.%s should never ' \
'be None' % (middleware, method_name,))
returnValue(message)

def _handle_middleware(self, handler, message, endpoint, index):
def _handle_control_flag(f):
if not isinstance(f.value, MiddlewareControlFlag):
raise f
if isinstance(f.value, StopPropagation):
raise f
raise MiddlewareError('Unknown Middleware Control Flag: %s'
% (f.value,))

d = maybeDeferred(handler, message, endpoint)
d.addErrback(_handle_control_flag)
return d

def process_control_flag(self, f):
f.trap(StopPropagation)
if isinstance(f.value, StopPropagation):
return None

def apply_consume(self, handler_name, message, endpoint):
return self._handle(
self.middlewares, handler_name, message, endpoint)
Expand Down
75 changes: 75 additions & 0 deletions vumi/middleware/tests/test_window_manager_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from twisted.internet.defer import inlineCallbacks, returnValue

from vumi.middleware.window_manager_middleware import WindowManagerMiddleware
from vumi.persist.fake_redis import FakeRedis
from vumi.message import TransportEvent, TransportUserMessage
from vumi.tests.utils import PersistenceMixin, VumiWorkerTestCase
from vumi.middleware.tests.utils import RecordingMiddleware
from vumi.middleware.base import StopPropagation, MiddlewareStack


class ToyWorker(object):

transport_name = 'transport'
messages = []

def handle_outbound_message(self, msg):
self.messages.append(msg)


class WindowManagerTestCase(VumiWorkerTestCase, PersistenceMixin):

@inlineCallbacks
def setUp(self):
self._persist_setUp()
toy_worker = ToyWorker()
self.transport_name = toy_worker.transport_name
config = self.mk_config({
'window_size': 2,
'flight_lifetime': 1,
'monitor_loop': 0.5})
self.mw = WindowManagerMiddleware('mw1', config, toy_worker)
mw_recording = RecordingMiddleware('mw2', {}, toy_worker)
yield self.mw.setup_middleware()
toy_worker._middlewares = MiddlewareStack([self.mw, mw_recording])

@inlineCallbacks
def tearDown(self):
self.mw.teardown_middleware()
yield self._persist_tearDown()

@inlineCallbacks
def test_handle_outbound(self):
msg_1 = self.mkmsg_out(message_id='1')
yield self.assertFailure(
self.mw.handle_outbound(msg_1, self.transport_name),
StopPropagation)

msg_2 = self.mkmsg_out(message_id='2')
yield self.assertFailure(
self.mw.handle_outbound(msg_2, self.transport_name),
StopPropagation)

msg_3 = self.mkmsg_out(message_id='3')
yield self.assertFailure(
self.mw.handle_outbound(msg_3, self.transport_name),
StopPropagation)

count_waiting = yield self.mw.wm.count_waiting(self.transport_name)
self.assertEqual(3, count_waiting)

yield self.mw.wm._monitor_windows(self.mw.send_outbound, False)
self.assertEqual(1, (yield self.mw.wm.count_waiting(self.transport_name)))
self.assertEqual(2, (yield self.mw.wm.count_in_flight(self.transport_name)))
self.assertEqual(2, len(self.mw.worker.messages))
msg_1 = self.mw.worker.messages[0]
self.assertEqual(msg_1['record'],
[('mw2', 'outbound', self.transport_name)])

#acknowledge one of the messages
ack = self.mkmsg_ack(user_message_id="1")
yield self.mw.handle_event(ack, self.transport_name)
self.assertEqual(1, (yield self.mw.wm.count_in_flight(self.transport_name)))

yield self.mw.wm._monitor_windows(self.mw.send_outbound)
self.assertEqual(2, (yield self.mw.wm.count_in_flight(self.transport_name)))
57 changes: 57 additions & 0 deletions vumi/middleware/window_manager_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet import reactor
from twisted.internet.task import LoopingCall

from vumi.middleware import BaseMiddleware
from vumi.message import TransportUserMessage
from vumi.persist.txredis_manager import TxRedisManager

from vumi.components.window_manager import WindowManager
from vumi.middleware.base import StopPropagation

class WindowManagerMiddleware(BaseMiddleware):

@inlineCallbacks
def setup_middleware(self):
store_prefix = self.config.get('store_prefix', 'message_store')
r_config = self.config.get('redis_manager', {})
redis = yield TxRedisManager.from_config(r_config)

self.transport_name = self.worker.transport_name

self.wm = WindowManager(
redis,
window_size=self.config.get('window_size', 10),
flight_lifetime=self.config.get('flight_lifetime', 1))

self.wm.monitor(
self.send_outbound,
self.config.get('monitor_loop', 1),
False)

if not (yield self.wm.window_exists(self.transport_name)):
yield self.wm.create_window(self.transport_name)

def teardown_middleware(self):
self.wm.stop()

@inlineCallbacks
def handle_event(self, event, endpoint):
if event["event_type"] in ['ack', 'nack']:
yield self.wm.remove_key(
self.transport_name,
event['user_message_id'])
returnValue(event)

@inlineCallbacks
def handle_outbound(self, msg, endpoint):
yield self.wm.add(self.transport_name, msg.to_json(), msg["message_id"])
raise StopPropagation()

@inlineCallbacks
def send_outbound(self, window_id, key):
data = yield self.wm.get_data(window_id, key)
msg = TransportUserMessage.from_json(data)
# TODO store the endpoint in the stored data
self.resume_handling('outbound', msg, self.transport_name)
self.worker.handle_outbound_message(msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this approach is wrong. I think we should just add a helper method to the middleware base class that does this by sending the message through the rest of the middleware stack (it'll need support from the middleware stack object probably).

1 change: 1 addition & 0 deletions vumi/transports/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def _send_failure(f):
d = self._middlewares.apply_consume("outbound", message,
self.transport_name)
d.addCallback(self.handle_outbound_message)
d.addErrback(self._middlewares.process_control_flag)
d.addErrback(_send_failure)
return d

Expand Down
24 changes: 24 additions & 0 deletions vumi/transports/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
from vumi.transports.tests.utils import TransportTestCase
from vumi.transports.base import Transport

from vumi.middleware.base import BaseMiddleware, StopPropagation


class StopPropagationMiddleware(BaseMiddleware):

def handle_outbound(self, message, endpoint):
raise StopPropagation()


class BaseTransportTestCase(TransportTestCase):
"""
Expand Down Expand Up @@ -84,3 +92,19 @@ def test_middleware_for_outbound_messages(self):
('mw1', 'outbound', self.transport_name),
('mw2', 'outbound', self.transport_name),
])

@inlineCallbacks
def test_middleware_for_outbound_messages_control_flag_stop_propagation(self):
transport = yield self.get_transport({
"middleware": [
{"mw1": "vumi.middleware.tests.utils.RecordingMiddleware"},
{"mw2": "vumi.transports.tests.test_base.StopPropagationMiddleware"},
{"mw3": "vumi.middleware.tests.utils.RecordingMiddleware"},
],
})
msgs = []
transport.handle_outbound_message = msgs.append
orig_msg = self.mkmsg_out()
orig_msg['timestamp'] = 0
yield self.dispatch(orig_msg)
self.assertEqual(0, len(msgs))