From c443ebb60521d87d4553efc1c5e86fda8707d69d Mon Sep 17 00:00:00 2001 From: Marek Materzok Date: Wed, 18 Dec 2024 13:17:10 +0100 Subject: [PATCH] Simplify method body argument handling --- test/core/test_transactions.py | 2 +- test/lib/test_connectors.py | 4 +-- test/lib/test_reqres.py | 4 +-- test/lib/test_simultaneous.py | 2 +- test/lib/test_transformers.py | 10 +++--- test/test_methods.py | 6 ++-- test/test_simultaneous.py | 2 +- test/testing/test_validate_arguments.py | 2 +- transactron/core/body.py | 35 ++++++++++++------- transactron/core/method.py | 24 +++---------- transactron/core/sugar.py | 46 +++++-------------------- transactron/core/transaction.py | 4 --- transactron/lib/adapters.py | 26 +++++++++----- 13 files changed, 69 insertions(+), 98 deletions(-) diff --git a/test/core/test_transactions.py b/test/core/test_transactions.py index e072c1e..1b4de28 100644 --- a/test/core/test_transactions.py +++ b/test/core/test_transactions.py @@ -115,7 +115,7 @@ def __init__(self, scheduler): def elaborate(self, platform): m = TModule() tm = TransactionModule(m, DependencyContext.get(), TransactionManager(self.scheduler)) - adapter = Adapter(i=data_layout(32), o=data_layout(32)) + adapter = Adapter.create(i=data_layout(32), o=data_layout(32)) m.submodules.out = self.out = TestbenchIO(adapter) m.submodules.in1 = self.in1 = TestbenchIO(AdapterTrans(adapter.iface)) m.submodules.in2 = self.in2 = TestbenchIO(AdapterTrans(adapter.iface)) diff --git a/test/lib/test_connectors.py b/test/lib/test_connectors.py index ee9196a..e7a813b 100644 --- a/test/lib/test_connectors.py +++ b/test/lib/test_connectors.py @@ -134,13 +134,13 @@ def elaborate(self, platform): get_results = [] for i in range(self.count): - input = TestbenchIO(Adapter(o=self.lay)) + input = TestbenchIO(Adapter.create(o=self.lay)) get_results.append(input.adapter.iface) m.submodules[f"input_{i}"] = input self.inputs.append(input) # Create ManyToOneConnectTrans, which will serialize results from different inputs - output = TestbenchIO(Adapter(i=self.lay)) + output = TestbenchIO(Adapter.create(i=self.lay)) m.submodules.output = self.output = output m.submodules.fu_arbitration = ManyToOneConnectTrans(get_results=get_results, put_result=output.adapter.iface) diff --git a/test/lib/test_reqres.py b/test/lib/test_reqres.py index 6aea07f..ce73ce5 100644 --- a/test/lib/test_reqres.py +++ b/test/lib/test_reqres.py @@ -27,8 +27,8 @@ def setup_method(self): layout = [("field", self.data_width)] - self.req_method = TestbenchIO(Adapter(i=layout)) - self.resp_method = TestbenchIO(Adapter(o=layout)) + self.req_method = TestbenchIO(Adapter.create(i=layout)) + self.resp_method = TestbenchIO(Adapter.create(o=layout)) self.test_circuit = SimpleTestCircuit( Serializer( diff --git a/test/lib/test_simultaneous.py b/test/lib/test_simultaneous.py index cad34e5..80199cd 100644 --- a/test/lib/test_simultaneous.py +++ b/test/lib/test_simultaneous.py @@ -49,7 +49,7 @@ class TestCondition(TestCaseWithSimulator): @pytest.mark.parametrize("priority", [False, True]) @pytest.mark.parametrize("catchall", [False, True]) def test_condition(self, nonblocking: bool, priority: bool, catchall: bool): - target = TestbenchIO(Adapter(i=[("cond", 2)])) + target = TestbenchIO(Adapter.create(i=[("cond", 2)])) circ = SimpleTestCircuit( ConditionTestCircuit(target.adapter.iface, nonblocking=nonblocking, priority=priority, catchall=catchall), diff --git a/test/lib/test_transformers.py b/test/lib/test_transformers.py index 4de85b4..3e4a647 100644 --- a/test/lib/test_transformers.py +++ b/test/lib/test_transformers.py @@ -51,7 +51,7 @@ def otransform_dict(_, v: MethodStruct) -> RecordDict: itransform = itransform_rec otransform = otransform_rec - m.submodules.target = self.target = TestbenchIO(Adapter(i=layout, o=layout)) + m.submodules.target = self.target = TestbenchIO(Adapter.create(i=layout, o=layout)) if self.use_methods: imeth = Method(i=layout, o=layout) @@ -111,8 +111,8 @@ class TestMethodFilter(TestCaseWithSimulator): def initialize(self): self.iosize = 4 self.layout = data_layout(self.iosize) - self.target = TestbenchIO(Adapter(i=self.layout, o=self.layout)) - self.cmeth = TestbenchIO(Adapter(i=self.layout, o=data_layout(1))) + self.target = TestbenchIO(Adapter.create(i=self.layout, o=self.layout)) + self.cmeth = TestbenchIO(Adapter.create(i=self.layout, o=data_layout(1))) async def source(self, sim: TestbenchContext): for i in range(2**self.iosize): @@ -165,7 +165,7 @@ def elaborate(self, platform): methods = [] for k in range(self.targets): - tgt = TestbenchIO(Adapter(i=layout, o=layout)) + tgt = TestbenchIO(Adapter.create(i=layout, o=layout)) methods.append(tgt.adapter.iface) self.target.append(tgt) m.submodules += tgt @@ -280,7 +280,7 @@ def elaborate(self, platform): methods = [] for k in range(self.targets): - tgt = TestbenchIO(Adapter(i=layout, o=layout)) + tgt = TestbenchIO(Adapter.create(i=layout, o=layout)) methods.append(tgt.adapter.iface) self.target.append(tgt) m.submodules += tgt diff --git a/test/test_methods.py b/test/test_methods.py index 9371d0e..95c7e58 100644 --- a/test/test_methods.py +++ b/test/test_methods.py @@ -407,7 +407,7 @@ def elaborate(self, platform): meth = Method(i=data_layout(1)) m.submodules.tb = self.tb = TestbenchIO(AdapterTrans(meth)) - m.submodules.out = self.out = TestbenchIO(Adapter()) + m.submodules.out = self.out = TestbenchIO(Adapter.create()) @def_method(m, meth) def _(arg): @@ -456,7 +456,7 @@ def elaborate(self, platform): m = TModule() self.ready = Signal() - m.submodules.tb = self.tb = TestbenchIO(Adapter()) + m.submodules.tb = self.tb = TestbenchIO(Adapter.create()) with Transaction().body(m, request=self.ready): self.tb.adapter.iface(m) @@ -469,7 +469,7 @@ def elaborate(self, platform): m = TModule() self.ready = Signal() - m.submodules.tb = self.tb = TestbenchIO(Adapter()) + m.submodules.tb = self.tb = TestbenchIO(Adapter.create()) with m.If(self.ready): with Transaction().body(m): diff --git a/test/test_simultaneous.py b/test/test_simultaneous.py index ad492e3..d9f281a 100644 --- a/test/test_simultaneous.py +++ b/test/test_simultaneous.py @@ -140,7 +140,7 @@ def elaborate(self, platform): class TestTransitivity(TestCaseWithSimulator): def test_transitivity(self): - target = TestbenchIO(Adapter(i=[("data", 2)])) + target = TestbenchIO(Adapter.create(i=[("data", 2)])) req1 = Signal() req2 = Signal() diff --git a/test/testing/test_validate_arguments.py b/test/testing/test_validate_arguments.py index 7e70369..9f2f5a9 100644 --- a/test/testing/test_validate_arguments.py +++ b/test/testing/test_validate_arguments.py @@ -14,7 +14,7 @@ class ValidateArgumentsTestCircuit(Elaboratable): def elaborate(self, platform): m = Module() - self.method = TestbenchIO(Adapter(i=data_layout(1), o=data_layout(1)).set(with_validate_arguments=True)) + self.method = TestbenchIO(Adapter.create(i=data_layout(1), o=data_layout(1)).set(with_validate_arguments=True)) self.caller1 = TestbenchIO(AdapterTrans(self.method.adapter.iface)) self.caller2 = TestbenchIO(AdapterTrans(self.method.adapter.iface)) diff --git a/transactron/core/body.py b/transactron/core/body.py index e617792..5d7db0e 100644 --- a/transactron/core/body.py +++ b/transactron/core/body.py @@ -9,14 +9,24 @@ from transactron.utils import * from amaranth import * -from typing import TYPE_CHECKING, ClassVar, NewType, Optional, Callable, final +from typing import TYPE_CHECKING, ClassVar, NewType, NotRequired, Optional, Callable, TypedDict, Unpack, final from transactron.utils.assign import AssignArg if TYPE_CHECKING: from .method import Method -__all__ = ["Body", "TBody", "MBody"] +__all__ = ["AdapterBodyParams", "BodyParams", "Body", "TBody", "MBody"] + + +class AdapterBodyParams(TypedDict): + combiner: NotRequired[Callable[[Module, Sequence[MethodStruct], Value], AssignArg]] + nonexclusive: NotRequired[bool] + single_caller: NotRequired[bool] + + +class BodyParams(AdapterBodyParams): + validate_arguments: NotRequired[Callable[..., ValueLike]] @final @@ -35,11 +45,8 @@ def __init__( owner: Optional[Elaboratable], i: StructLayout, o: StructLayout, - combiner: Optional[Callable[[Module, Sequence[MethodStruct], Value], AssignArg]], - validate_arguments: Optional[Callable[..., ValueLike]], - nonexclusive: bool, - single_caller: bool, src_loc: SrcLoc, + **kwargs: Unpack[BodyParams], ): super().__init__(src_loc=src_loc) @@ -57,15 +64,19 @@ def default_combiner(m: Module, args: Sequence[MethodStruct], runs: Value) -> As self.run = Signal(name=self.owned_name + "_run") self.data_in: MethodStruct = Signal(from_method_layout(i), name=self.owned_name + "_data_in") self.data_out: MethodStruct = Signal(from_method_layout(o), name=self.owned_name + "_data_out") - self.combiner: Callable[[Module, Sequence[MethodStruct], Value], AssignArg] = combiner or default_combiner - self.nonexclusive = nonexclusive - self.single_caller = single_caller - self.validate_arguments: Optional[Callable[..., ValueLike]] = validate_arguments + self.combiner: Callable[[Module, Sequence[MethodStruct], Value], AssignArg] = ( + kwargs["combiner"] if "combiner" in kwargs else default_combiner + ) + self.nonexclusive = kwargs["nonexclusive"] if "nonexclusive" in kwargs else False + self.single_caller = kwargs["single_caller"] if "single_caller" in kwargs else False + self.validate_arguments: Optional[Callable[..., ValueLike]] = ( + kwargs["validate_arguments"] if "validate_arguments" in kwargs else None + ) self.method_uses = {} self.method_calls = defaultdict(list) - if nonexclusive: - assert len(self.data_in.as_value()) == 0 or combiner is not None + if self.nonexclusive: + assert len(self.data_in.as_value()) == 0 or self.combiner is not None def _validate_arguments(self, arg_rec: MethodStruct) -> ValueLike: if self.validate_arguments is not None: diff --git a/transactron/core/method.py b/transactron/core/method.py index 013702c..dce26a8 100644 --- a/transactron/core/method.py +++ b/transactron/core/method.py @@ -3,13 +3,13 @@ from transactron.utils import * from amaranth import * from amaranth import tracer -from typing import TYPE_CHECKING, Optional, Callable, Iterator +from typing import TYPE_CHECKING, Optional, Iterator, Unpack from .transaction_base import * from contextlib import contextmanager from transactron.utils.assign import AssignArg from transactron.utils._typing import type_self_add_1pos_kwargs_as -from .body import Body, MBody +from .body import Body, BodyParams, MBody from .keys import TransactionManagerKey from .tmodule import TModule from .transaction_base import TransactionBase @@ -154,15 +154,7 @@ def proxy(self, m: TModule, method: "Method"): @contextmanager def body( - self, - m: TModule, - *, - ready: ValueLike = C(1), - out: ValueLike = C(0, 0), - validate_arguments: Optional[Callable[..., ValueLike]] = None, - combiner: Optional[Callable[[Module, Sequence[MethodStruct], Value], AssignArg]] = None, - nonexclusive: bool = False, - single_caller: bool = False, + self, m: TModule, *, ready: ValueLike = C(1), out: ValueLike = C(0, 0), **kwargs: Unpack[BodyParams] ) -> Iterator[MethodStruct]: """Define method body @@ -226,15 +218,7 @@ def body( m.d.comb += sum.eq(data_in.arg1 + data_in.arg2) """ body = Body( - name=self.name, - owner=self.owner, - i=self.layout_in, - o=self.layout_out, - combiner=combiner, - validate_arguments=validate_arguments, - nonexclusive=nonexclusive, - single_caller=single_caller, - src_loc=self.src_loc, + name=self.name, owner=self.owner, i=self.layout_in, o=self.layout_out, src_loc=self.src_loc, **kwargs ) self._set_impl(m, body) diff --git a/transactron/core/sugar.py b/transactron/core/sugar.py index 22a4773..4b1e599 100644 --- a/transactron/core/sugar.py +++ b/transactron/core/sugar.py @@ -1,6 +1,7 @@ from collections.abc import Sequence, Callable from amaranth import * -from typing import Optional, Concatenate, ParamSpec +from typing import Optional, Concatenate, ParamSpec, Unpack +from transactron.core.body import BodyParams from transactron.utils import * from transactron.utils.assign import AssignArg from functools import partial @@ -13,15 +14,7 @@ P = ParamSpec("P") -def def_method( - m: TModule, - method: Method, - ready: ValueLike = C(1), - combiner: Optional[Callable[[Module, Sequence[MethodStruct], Value], AssignArg]] = None, - validate_arguments: Optional[Callable[..., ValueLike]] = None, - nonexclusive: bool = False, - single_caller: bool = False, -): +def def_method(m: TModule, method: Method, ready: ValueLike = C(1), **kwargs: Unpack[BodyParams]): """Define a method. This decorator allows to define transactional methods in an @@ -46,15 +39,7 @@ def def_method( Signal to indicate if the method is ready to be run. By default it is `Const(1)`, so the method is always ready. Assigned combinationally to the `ready` attribute. - validate_arguments: Optional[Callable[..., ValueLike]] - For details, see `Method.body`. - validate_arguments: Optional[Callable[..., ValueLike]] - For details, see `Method.body`. - combiner: (Module, Sequence[MethodStruct], Value) -> AssignArg - For details, see `Method.body`. - nonexclusive: bool - For details, see `Method.body`. - single_caller: bool + **kwargs: BodyParams For details, see `Method.body`. Examples @@ -91,15 +76,7 @@ def decorator(func: Callable[..., Optional[AssignArg]]): out = Signal(method.layout_out) ret_out = None - with method.body( - m, - ready=ready, - out=out, - combiner=combiner, - validate_arguments=validate_arguments, - nonexclusive=nonexclusive, - single_caller=single_caller, - ) as arg: + with method.body(m, ready=ready, out=out, **kwargs) as arg: ret_out = method_def_helper(method, func, arg) if ret_out is not None: @@ -112,7 +89,7 @@ def def_methods( m: TModule, methods: Sequence[Method], ready: Callable[[int], ValueLike] = lambda _: C(1), - validate_arguments: Optional[Callable[..., ValueLike]] = None, + **kwargs: Unpack[BodyParams], ): """Decorator for defining similar methods @@ -143,12 +120,8 @@ def _(arg): A `Callable` that takes the index in the form of an `int` of the currently defined method and produces a `Value` describing whether the method is ready to be run. When omitted, each defined method is always ready. Assigned combinationally to the `ready` attribute. - validate_arguments: Optional[Callable[Concatenate[int, ...], ValueLike]] - Function that takes input arguments used to call the method - and checks whether the method can be called with those arguments. - It instantiates a combinational circuit for each - method caller. By default, there is no function, so all arguments - are accepted. + **kwargs: BodyParams + For details, see `Method.body`. Examples -------- @@ -187,7 +160,6 @@ def _(_): def decorator(func: Callable[Concatenate[int, P], Optional[RecordDict]]): for i in range(len(methods)): partial_f = partial(func, i) - partial_vargs = partial(validate_arguments, i) if validate_arguments is not None else None - def_method(m, methods[i], ready(i), partial_vargs)(partial_f) + def_method(m, methods[i], ready(i), **kwargs)(partial_f) return decorator diff --git a/transactron/core/transaction.py b/transactron/core/transaction.py index 03e7712..2667029 100644 --- a/transactron/core/transaction.py +++ b/transactron/core/transaction.py @@ -120,10 +120,6 @@ def body(self, m: TModule, *, request: ValueLike = C(1)) -> Iterator["Transactio owner=self.owner, i=StructLayout({}), o=StructLayout({}), - combiner=None, - validate_arguments=None, - nonexclusive=False, - single_caller=False, src_loc=self.src_loc, ) self._set_impl(m, impl) diff --git a/transactron/lib/adapters.py b/transactron/lib/adapters.py index 81816b3..b06d99d 100644 --- a/transactron/lib/adapters.py +++ b/transactron/lib/adapters.py @@ -1,12 +1,12 @@ from abc import abstractmethod -from typing import Optional +from typing import Optional, Unpack from amaranth import * from amaranth.lib.wiring import Component, In, Out from amaranth.lib.data import StructLayout, View from ..utils import SrcLoc, get_src_loc, MethodStruct from ..core import * -from ..utils._typing import type_self_kwargs_as, SignalBundle +from ..utils._typing import SignalBundle, MethodLayout __all__ = [ "AdapterBase", @@ -100,8 +100,7 @@ class Adapter(AdapterBase): Hooks for `validate_arguments`. """ - @type_self_kwargs_as(Method.__init__) - def __init__(self, **kwargs): + def __init__(self, method: Method, /, **kwargs: Unpack[AdapterBodyParams]): """ Parameters ---------- @@ -110,12 +109,21 @@ def __init__(self, **kwargs): See transactron.core.Method.__init__ for parameters description. """ - kwargs["src_loc"] = get_src_loc(kwargs.setdefault("src_loc", 0)) - - iface = Method(**kwargs) - super().__init__(iface, iface.layout_out, iface.layout_in) + super().__init__(method, method.layout_out, method.layout_in) self.validators: list[tuple[View[StructLayout], Signal]] = [] self.with_validate_arguments: bool = False + self.kwargs = kwargs + + @staticmethod + def create( + name: Optional[str] = None, + i: MethodLayout = [], + o: MethodLayout = [], + src_loc: int | SrcLoc = 0, + **kwargs: Unpack[AdapterBodyParams], + ): + method = Method(name=name, i=i, o=o, src_loc=get_src_loc(src_loc)) + return Adapter(method, **kwargs) def set(self, with_validate_arguments: Optional[bool]): if with_validate_arguments is not None: @@ -129,7 +137,7 @@ def elaborate(self, platform): data_in = Signal.like(self.data_in) m.d.comb += data_in.eq(self.data_in) - kwargs = {} + kwargs: BodyParams = self.kwargs # type: ignore (pyright complains about optional attribute) if self.with_validate_arguments: