Skip to content

Commit

Permalink
Rewrite Choice step to make it usable from mlrun (#537)
Browse files Browse the repository at this point in the history
* Rewrite Choice step to make it usable from mlrun

[ML-7818](https://iguazio.atlassian.net/browse/ML-7818)

* Add missing space

* Hack to avoid issue with mlrun preview

* Improve docs

* Remove accidental kwargs, add type annotation
  • Loading branch information
gtopper authored Oct 15, 2024
1 parent 9bdebf4 commit c389ddc
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 62 deletions.
99 changes: 51 additions & 48 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,16 @@ def _event_string(event):
def _should_terminate(self):
return self._termination_received == len(self._inlets)

async def _do_downstream(self, event):
if not self._outlets:
async def _do_downstream(self, event, outlets=None):
outlets = self._outlets if outlets is None else outlets
if not outlets:
return
if event is _termination_obj:
# Only propagate the termination object once we received one per inlet
self._outlets[0]._termination_received += 1
if self._outlets[0]._should_terminate():
self._termination_result = await self._outlets[0]._do(_termination_obj)
for outlet in self._outlets[1:] + self._get_recovery_steps():
outlets[0]._termination_received += 1
if outlets[0]._should_terminate():
self._termination_result = await outlets[0]._do(_termination_obj)
for outlet in outlets[1:] + self._get_recovery_steps():
outlet._termination_received += 1
if outlet._should_terminate():
self._termination_result = self._termination_result_fn(
Expand All @@ -269,28 +270,28 @@ async def _do_downstream(self, event):
return self._termination_result
# If there is more than one outlet, allow concurrent execution.
tasks = []
if len(self._outlets) > 1:
if len(outlets) > 1:
awaitable_result = event._awaitable_result
event._awaitable_result = None
original_events = getattr(event, "_original_events", None)
# Temporarily delete self-reference to avoid deepcopy getting stuck in an infinite loop
event._original_events = None
for i in range(1, len(self._outlets)):
for i in range(1, len(outlets)):
event_copy = copy.deepcopy(event)
event_copy._awaitable_result = awaitable_result
event_copy._original_events = original_events
tasks.append(asyncio.get_running_loop().create_task(self._outlets[i]._do_and_recover(event_copy)))
tasks.append(asyncio.get_running_loop().create_task(outlets[i]._do_and_recover(event_copy)))
# Set self-reference back after deepcopy
event._original_events = original_events
event._awaitable_result = awaitable_result
if self.verbose and self.logger:
step_name = self.name
event_string = self._event_string(event)
self.logger.debug(f"{step_name} -> {self._outlets[0].name} | {event_string}")
await self._outlets[0]._do_and_recover(event) # Optimization - avoids creating a task for the first outlet.
self.logger.debug(f"{step_name} -> {outlets[0].name} | {event_string}")
await outlets[0]._do_and_recover(event) # Optimization - avoids creating a task for the first outlet.
for i, task in enumerate(tasks, start=1):
if self.verbose and self.logger:
self.logger.debug(f"{step_name} -> {self._outlets[i].name} | {event_string}")
self.logger.debug(f"{step_name} -> {outlets[i].name} | {event_string}")
await task

def _get_event_or_body(self, event):
Expand Down Expand Up @@ -347,46 +348,48 @@ def _get_uuid(self):


class Choice(Flow):
"""Redirects each input element into at most one of multiple downstreams.
:param choice_array: a list of (downstream, condition) tuples, where downstream is a step and condition is a
function. The first condition in the list to evaluate as true for an input element causes that element to
be redirected to that downstream step.
:type choice_array: tuple of (Flow, Function (Event=>boolean))
:param default: a default step for events that did not match any condition in choice_array. If not set, elements
that don't match any condition will be discarded.
:type default: Flow
:param name: Name of this step, as it should appear in logs. Defaults to class name (Choice).
:type name: string
:param full_event: Whether user functions should receive and return Event objects (when True),
or only the payload (when False). Defaults to False.
:type full_event: boolean
"""
Redirects each input element into any number of predetermined downstream steps. Override select_outlets()
to route events to any number of downstream steps.
"""

def __init__(self, choice_array, default=None, **kwargs):
Flow.__init__(self, **kwargs)

self._choice_array = choice_array
for outlet, _ in choice_array:
self.to(outlet)

if default:
self.to(default)
self._default = default
def _init(self):
super()._init()
self._name_to_outlet = {}
for outlet in self._outlets:
if outlet.name in self._name_to_outlet:
raise ValueError(f"Ambiguous outlet name '{outlet.name}' in Choice step")
self._name_to_outlet[outlet.name] = outlet
# TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget
self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"]

def select_outlets(self, event) -> List[str]:
"""
Override this method to route events based on a customer logic. The default implementation will route all
events to all outlets.
"""
return list(self._name_to_outlet.keys())

async def _do(self, event):
if not self._outlets or event is _termination_obj:
return await super()._do_downstream(event)
chosen_outlet = None
element = self._get_event_or_body(event)
for outlet, condition in self._choice_array:
if condition(element):
chosen_outlet = outlet
break
if chosen_outlet:
await chosen_outlet._do(event)
elif self._default:
await self._default._do(event)
if event is _termination_obj:
return await self._do_downstream(_termination_obj)
else:
event_body = event if self._full_event else event.body
outlet_names = self.select_outlets(event_body)
outlets = []
if self._passthrough_for_preview:
outlet = self._name_to_outlet["dataframe"]
outlets.append(outlet)
else:
for outlet_name in outlet_names:
if outlet_name not in self._name_to_outlet:
raise ValueError(
f"select_outlets() returned outlet name '{outlet_name}', which is not one of the "
f"defined outlets: " + ", ".join(self._name_to_outlet)
)
outlet = self._name_to_outlet[outlet_name]
outlets.append(outlet)
return await self._do_downstream(event, outlets=outlets)


class Recover(Flow):
Expand Down
1 change: 1 addition & 0 deletions storey/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ async def _run_loop(self):
await _commit_handled_events(self._outstanding_offsets, committer, commit_all=True)
self._termination_future.set_result(termination_result)
except BaseException as ex:
traceback.print_exc()
if self.logger:
message = "An error was raised"
raised_by = getattr(ex, "_raised_by_storey_step", None)
Expand Down
44 changes: 30 additions & 14 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,26 +1690,42 @@ def boom(_):


def test_choice():
small_reduce = Reduce(0, lambda acc, x: acc + x)
class MyChoice(Choice):
def select_outlets(self, event):
outlets = ["all_events"]
if event > 5:
outlets.append("more_than_five")
else:
outlets.append("up_to_five")
return outlets

big_reduce = build_flow([Map(lambda x: x * 100), Reduce(0, lambda acc, x: acc + x)])
source = SyncEmitSource()
my_choice = MyChoice(termination_result_fn=lambda x, y: x + y)
all_events = Map(lambda x: x, name="all_events")
more_than_five = Map(lambda x: x * 10, name="more_than_five")
up_to_five = Map(lambda x: x * 100, name="up_to_five")
sum_up_all_events = Reduce(0, lambda acc, x: acc + x)
sum_up_more_than_five = Reduce(0, lambda acc, x: acc + x)
sum_up_up_to_five = Reduce(0, lambda acc, x: acc + x)

source.to(my_choice)
my_choice.to(all_events)
my_choice.to(more_than_five)
my_choice.to(up_to_five)
all_events.to(sum_up_all_events)
more_than_five.to(sum_up_more_than_five)
up_to_five.to(sum_up_up_to_five)

controller = build_flow(
[
SyncEmitSource(),
Choice(
[(big_reduce, lambda x: x % 2 == 0)],
default=small_reduce,
termination_result_fn=lambda x, y: x + y,
),
]
).run()
controller = source.run()

for i in range(10):
for i in range(4, 8):
controller.emit(i)

controller.terminate()
termination_result = controller.await_termination()
assert termination_result == 2025

expected = sum(range(4, 8)) + sum(range(6, 8)) * 10 + sum(range(4, 6)) * 100
assert termination_result == expected


def test_metadata():
Expand Down

0 comments on commit c389ddc

Please sign in to comment.