Skip to content

Commit

Permalink
Simplify method body argument handling
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk committed Dec 18, 2024
1 parent 839e9a8 commit ee23cdf
Show file tree
Hide file tree
Showing 13 changed files with 69 additions and 98 deletions.
2 changes: 1 addition & 1 deletion test/core/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(self, scheduler):
def elaborate(self, platform):
m = TModule()
tm = TransactionModule(m, DependencyContext.get(), TransactionManager(self.scheduler))
adapter = Adapter(i=data_layout(32), o=data_layout(32))
adapter = Adapter.create(i=data_layout(32), o=data_layout(32))
m.submodules.out = self.out = TestbenchIO(adapter)
m.submodules.in1 = self.in1 = TestbenchIO(AdapterTrans(adapter.iface))
m.submodules.in2 = self.in2 = TestbenchIO(AdapterTrans(adapter.iface))
Expand Down
4 changes: 2 additions & 2 deletions test/lib/test_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ def elaborate(self, platform):

get_results = []
for i in range(self.count):
input = TestbenchIO(Adapter(o=self.lay))
input = TestbenchIO(Adapter.create(o=self.lay))
get_results.append(input.adapter.iface)
m.submodules[f"input_{i}"] = input
self.inputs.append(input)

# Create ManyToOneConnectTrans, which will serialize results from different inputs
output = TestbenchIO(Adapter(i=self.lay))
output = TestbenchIO(Adapter.create(i=self.lay))
m.submodules.output = self.output = output
m.submodules.fu_arbitration = ManyToOneConnectTrans(get_results=get_results, put_result=output.adapter.iface)

Expand Down
4 changes: 2 additions & 2 deletions test/lib/test_reqres.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def setup_method(self):

layout = [("field", self.data_width)]

self.req_method = TestbenchIO(Adapter(i=layout))
self.resp_method = TestbenchIO(Adapter(o=layout))
self.req_method = TestbenchIO(Adapter.create(i=layout))
self.resp_method = TestbenchIO(Adapter.create(o=layout))

self.test_circuit = SimpleTestCircuit(
Serializer(
Expand Down
2 changes: 1 addition & 1 deletion test/lib/test_simultaneous.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class TestCondition(TestCaseWithSimulator):
@pytest.mark.parametrize("priority", [False, True])
@pytest.mark.parametrize("catchall", [False, True])
def test_condition(self, nonblocking: bool, priority: bool, catchall: bool):
target = TestbenchIO(Adapter(i=[("cond", 2)]))
target = TestbenchIO(Adapter.create(i=[("cond", 2)]))

circ = SimpleTestCircuit(
ConditionTestCircuit(target.adapter.iface, nonblocking=nonblocking, priority=priority, catchall=catchall),
Expand Down
10 changes: 5 additions & 5 deletions test/lib/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def otransform_dict(_, v: MethodStruct) -> RecordDict:
itransform = itransform_rec
otransform = otransform_rec

m.submodules.target = self.target = TestbenchIO(Adapter(i=layout, o=layout))
m.submodules.target = self.target = TestbenchIO(Adapter.create(i=layout, o=layout))

if self.use_methods:
imeth = Method(i=layout, o=layout)
Expand Down Expand Up @@ -111,8 +111,8 @@ class TestMethodFilter(TestCaseWithSimulator):
def initialize(self):
self.iosize = 4
self.layout = data_layout(self.iosize)
self.target = TestbenchIO(Adapter(i=self.layout, o=self.layout))
self.cmeth = TestbenchIO(Adapter(i=self.layout, o=data_layout(1)))
self.target = TestbenchIO(Adapter.create(i=self.layout, o=self.layout))
self.cmeth = TestbenchIO(Adapter.create(i=self.layout, o=data_layout(1)))

async def source(self, sim: TestbenchContext):
for i in range(2**self.iosize):
Expand Down Expand Up @@ -165,7 +165,7 @@ def elaborate(self, platform):
methods = []

for k in range(self.targets):
tgt = TestbenchIO(Adapter(i=layout, o=layout))
tgt = TestbenchIO(Adapter.create(i=layout, o=layout))
methods.append(tgt.adapter.iface)
self.target.append(tgt)
m.submodules += tgt
Expand Down Expand Up @@ -280,7 +280,7 @@ def elaborate(self, platform):
methods = []

for k in range(self.targets):
tgt = TestbenchIO(Adapter(i=layout, o=layout))
tgt = TestbenchIO(Adapter.create(i=layout, o=layout))
methods.append(tgt.adapter.iface)
self.target.append(tgt)
m.submodules += tgt
Expand Down
6 changes: 3 additions & 3 deletions test/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def elaborate(self, platform):
meth = Method(i=data_layout(1))

m.submodules.tb = self.tb = TestbenchIO(AdapterTrans(meth))
m.submodules.out = self.out = TestbenchIO(Adapter())
m.submodules.out = self.out = TestbenchIO(Adapter.create())

@def_method(m, meth)
def _(arg):
Expand Down Expand Up @@ -456,7 +456,7 @@ def elaborate(self, platform):
m = TModule()

self.ready = Signal()
m.submodules.tb = self.tb = TestbenchIO(Adapter())
m.submodules.tb = self.tb = TestbenchIO(Adapter.create())

with Transaction().body(m, request=self.ready):
self.tb.adapter.iface(m)
Expand All @@ -469,7 +469,7 @@ def elaborate(self, platform):
m = TModule()

self.ready = Signal()
m.submodules.tb = self.tb = TestbenchIO(Adapter())
m.submodules.tb = self.tb = TestbenchIO(Adapter.create())

with m.If(self.ready):
with Transaction().body(m):
Expand Down
2 changes: 1 addition & 1 deletion test/test_simultaneous.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def elaborate(self, platform):

class TestTransitivity(TestCaseWithSimulator):
def test_transitivity(self):
target = TestbenchIO(Adapter(i=[("data", 2)]))
target = TestbenchIO(Adapter.create(i=[("data", 2)]))
req1 = Signal()
req2 = Signal()

Expand Down
2 changes: 1 addition & 1 deletion test/testing/test_validate_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ValidateArgumentsTestCircuit(Elaboratable):
def elaborate(self, platform):
m = Module()

self.method = TestbenchIO(Adapter(i=data_layout(1), o=data_layout(1)).set(with_validate_arguments=True))
self.method = TestbenchIO(Adapter.create(i=data_layout(1), o=data_layout(1)).set(with_validate_arguments=True))
self.caller1 = TestbenchIO(AdapterTrans(self.method.adapter.iface))
self.caller2 = TestbenchIO(AdapterTrans(self.method.adapter.iface))

Expand Down
35 changes: 23 additions & 12 deletions transactron/core/body.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,24 @@

from transactron.utils import *
from amaranth import *
from typing import TYPE_CHECKING, ClassVar, NewType, Optional, Callable, final
from typing import TYPE_CHECKING, ClassVar, NewType, NotRequired, Optional, Callable, TypedDict, Unpack, final
from transactron.utils.assign import AssignArg

if TYPE_CHECKING:
from .method import Method


__all__ = ["Body", "TBody", "MBody"]
__all__ = ["AdapterBodyParams", "BodyParams", "Body", "TBody", "MBody"]


class AdapterBodyParams(TypedDict):
combiner: NotRequired[Callable[[Module, Sequence[MethodStruct], Value], AssignArg]]
nonexclusive: NotRequired[bool]
single_caller: NotRequired[bool]


class BodyParams(AdapterBodyParams):
validate_arguments: NotRequired[Callable[..., ValueLike]]


@final
Expand All @@ -35,11 +45,8 @@ def __init__(
owner: Optional[Elaboratable],
i: StructLayout,
o: StructLayout,
combiner: Optional[Callable[[Module, Sequence[MethodStruct], Value], AssignArg]],
validate_arguments: Optional[Callable[..., ValueLike]],
nonexclusive: bool,
single_caller: bool,
src_loc: SrcLoc,
**kwargs: Unpack[BodyParams],
):
super().__init__(src_loc=src_loc)

Expand All @@ -57,15 +64,19 @@ def default_combiner(m: Module, args: Sequence[MethodStruct], runs: Value) -> As
self.run = Signal(name=self.owned_name + "_run")
self.data_in: MethodStruct = Signal(from_method_layout(i), name=self.owned_name + "_data_in")
self.data_out: MethodStruct = Signal(from_method_layout(o), name=self.owned_name + "_data_out")
self.combiner: Callable[[Module, Sequence[MethodStruct], Value], AssignArg] = combiner or default_combiner
self.nonexclusive = nonexclusive
self.single_caller = single_caller
self.validate_arguments: Optional[Callable[..., ValueLike]] = validate_arguments
self.combiner: Callable[[Module, Sequence[MethodStruct], Value], AssignArg] = (
kwargs["combiner"] if "combiner" in kwargs else default_combiner
)
self.nonexclusive = kwargs["nonexclusive"] if "nonexclusive" in kwargs else False
self.single_caller = kwargs["single_caller"] if "single_caller" in kwargs else False
self.validate_arguments: Optional[Callable[..., ValueLike]] = (
kwargs["validate_arguments"] if "validate_arguments" in kwargs else None
)
self.method_uses = {}
self.method_calls = defaultdict(list)

if nonexclusive:
assert len(self.data_in.as_value()) == 0 or combiner is not None
if self.nonexclusive:
assert len(self.data_in.as_value()) == 0 or self.combiner is not None

def _validate_arguments(self, arg_rec: MethodStruct) -> ValueLike:
if self.validate_arguments is not None:
Expand Down
24 changes: 4 additions & 20 deletions transactron/core/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from transactron.utils import *
from amaranth import *
from amaranth import tracer
from typing import TYPE_CHECKING, Optional, Callable, Iterator
from typing import TYPE_CHECKING, Optional, Iterator, Unpack
from .transaction_base import *
from contextlib import contextmanager
from transactron.utils.assign import AssignArg
from transactron.utils._typing import type_self_add_1pos_kwargs_as

from .body import Body, MBody
from .body import Body, BodyParams, MBody
from .keys import TransactionManagerKey
from .tmodule import TModule
from .transaction_base import TransactionBase
Expand Down Expand Up @@ -154,15 +154,7 @@ def proxy(self, m: TModule, method: "Method"):

@contextmanager
def body(
self,
m: TModule,
*,
ready: ValueLike = C(1),
out: ValueLike = C(0, 0),
validate_arguments: Optional[Callable[..., ValueLike]] = None,
combiner: Optional[Callable[[Module, Sequence[MethodStruct], Value], AssignArg]] = None,
nonexclusive: bool = False,
single_caller: bool = False,
self, m: TModule, *, ready: ValueLike = C(1), out: ValueLike = C(0, 0), **kwargs: Unpack[BodyParams]
) -> Iterator[MethodStruct]:
"""Define method body
Expand Down Expand Up @@ -226,15 +218,7 @@ def body(
m.d.comb += sum.eq(data_in.arg1 + data_in.arg2)
"""
body = Body(
name=self.name,
owner=self.owner,
i=self.layout_in,
o=self.layout_out,
combiner=combiner,
validate_arguments=validate_arguments,
nonexclusive=nonexclusive,
single_caller=single_caller,
src_loc=self.src_loc,
name=self.name, owner=self.owner, i=self.layout_in, o=self.layout_out, src_loc=self.src_loc, **kwargs
)
self._set_impl(m, body)

Expand Down
46 changes: 9 additions & 37 deletions transactron/core/sugar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Sequence, Callable
from amaranth import *
from typing import Optional, Concatenate, ParamSpec
from typing import Optional, Concatenate, ParamSpec, Unpack
from transactron.core.body import BodyParams
from transactron.utils import *
from transactron.utils.assign import AssignArg
from functools import partial
Expand All @@ -13,15 +14,7 @@
P = ParamSpec("P")


def def_method(
m: TModule,
method: Method,
ready: ValueLike = C(1),
combiner: Optional[Callable[[Module, Sequence[MethodStruct], Value], AssignArg]] = None,
validate_arguments: Optional[Callable[..., ValueLike]] = None,
nonexclusive: bool = False,
single_caller: bool = False,
):
def def_method(m: TModule, method: Method, ready: ValueLike = C(1), **kwargs: Unpack[BodyParams]):
"""Define a method.
This decorator allows to define transactional methods in an
Expand All @@ -46,15 +39,7 @@ def def_method(
Signal to indicate if the method is ready to be run. By
default it is `Const(1)`, so the method is always ready.
Assigned combinationally to the `ready` attribute.
validate_arguments: Optional[Callable[..., ValueLike]]
For details, see `Method.body`.
validate_arguments: Optional[Callable[..., ValueLike]]
For details, see `Method.body`.
combiner: (Module, Sequence[MethodStruct], Value) -> AssignArg
For details, see `Method.body`.
nonexclusive: bool
For details, see `Method.body`.
single_caller: bool
**kwargs: BodyParams
For details, see `Method.body`.
Examples
Expand Down Expand Up @@ -91,15 +76,7 @@ def decorator(func: Callable[..., Optional[AssignArg]]):
out = Signal(method.layout_out)
ret_out = None

with method.body(
m,
ready=ready,
out=out,
combiner=combiner,
validate_arguments=validate_arguments,
nonexclusive=nonexclusive,
single_caller=single_caller,
) as arg:
with method.body(m, ready=ready, out=out, **kwargs) as arg:
ret_out = method_def_helper(method, func, arg)

if ret_out is not None:
Expand All @@ -112,7 +89,7 @@ def def_methods(
m: TModule,
methods: Sequence[Method],
ready: Callable[[int], ValueLike] = lambda _: C(1),
validate_arguments: Optional[Callable[..., ValueLike]] = None,
**kwargs: Unpack[BodyParams],
):
"""Decorator for defining similar methods
Expand Down Expand Up @@ -143,12 +120,8 @@ def _(arg):
A `Callable` that takes the index in the form of an `int` of the currently defined method
and produces a `Value` describing whether the method is ready to be run.
When omitted, each defined method is always ready. Assigned combinationally to the `ready` attribute.
validate_arguments: Optional[Callable[Concatenate[int, ...], ValueLike]]
Function that takes input arguments used to call the method
and checks whether the method can be called with those arguments.
It instantiates a combinational circuit for each
method caller. By default, there is no function, so all arguments
are accepted.
**kwargs: BodyParams
For details, see `Method.body`.
Examples
--------
Expand Down Expand Up @@ -187,7 +160,6 @@ def _(_):
def decorator(func: Callable[Concatenate[int, P], Optional[RecordDict]]):
for i in range(len(methods)):
partial_f = partial(func, i)
partial_vargs = partial(validate_arguments, i) if validate_arguments is not None else None
def_method(m, methods[i], ready(i), partial_vargs)(partial_f)
def_method(m, methods[i], ready(i), **kwargs)(partial_f)

return decorator
4 changes: 0 additions & 4 deletions transactron/core/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,6 @@ def body(self, m: TModule, *, request: ValueLike = C(1)) -> Iterator["Transactio
owner=self.owner,
i=StructLayout({}),
o=StructLayout({}),
combiner=None,
validate_arguments=None,
nonexclusive=False,
single_caller=False,
src_loc=self.src_loc,
)
self._set_impl(m, impl)
Expand Down
Loading

0 comments on commit ee23cdf

Please sign in to comment.