diff --git a/storey/flow.py b/storey/flow.py index 8b616a17..7f2df3f0 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -252,15 +252,16 @@ def _event_string(event): def _should_terminate(self): return self._termination_received == len(self._inlets) - async def _do_downstream(self, event): - if not self._outlets: + async def _do_downstream(self, event, outlets=None): + outlets = self._outlets if outlets is None else outlets + if not outlets: return if event is _termination_obj: # Only propagate the termination object once we received one per inlet - self._outlets[0]._termination_received += 1 - if self._outlets[0]._should_terminate(): - self._termination_result = await self._outlets[0]._do(_termination_obj) - for outlet in self._outlets[1:] + self._get_recovery_steps(): + outlets[0]._termination_received += 1 + if outlets[0]._should_terminate(): + self._termination_result = await outlets[0]._do(_termination_obj) + for outlet in outlets[1:] + self._get_recovery_steps(): outlet._termination_received += 1 if outlet._should_terminate(): self._termination_result = self._termination_result_fn( @@ -269,28 +270,28 @@ async def _do_downstream(self, event): return self._termination_result # If there is more than one outlet, allow concurrent execution. tasks = [] - if len(self._outlets) > 1: + if len(outlets) > 1: awaitable_result = event._awaitable_result event._awaitable_result = None original_events = getattr(event, "_original_events", None) # Temporarily delete self-reference to avoid deepcopy getting stuck in an infinite loop event._original_events = None - for i in range(1, len(self._outlets)): + for i in range(1, len(outlets)): event_copy = copy.deepcopy(event) event_copy._awaitable_result = awaitable_result event_copy._original_events = original_events - tasks.append(asyncio.get_running_loop().create_task(self._outlets[i]._do_and_recover(event_copy))) + tasks.append(asyncio.get_running_loop().create_task(outlets[i]._do_and_recover(event_copy))) # Set self-reference back after deepcopy event._original_events = original_events event._awaitable_result = awaitable_result if self.verbose and self.logger: step_name = self.name event_string = self._event_string(event) - self.logger.debug(f"{step_name} -> {self._outlets[0].name} | {event_string}") - await self._outlets[0]._do_and_recover(event) # Optimization - avoids creating a task for the first outlet. + self.logger.debug(f"{step_name} -> {outlets[0].name} | {event_string}") + await outlets[0]._do_and_recover(event) # Optimization - avoids creating a task for the first outlet. for i, task in enumerate(tasks, start=1): if self.verbose and self.logger: - self.logger.debug(f"{step_name} -> {self._outlets[i].name} | {event_string}") + self.logger.debug(f"{step_name} -> {outlets[i].name} | {event_string}") await task def _get_event_or_body(self, event): @@ -347,46 +348,48 @@ def _get_uuid(self): class Choice(Flow): - """Redirects each input element into at most one of multiple downstreams. - - :param choice_array: a list of (downstream, condition) tuples, where downstream is a step and condition is a - function. The first condition in the list to evaluate as true for an input element causes that element to - be redirected to that downstream step. - :type choice_array: tuple of (Flow, Function (Event=>boolean)) - :param default: a default step for events that did not match any condition in choice_array. If not set, elements - that don't match any condition will be discarded. - :type default: Flow - :param name: Name of this step, as it should appear in logs. Defaults to class name (Choice). - :type name: string - :param full_event: Whether user functions should receive and return Event objects (when True), - or only the payload (when False). Defaults to False. - :type full_event: boolean + """ + Redirects each input element into any number of predetermined downstream steps. Override select_outlets() + to route events to any number of downstream steps. """ - def __init__(self, choice_array, default=None, **kwargs): - Flow.__init__(self, **kwargs) - - self._choice_array = choice_array - for outlet, _ in choice_array: - self.to(outlet) - - if default: - self.to(default) - self._default = default + def _init(self): + super()._init() + self._name_to_outlet = {} + for outlet in self._outlets: + if outlet.name in self._name_to_outlet: + raise ValueError(f"Ambiguous outlet name '{outlet.name}' in Choice step") + self._name_to_outlet[outlet.name] = outlet + # TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget + self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"] + + def select_outlets(self, event) -> List[str]: + """ + Override this method to route events based on a customer logic. The default implementation will route all + events to all outlets. + """ + return list(self._name_to_outlet.keys()) async def _do(self, event): - if not self._outlets or event is _termination_obj: - return await super()._do_downstream(event) - chosen_outlet = None - element = self._get_event_or_body(event) - for outlet, condition in self._choice_array: - if condition(element): - chosen_outlet = outlet - break - if chosen_outlet: - await chosen_outlet._do(event) - elif self._default: - await self._default._do(event) + if event is _termination_obj: + return await self._do_downstream(_termination_obj) + else: + event_body = event if self._full_event else event.body + outlet_names = self.select_outlets(event_body) + outlets = [] + if self._passthrough_for_preview: + outlet = self._name_to_outlet["dataframe"] + outlets.append(outlet) + else: + for outlet_name in outlet_names: + if outlet_name not in self._name_to_outlet: + raise ValueError( + f"select_outlets() returned outlet name '{outlet_name}', which is not one of the " + f"defined outlets: " + ", ".join(self._name_to_outlet) + ) + outlet = self._name_to_outlet[outlet_name] + outlets.append(outlet) + return await self._do_downstream(event, outlets=outlets) class Recover(Flow): diff --git a/storey/sources.py b/storey/sources.py index 0abca29a..6c10853a 100644 --- a/storey/sources.py +++ b/storey/sources.py @@ -313,6 +313,7 @@ async def _run_loop(self): await _commit_handled_events(self._outstanding_offsets, committer, commit_all=True) self._termination_future.set_result(termination_result) except BaseException as ex: + traceback.print_exc() if self.logger: message = "An error was raised" raised_by = getattr(ex, "_raised_by_storey_step", None) diff --git a/tests/test_flow.py b/tests/test_flow.py index bd484069..755597ce 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -1690,26 +1690,42 @@ def boom(_): def test_choice(): - small_reduce = Reduce(0, lambda acc, x: acc + x) + class MyChoice(Choice): + def select_outlets(self, event): + outlets = ["all_events"] + if event > 5: + outlets.append("more_than_five") + else: + outlets.append("up_to_five") + return outlets - big_reduce = build_flow([Map(lambda x: x * 100), Reduce(0, lambda acc, x: acc + x)]) + source = SyncEmitSource() + my_choice = MyChoice(termination_result_fn=lambda x, y: x + y) + all_events = Map(lambda x: x, name="all_events") + more_than_five = Map(lambda x: x * 10, name="more_than_five") + up_to_five = Map(lambda x: x * 100, name="up_to_five") + sum_up_all_events = Reduce(0, lambda acc, x: acc + x) + sum_up_more_than_five = Reduce(0, lambda acc, x: acc + x) + sum_up_up_to_five = Reduce(0, lambda acc, x: acc + x) + + source.to(my_choice) + my_choice.to(all_events) + my_choice.to(more_than_five) + my_choice.to(up_to_five) + all_events.to(sum_up_all_events) + more_than_five.to(sum_up_more_than_five) + up_to_five.to(sum_up_up_to_five) - controller = build_flow( - [ - SyncEmitSource(), - Choice( - [(big_reduce, lambda x: x % 2 == 0)], - default=small_reduce, - termination_result_fn=lambda x, y: x + y, - ), - ] - ).run() + controller = source.run() - for i in range(10): + for i in range(4, 8): controller.emit(i) + controller.terminate() termination_result = controller.await_termination() - assert termination_result == 2025 + + expected = sum(range(4, 8)) + sum(range(6, 8)) * 10 + sum(range(4, 6)) * 100 + assert termination_result == expected def test_metadata():