From a053bb5c2700ae674dd966ca37af664bc02126dc Mon Sep 17 00:00:00 2001 From: lekcyjna123 <34948061+lekcyjna123@users.noreply.github.com> Date: Sun, 3 Dec 2023 19:49:34 +0100 Subject: [PATCH] Add `condition` based MethodFilter (#504) --- test/transactions/test_transaction_lib.py | 12 ++++++--- transactron/lib/transformers.py | 33 ++++++++++++++++++----- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/test/transactions/test_transaction_lib.py b/test/transactions/test_transaction_lib.py index 6daa6517c..e096f7860 100644 --- a/test/transactions/test_transaction_lib.py +++ b/test/transactions/test_transaction_lib.py @@ -466,23 +466,27 @@ def target_mock(self, data): def cmeth_mock(self, data): return {"data": data % 2} - def test_method_filter_with_methods(self): + @parameterized.expand([(True,), (False,)]) + def test_method_filter_with_methods(self, use_condition): self.initialize() self.cmeth = TestbenchIO(Adapter(i=self.layout, o=data_layout(1))) - self.tc = SimpleTestCircuit(MethodFilter(self.target.adapter.iface, self.cmeth.adapter.iface)) + self.tc = SimpleTestCircuit( + MethodFilter(self.target.adapter.iface, self.cmeth.adapter.iface, use_condition=use_condition) + ) m = ModuleConnector(test_circuit=self.tc, target=self.target, cmeth=self.cmeth) with self.run_simulation(m) as sim: sim.add_sync_process(self.source) sim.add_sync_process(self.target_mock) sim.add_sync_process(self.cmeth_mock) - def test_method_filter(self): + @parameterized.expand([(True,), (False,)]) + def test_method_filter(self, use_condition): self.initialize() def condition(_, v): return v[0] - self.tc = SimpleTestCircuit(MethodFilter(self.target.adapter.iface, condition)) + self.tc = SimpleTestCircuit(MethodFilter(self.target.adapter.iface, condition, use_condition=use_condition)) m = ModuleConnector(test_circuit=self.tc, target=self.target) with self.run_simulation(m) as sim: sim.add_sync_process(self.source) diff --git a/transactron/lib/transformers.py b/transactron/lib/transformers.py index 18c3ac73a..5bcf0a4d3 100644 --- a/transactron/lib/transformers.py +++ b/transactron/lib/transformers.py @@ -6,6 +6,7 @@ from collections.abc import Callable from transactron.utils import ValueLike, assign, AssignType, ModuleLike from .connectors import Forwarder, ManyToOneConnectTrans, ConnectTrans +from .simultaneous import condition __all__ = [ "Transformer", @@ -109,9 +110,10 @@ class MethodFilter(Transformer, Elaboratable): parameters, a module and the input `Record` of the method. Non-zero return value is interpreted as true. Alternatively to using a function, a `Method` can be passed as a condition. - - Caveat: because of the limitations of transaction scheduling, the target - method is locked for usage even if it is not called. + By default, the target method is locked for use even if it is not called. + If this is not the desired effect, set `use_condition` to True, but this will + cause that the provided method will be `single_caller` and all other `condition` + drawbacks will be in place (e.g. risk of exponential complexity). Attributes ---------- @@ -120,7 +122,11 @@ class MethodFilter(Transformer, Elaboratable): """ def __init__( - self, target: Method, condition: Callable[[TModule, Record], ValueLike], default: Optional[RecordDict] = None + self, + target: Method, + condition: Callable[[TModule, Record], ValueLike], + default: Optional[RecordDict] = None, + use_condition: bool = False, ): """ Parameters @@ -133,12 +139,16 @@ def __init__( default: Value or dict, optional The default value returned from the filtered method when the condition is false. If omitted, zero is returned. + use_condition : bool + Instead of `m.If` use simultaneus `condition` which allow to execute + this filter if the condition is False and target is not ready. """ if default is None: default = Record.like(target.data_out) self.target = target - self.method = Method.like(target) + self.use_condition = use_condition + self.method = Method(i=target.data_in.layout, o=target.data_out.layout, single_caller=self.use_condition) self.condition = condition self.default = default @@ -150,8 +160,17 @@ def elaborate(self, platform): @def_method(m, self.method) def _(arg): - with m.If(self.condition(m, arg)): - m.d.comb += ret.eq(self.target(m, arg)) + if self.use_condition: + cond = Signal() + m.d.top_comb += cond.eq(self.condition(m, arg)) + with condition(m, nonblocking=False, priority=False) as branch: + with branch(cond): + m.d.comb += ret.eq(self.target(m, arg)) + with branch(~cond): + pass + else: + with m.If(self.condition(m, arg)): + m.d.comb += ret.eq(self.target(m, arg)) return ret return m