diff --git a/docs/transactions.md b/docs/transactions.md index 286f3c7..d42ae7f 100644 --- a/docs/transactions.md +++ b/docs/transactions.md @@ -140,7 +140,7 @@ Suppose we have the following layout, which is an input layout for a method call ```python layout = [("foo", 1), ("bar", 32)] -method = Method(input_layout=layout) +method = Method(i=layout) ``` The method can be called in multiple ways. @@ -170,7 +170,7 @@ Take the following definitions: ```python layout2 = [("foobar", layout), ("baz", 42)] -method2 = Method(input_layout=layout2) +method2 = Method(i=layout2) ``` One can then pass the arguments using `dict`s in following ways: @@ -208,7 +208,7 @@ The `dict` syntax can be used for returning values from methods. Take the following method declaration: ```python -method3 = Method(input_layout=layout, output_layout=layout2) +method3 = Method(i=layout, o=layout2) ``` One can then define this method as follows: diff --git a/pyproject.toml b/pyproject.toml index db92389..e5239d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "transactron" dynamic = ["version"] dependencies = [ "amaranth == 0.5.3", - "amaranth-stubs @ git+https://github.com/piotro888/amaranth-stubs.git@e25a7fa11d4a0d66ed18190f31b60914f222e74c" + "amaranth-stubs @ git+https://github.com/piotro888/amaranth-stubs.git@e25a7fa11d4a0d66ed18190f31b60914f222e74c", "dataclasses-json == 0.6.3", "tabulate == 0.9.0" ] diff --git a/test/core/test_transactions.py b/test/core/test_transactions.py index 08a9873..1b4de28 100644 --- a/test/core/test_transactions.py +++ b/test/core/test_transactions.py @@ -10,6 +10,7 @@ from collections import deque from typing import Iterable, Callable +from transactron.core.keys import TransactionManagerKey from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout @@ -20,7 +21,7 @@ from transactron.core import Priority from transactron.core.schedulers import trivial_roundrobin_cc_scheduler, eager_deterministic_cc_scheduler from transactron.core.manager import TransactionScheduler -from transactron.utils.dependencies import DependencyContext +from transactron.utils.dependencies import DependencyContext, DependencyManager class TestNames(TestCase): @@ -28,25 +29,28 @@ def test_names(self): mgr = TransactionManager() mgr._MustUse__silence = True # type: ignore - class T(Elaboratable): - def __init__(self): - self._MustUse__silence = True # type: ignore - Transaction(manager=mgr) + with DependencyContext(DependencyManager()) as ctx: + ctx.manager.add_dependency(TransactionManagerKey(), mgr) - T() - assert mgr.transactions[0].name == "T" + class T(Elaboratable): + def __init__(self): + self._MustUse__silence = True # type: ignore + Transaction() - t = Transaction(name="x", manager=mgr) - assert t.name == "x" + T() + assert mgr.transactions[0].name == "T" - t = Transaction(manager=mgr) - assert t.name == "t" + t = Transaction(name="x") + assert t.name == "x" - m = Method(name="x") - assert m.name == "x" + t = Transaction() + assert t.name == "t" - m = Method() - assert m.name == "m" + m = Method(name="x") + assert m.name == "x" + + m = Method() + assert m.name == "m" class TestScheduler(TestCaseWithSimulator): @@ -111,7 +115,7 @@ def __init__(self, scheduler): def elaborate(self, platform): m = TModule() tm = TransactionModule(m, DependencyContext.get(), TransactionManager(self.scheduler)) - adapter = Adapter(i=data_layout(32), o=data_layout(32)) + adapter = Adapter.create(i=data_layout(32), o=data_layout(32)) m.submodules.out = self.out = TestbenchIO(adapter) m.submodules.in1 = self.in1 = TestbenchIO(AdapterTrans(adapter.iface)) m.submodules.in2 = self.in2 = TestbenchIO(AdapterTrans(adapter.iface)) @@ -428,7 +432,11 @@ class SingleCallerTestCircuit(Elaboratable): def elaborate(self, platform): m = TModule() - method = Method(single_caller=True) + method = Method() + + @def_method(m, method, single_caller=True) + def _(): + pass with Transaction().body(m): method(m) diff --git a/test/lib/test_connectors.py b/test/lib/test_connectors.py index ee9196a..e7a813b 100644 --- a/test/lib/test_connectors.py +++ b/test/lib/test_connectors.py @@ -134,13 +134,13 @@ def elaborate(self, platform): get_results = [] for i in range(self.count): - input = TestbenchIO(Adapter(o=self.lay)) + input = TestbenchIO(Adapter.create(o=self.lay)) get_results.append(input.adapter.iface) m.submodules[f"input_{i}"] = input self.inputs.append(input) # Create ManyToOneConnectTrans, which will serialize results from different inputs - output = TestbenchIO(Adapter(i=self.lay)) + output = TestbenchIO(Adapter.create(i=self.lay)) m.submodules.output = self.output = output m.submodules.fu_arbitration = ManyToOneConnectTrans(get_results=get_results, put_result=output.adapter.iface) diff --git a/test/lib/test_reqres.py b/test/lib/test_reqres.py index 6aea07f..ce73ce5 100644 --- a/test/lib/test_reqres.py +++ b/test/lib/test_reqres.py @@ -27,8 +27,8 @@ def setup_method(self): layout = [("field", self.data_width)] - self.req_method = TestbenchIO(Adapter(i=layout)) - self.resp_method = TestbenchIO(Adapter(o=layout)) + self.req_method = TestbenchIO(Adapter.create(i=layout)) + self.resp_method = TestbenchIO(Adapter.create(o=layout)) self.test_circuit = SimpleTestCircuit( Serializer( diff --git a/test/lib/test_simultaneous.py b/test/lib/test_simultaneous.py index 0418e23..80199cd 100644 --- a/test/lib/test_simultaneous.py +++ b/test/lib/test_simultaneous.py @@ -20,7 +20,7 @@ class ConditionTestCircuit(Elaboratable): def __init__(self, target: Method, *, nonblocking: bool, priority: bool, catchall: bool): self.target = target - self.source = Method(i=[("cond1", 1), ("cond2", 1), ("cond3", 1)], single_caller=True) + self.source = Method(i=[("cond1", 1), ("cond2", 1), ("cond3", 1)]) self.nonblocking = nonblocking self.priority = priority self.catchall = catchall @@ -28,7 +28,7 @@ def __init__(self, target: Method, *, nonblocking: bool, priority: bool, catchal def elaborate(self, platform): m = TModule() - @def_method(m, self.source) + @def_method(m, self.source, single_caller=True) def _(cond1, cond2, cond3): with condition(m, nonblocking=self.nonblocking, priority=self.priority) as branch: with branch(cond1): @@ -49,7 +49,7 @@ class TestCondition(TestCaseWithSimulator): @pytest.mark.parametrize("priority", [False, True]) @pytest.mark.parametrize("catchall", [False, True]) def test_condition(self, nonblocking: bool, priority: bool, catchall: bool): - target = TestbenchIO(Adapter(i=[("cond", 2)])) + target = TestbenchIO(Adapter.create(i=[("cond", 2)])) circ = SimpleTestCircuit( ConditionTestCircuit(target.adapter.iface, nonblocking=nonblocking, priority=priority, catchall=catchall), diff --git a/test/lib/test_storage.py b/test/lib/test_storage.py index 9787a0a..aacc655 100644 --- a/test/lib/test_storage.py +++ b/test/lib/test_storage.py @@ -1,3 +1,5 @@ +from collections.abc import Callable +from amaranth_types import ShapeLike import pytest import random from collections import deque @@ -5,6 +7,7 @@ from hypothesis import given, settings, Phase from transactron.testing import * from transactron.lib.storage import * +from transactron.utils.transactron_helpers import make_layout class TestContentAddressableMemory(TestCaseWithSimulator): @@ -134,6 +137,12 @@ def test_random(self, in_push, in_write, in_read, in_remove): sim.add_testbench(self.remove_process(in_remove)) +bank_shapes = [ + (6, lambda x: x, lambda x: x), + (make_layout(("data_field", 6)), lambda x: {"data_field": x}, lambda x: x["data_field"]), +] + + class TestMemoryBank(TestCaseWithSimulator): test_conf = [(9, 3, 3, 3, 14), (16, 1, 1, 3, 15), (16, 1, 1, 1, 16), (12, 3, 1, 1, 17), (9, 0, 0, 0, 18)] @@ -141,6 +150,7 @@ class TestMemoryBank(TestCaseWithSimulator): @pytest.mark.parametrize("transparent", [False, True]) @pytest.mark.parametrize("read_ports", [1, 2]) @pytest.mark.parametrize("write_ports", [1, 2]) + @pytest.mark.parametrize("shape,to_shape,from_shape", bank_shapes) def test_mem( self, max_addr: int, @@ -151,14 +161,16 @@ def test_mem( transparent: bool, read_ports: int, write_ports: int, + shape: ShapeLike, + to_shape: Callable, + from_shape: Callable, ): test_count = 200 - data_width = 6 m = SimpleTestCircuit( MemoryBank( - data_layout=[("data", data_width)], - elem_count=max_addr, + shape=shape, + depth=max_addr, transparent=transparent, read_ports=read_ports, write_ports=write_ports, @@ -173,9 +185,9 @@ def test_mem( def writer(i): async def process(sim: TestbenchContext): for cycle in range(test_count): - d = random.randrange(2**data_width) + d = random.randrange(2 ** Shape.cast(shape).width) a = random.randrange(max_addr) - await m.write[i].call(sim, data={"data": d}, addr=a) + await m.write[i].call(sim, data=to_shape(d), addr=a) await sim.delay(1e-9 * (i + 2 if not transparent else i)) data[a] = d await self.random_wait(sim, writer_rand) @@ -202,7 +214,7 @@ async def process(sim: TestbenchContext): await self.random_wait(sim, reader_resp_rand or 1, min_cycle_cnt=1) await sim.delay(1e-9 * (write_ports + 3)) d = read_req_queues[i].popleft() - assert (await m.read_resp[i].call(sim)).data == d + assert from_shape((await m.read_resp[i].call(sim)).data) == d await self.random_wait(sim, reader_resp_rand) return process @@ -224,13 +236,27 @@ class TestAsyncMemoryBank(TestCaseWithSimulator): ) @pytest.mark.parametrize("read_ports", [1, 2]) @pytest.mark.parametrize("write_ports", [1, 2]) - def test_mem(self, max_addr: int, writer_rand: int, reader_rand: int, seed: int, read_ports: int, write_ports: int): + @pytest.mark.parametrize("shape,to_shape,from_shape", bank_shapes) + def test_mem( + self, + max_addr: int, + writer_rand: int, + reader_rand: int, + seed: int, + read_ports: int, + write_ports: int, + shape: ShapeLike, + to_shape: Callable, + from_shape: Callable, + ): test_count = 200 - data_width = 6 m = SimpleTestCircuit( AsyncMemoryBank( - data_layout=[("data", data_width)], elem_count=max_addr, read_ports=read_ports, write_ports=write_ports + shape=shape, + depth=max_addr, + read_ports=read_ports, + write_ports=write_ports, ), ) @@ -241,9 +267,9 @@ def test_mem(self, max_addr: int, writer_rand: int, reader_rand: int, seed: int, def writer(i): async def process(sim: TestbenchContext): for cycle in range(test_count): - d = random.randrange(2**data_width) + d = random.randrange(2 ** Shape.cast(shape).width) a = random.randrange(max_addr) - await m.write[i].call(sim, data={"data": d}, addr=a) + await m.write[i].call(sim, data=to_shape(d), addr=a) await sim.delay(1e-9 * (i + 2)) data[a] = d await self.random_wait(sim, writer_rand, min_cycle_cnt=1) @@ -257,7 +283,7 @@ async def process(sim: TestbenchContext): d = await m.read[i].call(sim, addr=a) await sim.delay(1e-9) expected_d = data[a] - assert d["data"] == expected_d + assert from_shape(d.data) == expected_d await self.random_wait(sim, reader_rand, min_cycle_cnt=1) return process diff --git a/test/lib/test_transformers.py b/test/lib/test_transformers.py index 4de85b4..3e4a647 100644 --- a/test/lib/test_transformers.py +++ b/test/lib/test_transformers.py @@ -51,7 +51,7 @@ def otransform_dict(_, v: MethodStruct) -> RecordDict: itransform = itransform_rec otransform = otransform_rec - m.submodules.target = self.target = TestbenchIO(Adapter(i=layout, o=layout)) + m.submodules.target = self.target = TestbenchIO(Adapter.create(i=layout, o=layout)) if self.use_methods: imeth = Method(i=layout, o=layout) @@ -111,8 +111,8 @@ class TestMethodFilter(TestCaseWithSimulator): def initialize(self): self.iosize = 4 self.layout = data_layout(self.iosize) - self.target = TestbenchIO(Adapter(i=self.layout, o=self.layout)) - self.cmeth = TestbenchIO(Adapter(i=self.layout, o=data_layout(1))) + self.target = TestbenchIO(Adapter.create(i=self.layout, o=self.layout)) + self.cmeth = TestbenchIO(Adapter.create(i=self.layout, o=data_layout(1))) async def source(self, sim: TestbenchContext): for i in range(2**self.iosize): @@ -165,7 +165,7 @@ def elaborate(self, platform): methods = [] for k in range(self.targets): - tgt = TestbenchIO(Adapter(i=layout, o=layout)) + tgt = TestbenchIO(Adapter.create(i=layout, o=layout)) methods.append(tgt.adapter.iface) self.target.append(tgt) m.submodules += tgt @@ -280,7 +280,7 @@ def elaborate(self, platform): methods = [] for k in range(self.targets): - tgt = TestbenchIO(Adapter(i=layout, o=layout)) + tgt = TestbenchIO(Adapter.create(i=layout, o=layout)) methods.append(tgt.adapter.iface) self.target.append(tgt) m.submodules += tgt diff --git a/test/test_methods.py b/test/test_methods.py index 94cff34..95c7e58 100644 --- a/test/test_methods.py +++ b/test/test_methods.py @@ -407,7 +407,7 @@ def elaborate(self, platform): meth = Method(i=data_layout(1)) m.submodules.tb = self.tb = TestbenchIO(AdapterTrans(meth)) - m.submodules.out = self.out = TestbenchIO(Adapter()) + m.submodules.out = self.out = TestbenchIO(Adapter.create()) @def_method(m, meth) def _(arg): @@ -456,7 +456,7 @@ def elaborate(self, platform): m = TModule() self.ready = Signal() - m.submodules.tb = self.tb = TestbenchIO(Adapter()) + m.submodules.tb = self.tb = TestbenchIO(Adapter.create()) with Transaction().body(m, request=self.ready): self.tb.adapter.iface(m) @@ -469,7 +469,7 @@ def elaborate(self, platform): m = TModule() self.ready = Signal() - m.submodules.tb = self.tb = TestbenchIO(Adapter()) + m.submodules.tb = self.tb = TestbenchIO(Adapter.create()) with m.If(self.ready): with Transaction().body(m): @@ -538,9 +538,9 @@ def elaborate(self, platform): self.running = Signal() self.data = Signal(WIDTH) - method = Method(o=data_layout(WIDTH), nonexclusive=True) + method = Method(o=data_layout(WIDTH)) - @def_method(m, method, self.ready) + @def_method(m, method, self.ready, nonexclusive=True) def _(): m.d.comb += self.running.eq(1) return {"data": self.data} @@ -594,20 +594,20 @@ def elaborate(self, platform): self.running1 = Signal() self.running2 = Signal() - method1 = Method(o=data_layout(WIDTH), nonexclusive=True) - method2 = Method(o=data_layout(WIDTH), nonexclusive=self.two_nonexclusive) + method1 = Method(o=data_layout(WIDTH)) + method2 = Method(o=data_layout(WIDTH)) method_in = Method(o=data_layout(WIDTH)) @def_method(m, method_in) def _(): return {"data": 0} - @def_method(m, method1) + @def_method(m, method1, nonexclusive=True) def _(): m.d.comb += self.running1.eq(1) return method_in(m) - @def_method(m, method2) + @def_method(m, method2, nonexclusive=self.two_nonexclusive) def _(): m.d.comb += self.running2.eq(1) return method_in(m) @@ -649,9 +649,9 @@ def combiner(m: Module, args: Sequence[MethodStruct], runs: Value) -> AssignArg: result = result ^ Mux(runs[i], v.data, 0) return {"data": result} - method = Method(i=data_layout(WIDTH), o=data_layout(WIDTH), nonexclusive=True, combiner=combiner) + method = Method(i=data_layout(WIDTH), o=data_layout(WIDTH)) - @def_method(m, method, self.ready) + @def_method(m, method, self.ready, nonexclusive=True, combiner=combiner) def _(data: Value): m.d.comb += self.running.eq(1) return {"data": data} diff --git a/test/test_simultaneous.py b/test/test_simultaneous.py index ad492e3..d9f281a 100644 --- a/test/test_simultaneous.py +++ b/test/test_simultaneous.py @@ -140,7 +140,7 @@ def elaborate(self, platform): class TestTransitivity(TestCaseWithSimulator): def test_transitivity(self): - target = TestbenchIO(Adapter(i=[("data", 2)])) + target = TestbenchIO(Adapter.create(i=[("data", 2)])) req1 = Signal() req2 = Signal() diff --git a/test/testing/test_method_mock.py b/test/testing/test_method_mock.py index 7ad198e..9db9561 100644 --- a/test/testing/test_method_mock.py +++ b/test/testing/test_method_mock.py @@ -10,6 +10,44 @@ from transactron.lib import * +class SimpleMethodMockTestCircuit(Elaboratable): + method: Required[Method] + wrapper: Provided[Method] + + def __init__(self, width: int): + self.method = Method(i=StructLayout({"input": width}), o=StructLayout({"output": width})) + self.wrapper = Method(i=StructLayout({"input": width}), o=StructLayout({"output": width})) + + def elaborate(self, platform): + m = TModule() + + @def_method(m, self.wrapper) + def _(input): + return {"output": self.method(m, input).output + 1} + + return m + + +class TestMethodMock(TestCaseWithSimulator): + async def process(self, sim: TestbenchContext): + for _ in range(20): + val = random.randrange(2**self.width) + ret = await self.dut.wrapper.call(sim, input=val) + assert ret.output == (val + 2) % 2**self.width + + @def_method_mock(lambda self: self.dut.method, enable=lambda _: random.randint(0, 1)) + def method_mock(self, input): + return {"output": input + 1} + + def test_method_mock_simple(self): + random.seed(42) + self.width = 4 + self.dut = SimpleTestCircuit(SimpleMethodMockTestCircuit(self.width)) + + with self.run_simulation(self.dut) as sim: + sim.add_testbench(self.process) + + class ReverseMethodMockTestCircuit(Elaboratable): def __init__(self, width): self.method = Method(i=StructLayout({"input": width}), o=StructLayout({"output": width})) diff --git a/test/testing/test_validate_arguments.py b/test/testing/test_validate_arguments.py index 7e70369..9f2f5a9 100644 --- a/test/testing/test_validate_arguments.py +++ b/test/testing/test_validate_arguments.py @@ -14,7 +14,7 @@ class ValidateArgumentsTestCircuit(Elaboratable): def elaborate(self, platform): m = Module() - self.method = TestbenchIO(Adapter(i=data_layout(1), o=data_layout(1)).set(with_validate_arguments=True)) + self.method = TestbenchIO(Adapter.create(i=data_layout(1), o=data_layout(1)).set(with_validate_arguments=True)) self.caller1 = TestbenchIO(AdapterTrans(self.method.adapter.iface)) self.caller2 = TestbenchIO(AdapterTrans(self.method.adapter.iface)) diff --git a/test/utils/test_utils.py b/test/utils/test_utils.py index e403a02..869f5ad 100644 --- a/test/utils/test_utils.py +++ b/test/utils/test_utils.py @@ -10,6 +10,7 @@ popcount, count_leading_zeros, count_trailing_zeros, + cyclic_mask, ) from amaranth.utils import ceil_log2 @@ -191,3 +192,56 @@ async def process(self, sim: TestbenchContext): def test_count_trailing_zeros(self, size): with self.run_simulation(self.m) as sim: sim.add_testbench(self.process) + + +class GenCyclicMaskTestCircuit(Elaboratable): + def __init__(self, xlen: int): + self.start = Signal(range(xlen)) + self.end = Signal(range(xlen)) + self.sig_out = Signal(xlen) + self.xlen = xlen + + def elaborate(self, platform): + m = Module() + + m.d.comb += self.sig_out.eq(cyclic_mask(self.xlen, self.start, self.end)) + + return m + + +@pytest.mark.parametrize("size", [1, 2, 3, 5, 8]) +class TestGenCyclicMask(TestCaseWithSimulator): + @pytest.fixture(scope="function", autouse=True) + def setup_fixture(self, size): + self.size = size + random.seed(14) + self.test_number = 40 + self.m = GenCyclicMaskTestCircuit(self.size) + + async def check(self, sim: TestbenchContext, start, end): + sim.set(self.m.start, start) + sim.set(self.m.end, end) + await sim.delay(1e-6) + out = sim.get(self.m.sig_out) + + expected = 0 + for i in range(min(start, end), max(start, end) + 1): + expected |= 1 << i + + if end < start: + expected ^= (1 << self.size) - 1 + expected |= 1 << start + expected |= 1 << end + + assert out == expected + + async def process(self, sim: TestbenchContext): + for _ in range(self.test_number): + start = random.randrange(self.size) + end = random.randrange(self.size) + await self.check(sim, start, end) + await sim.delay(1e-6) + + def test_count_trailing_zeros(self, size): + with self.run_simulation(self.m) as sim: + sim.add_testbench(self.process) diff --git a/transactron/core/__init__.py b/transactron/core/__init__.py index 6ead593..1ea77f3 100644 --- a/transactron/core/__init__.py +++ b/transactron/core/__init__.py @@ -4,3 +4,4 @@ from .transaction import * # noqa: F401 from .manager import * # noqa: F401 from .sugar import * # noqa: F401 +from .body import * # noqa: F401 diff --git a/transactron/core/body.py b/transactron/core/body.py new file mode 100644 index 0000000..5d7db0e --- /dev/null +++ b/transactron/core/body.py @@ -0,0 +1,130 @@ +from collections import defaultdict +from collections.abc import Iterator, Sequence +from contextlib import contextmanager +from itertools import count + +from amaranth.lib.data import StructLayout +from transactron.core.tmodule import CtrlPath, TModule +from transactron.core.transaction_base import TransactionBase + +from transactron.utils import * +from amaranth import * +from typing import TYPE_CHECKING, ClassVar, NewType, NotRequired, Optional, Callable, TypedDict, Unpack, final +from transactron.utils.assign import AssignArg + +if TYPE_CHECKING: + from .method import Method + + +__all__ = ["AdapterBodyParams", "BodyParams", "Body", "TBody", "MBody"] + + +class AdapterBodyParams(TypedDict): + combiner: NotRequired[Callable[[Module, Sequence[MethodStruct], Value], AssignArg]] + nonexclusive: NotRequired[bool] + single_caller: NotRequired[bool] + + +class BodyParams(AdapterBodyParams): + validate_arguments: NotRequired[Callable[..., ValueLike]] + + +@final +class Body(TransactionBase["Body"]): + def_counter: ClassVar[count] = count() + def_order: int + stack: ClassVar[list["Body"]] = [] + ctrl_path: CtrlPath = CtrlPath(-1, []) + method_uses: dict["Method", tuple[MethodStruct, Signal]] + method_calls: defaultdict["Method", list[tuple[CtrlPath, MethodStruct, ValueLike]]] + + def __init__( + self, + *, + name: str, + owner: Optional[Elaboratable], + i: StructLayout, + o: StructLayout, + src_loc: SrcLoc, + **kwargs: Unpack[BodyParams], + ): + super().__init__(src_loc=src_loc) + + def default_combiner(m: Module, args: Sequence[MethodStruct], runs: Value) -> AssignArg: + ret = Signal(from_method_layout(i)) + for k in OneHotSwitchDynamic(m, runs): + m.d.comb += ret.eq(args[k]) + return ret + + self.def_order = next(Body.def_counter) + self.name = name + self.owner = owner + self.ready = Signal(name=self.owned_name + "_ready") + self.runnable = Signal(name=self.owned_name + "_runnable") + self.run = Signal(name=self.owned_name + "_run") + self.data_in: MethodStruct = Signal(from_method_layout(i), name=self.owned_name + "_data_in") + self.data_out: MethodStruct = Signal(from_method_layout(o), name=self.owned_name + "_data_out") + self.combiner: Callable[[Module, Sequence[MethodStruct], Value], AssignArg] = ( + kwargs["combiner"] if "combiner" in kwargs else default_combiner + ) + self.nonexclusive = kwargs["nonexclusive"] if "nonexclusive" in kwargs else False + self.single_caller = kwargs["single_caller"] if "single_caller" in kwargs else False + self.validate_arguments: Optional[Callable[..., ValueLike]] = ( + kwargs["validate_arguments"] if "validate_arguments" in kwargs else None + ) + self.method_uses = {} + self.method_calls = defaultdict(list) + + if self.nonexclusive: + assert len(self.data_in.as_value()) == 0 or self.combiner is not None + + def _validate_arguments(self, arg_rec: MethodStruct) -> ValueLike: + if self.validate_arguments is not None: + return self.ready & method_def_helper(self, self.validate_arguments, arg_rec) + return self.ready + + @contextmanager + def context(self, m: TModule) -> Iterator["Body"]: + self.ctrl_path = m.ctrl_path + + parent = Body.peek() + if parent is not None: + parent.schedule_before(self) + + Body.stack.append(self) + + try: + yield self + finally: + Body.stack.pop() + self.defined = True + + @staticmethod + def get() -> "Body": + ret = Body.peek() + if ret is None: + raise RuntimeError("No current body") + return ret + + @staticmethod + def peek() -> Optional["Body"]: + if not Body.stack: + return None + return Body.stack[-1] + + def _set_method_uses(self, m: ModuleLike): + for method, calls in self.method_calls.items(): + arg_rec, enable_sig = self.method_uses[method] + if len(calls) == 1: + m.d.comb += arg_rec.eq(calls[0][1]) + m.d.comb += enable_sig.eq(calls[0][2]) + else: + call_ens = Cat([en for _, _, en in calls]) + + for i in OneHotSwitchDynamic(m, call_ens): + m.d.comb += arg_rec.eq(calls[i][1]) + m.d.comb += enable_sig.eq(1) + + +TBody = NewType("TBody", Body) +MBody = NewType("MBody", Body) diff --git a/transactron/core/manager.py b/transactron/core/manager.py index ec9c1e6..24d7038 100644 --- a/transactron/core/manager.py +++ b/transactron/core/manager.py @@ -13,34 +13,34 @@ from transactron.utils.transactron_helpers import _graph_ccs from transactron.graph import OwnershipGraph, Direction -from .transaction_base import TransactionBase, TransactionOrMethod, Priority, Relation -from .method import Method +from .transaction_base import Priority, Relation, RelationBase +from .body import Body, TBody, MBody from .transaction import Transaction, TransactionManagerKey +from .method import Method from .tmodule import TModule from .schedulers import eager_deterministic_cc_scheduler __all__ = ["TransactionManager", "TransactionModule", "TransactionComponent"] -TransactionGraph: TypeAlias = Graph["Transaction"] -TransactionGraphCC: TypeAlias = GraphCC["Transaction"] -PriorityOrder: TypeAlias = dict["Transaction", int] +TransactionGraph: TypeAlias = Graph[TBody] +TransactionGraphCC: TypeAlias = GraphCC[TBody] +PriorityOrder: TypeAlias = dict[TBody, int] TransactionScheduler: TypeAlias = Callable[["MethodMap", TransactionGraph, TransactionGraphCC, PriorityOrder], Module] class MethodMap: - def __init__(self, transactions: Iterable["Transaction"]): - self.methods_by_transaction = dict[Transaction, list[Method]]() - self.transactions_by_method = defaultdict[Method, list[Transaction]](list) - self.argument_by_call = dict[tuple[Transaction, Method], MethodStruct]() - self.ancestors_by_call = dict[tuple[Transaction, Method], tuple[Method, ...]]() - self.method_parents = defaultdict[Method, list[TransactionBase]](list) - - def rec(transaction: Transaction, source: TransactionBase, ancestors: tuple[Method, ...]): - for method, (arg_rec, _) in source.method_uses.items(): - if not method.defined: - raise RuntimeError(f"Trying to use method '{method.name}' which is not defined yet") + def __init__(self, transactions: Iterable[Transaction]): + self.methods_by_transaction = dict[TBody, list[MBody]]() + self.transactions_by_method = defaultdict[MBody, list[TBody]](list) + self.argument_by_call = dict[tuple[TBody, MBody], MethodStruct]() + self.ancestors_by_call = dict[tuple[TBody, MBody], tuple[MBody, ...]]() + self.method_parents = defaultdict[MBody, list[Body]](list) + + def rec(transaction: TBody, source: Body, ancestors: tuple[MBody, ...]): + for method_obj, (arg_rec, _) in source.method_uses.items(): + method = MBody(method_obj._body) if method in self.methods_by_transaction[transaction]: - raise RuntimeError(f"Method '{method.name}' can't be called twice from the same transaction") + raise RuntimeError(f"Method '{method_obj.name}' can't be called twice from the same transaction") self.methods_by_transaction[transaction].append(method) self.transactions_by_method[method].append(transaction) self.argument_by_call[(transaction, method)] = arg_rec @@ -48,29 +48,30 @@ def rec(transaction: Transaction, source: TransactionBase, ancestors: tuple[Meth rec(transaction, method, new_ancestors) for transaction in transactions: - self.methods_by_transaction[transaction] = [] - rec(transaction, transaction, ()) + self.methods_by_transaction[TBody(transaction._body)] = [] + rec(TBody(transaction._body), transaction._body, ()) for transaction_or_method in self.methods_and_transactions: for method in transaction_or_method.method_uses.keys(): - self.method_parents[method].append(transaction_or_method) + self.method_parents[MBody(method._body)].append(transaction_or_method) - def transactions_for(self, elem: TransactionOrMethod) -> Collection["Transaction"]: - if isinstance(elem, Transaction): - return [elem] - else: + def transactions_for(self, elem: Body) -> Collection[TBody]: + if elem in self.transactions_by_method: return self.transactions_by_method[elem] + else: + assert elem in self.methods_by_transaction + return [TBody(elem)] @property - def methods(self) -> Collection["Method"]: + def methods(self) -> Collection[MBody]: return self.transactions_by_method.keys() @property - def transactions(self) -> Collection["Transaction"]: + def transactions(self) -> Collection[TBody]: return self.methods_by_transaction.keys() @property - def methods_and_transactions(self) -> Iterable[TransactionOrMethod]: + def methods_and_transactions(self) -> Iterable[Body]: return chain(self.methods, self.transactions) @@ -84,11 +85,15 @@ class TransactionManager(Elaboratable): def __init__(self, cc_scheduler: TransactionScheduler = eager_deterministic_cc_scheduler): self.transactions: list[Transaction] = [] + self.methods: list[Method] = [] self.cc_scheduler = cc_scheduler - def add_transaction(self, transaction: "Transaction"): + def _add_transaction(self, transaction: Transaction): self.transactions.append(transaction) + def _add_method(self, method: Method): + self.methods.append(method) + @staticmethod def _conflict_graph(method_map: MethodMap) -> tuple[TransactionGraph, PriorityOrder]: """_conflict_graph @@ -120,7 +125,7 @@ def _conflict_graph(method_map: MethodMap) -> tuple[TransactionGraph, PriorityOr Linear ordering of transactions which is consistent with priority constraints. """ - def transactions_exclusive(trans1: Transaction, trans2: Transaction): + def transactions_exclusive(trans1: TBody, trans2: TBody): tms1 = [trans1] + method_map.methods_by_transaction[trans1] tms2 = [trans2] + method_map.methods_by_transaction[trans2] @@ -132,7 +137,7 @@ def transactions_exclusive(trans1: Transaction, trans2: Transaction): return False - def calls_nonexclusive(trans1: Transaction, trans2: Transaction, method: Method): + def calls_nonexclusive(trans1: TBody, trans2: TBody, method: MBody): ancestors1 = method_map.ancestors_by_call[(trans1, method)] ancestors2 = method_map.ancestors_by_call[(trans2, method)] common_ancestors = longest_common_prefix(ancestors1, ancestors2) @@ -141,7 +146,7 @@ def calls_nonexclusive(trans1: Transaction, trans2: Transaction, method: Method) cgr: TransactionGraph = {} # Conflict graph pgr: TransactionGraph = {} # Priority graph - def add_edge(begin: Transaction, end: Transaction, priority: Priority, conflict: bool): + def add_edge(begin: TBody, end: TBody, priority: Priority, conflict: bool): if conflict: cgr[begin].add(end) cgr[end].add(begin) @@ -166,22 +171,22 @@ def add_edge(begin: Transaction, end: Transaction, priority: Priority, conflict: add_edge(transaction1, transaction2, Priority.UNDEFINED, True) relations = [ - Relation(**relation, start=elem) + Relation(start=elem, **dataclass_asdict(relation)) for elem in method_map.methods_and_transactions for relation in elem.relations ] for relation in relations: - start = relation["start"] - end = relation["end"] - if not relation["conflict"]: # relation added with schedule_before - if end.def_order < start.def_order and not relation["silence_warning"]: + start = relation.start + end = relation.end + if not relation.conflict: # relation added with schedule_before + if end.def_order < start.def_order and not relation.silence_warning: raise RuntimeError(f"{start.name!r} scheduled before {end.name!r}, but defined afterwards") for trans_start in method_map.transactions_for(start): for trans_end in method_map.transactions_for(end): - conflict = relation["conflict"] and not transactions_exclusive(trans_start, trans_end) - add_edge(trans_start, trans_end, relation["priority"], conflict) + conflict = relation.conflict and not transactions_exclusive(trans_start, trans_end) + add_edge(trans_start, trans_end, relation.priority, conflict) porder: PriorityOrder = {} @@ -191,15 +196,15 @@ def add_edge(begin: Transaction, end: Transaction, priority: Priority, conflict: return cgr, porder @staticmethod - def _method_enables(method_map: MethodMap) -> Mapping["Transaction", Mapping["Method", ValueLike]]: - method_enables = defaultdict[Transaction, dict[Method, ValueLike]](dict) + def _method_enables(method_map: MethodMap) -> Mapping[TBody, Mapping[MBody, ValueLike]]: + method_enables = defaultdict[TBody, dict[MBody, ValueLike]](dict) enables: list[ValueLike] = [] - def rec(transaction: Transaction, source: TransactionOrMethod): + def rec(transaction: TBody, source: Body): for method, (_, enable) in source.method_uses.items(): enables.append(enable) - rec(transaction, method) - method_enables[transaction][method] = Cat(*enables).all() + rec(transaction, method._body) + method_enables[transaction][MBody(method._body)] = Cat(*enables).all() enables.pop() for transaction in method_map.transactions: @@ -210,20 +215,20 @@ def rec(transaction: Transaction, source: TransactionOrMethod): @staticmethod def _method_calls( m: Module, method_map: MethodMap - ) -> tuple[Mapping["Method", Sequence[MethodStruct]], Mapping["Method", Sequence[Value]]]: - args = defaultdict[Method, list[MethodStruct]](list) - runs = defaultdict[Method, list[Value]](list) + ) -> tuple[Mapping[MBody, Sequence[MethodStruct]], Mapping[MBody, Sequence[Value]]]: + args = defaultdict[MBody, list[MethodStruct]](list) + runs = defaultdict[MBody, list[Value]](list) for source in method_map.methods_and_transactions: - if isinstance(source, Method): - run_val = Cat(transaction.grant for transaction in method_map.transactions_by_method[source]).any() + if source in method_map.methods: + run_val = Cat(transaction.run for transaction in method_map.transactions_by_method[MBody(source)]).any() run = Signal() m.d.comb += run.eq(run_val) else: - run = source.grant + run = source.run for method, (arg, _) in source.method_uses.items(): - args[method].append(arg) - runs[method].append(run) + args[method._body].append(arg) + runs[method._body].append(run) return (args, runs) @@ -236,24 +241,24 @@ def _simultaneous(self): all_sims = frozenset(elem.simultaneous_list) elem.relations = list( filterfalse( - lambda relation: not relation["conflict"] - and relation["priority"] != Priority.UNDEFINED - and relation["end"] in all_sims, + lambda relation: not relation.conflict + and relation.priority != Priority.UNDEFINED + and relation.end in all_sims, elem.relations, ) ) # step 1: simultaneous and independent sets generation - independents = defaultdict[Transaction, set[Transaction]](set) + independents = defaultdict[TBody, set[TBody]](set) for elem in method_map.methods_and_transactions: - indeps = frozenset[Transaction]().union( + indeps = frozenset[TBody]().union( *(frozenset(method_map.transactions_for(ind)) for ind in chain([elem], elem.independent_list)) ) for transaction1, transaction2 in product(indeps, indeps): independents[transaction1].add(transaction2) - simultaneous = set[frozenset[Transaction]]() + simultaneous = set[frozenset[TBody]]() for elem in method_map.methods_and_transactions: for sim_elem in elem.simultaneous_list: @@ -265,12 +270,12 @@ def _simultaneous(self): simultaneous.add(frozenset({tr1, tr2})) # step 2: transitivity computation - tr_simultaneous = set[frozenset[Transaction]]() + tr_simultaneous = set[frozenset[TBody]]() - def conflicting(group: frozenset[Transaction]): + def conflicting(group: frozenset[TBody]): return any(tr1 != tr2 and tr1 in independents[tr2] for tr1 in group for tr2 in group) - q = deque[frozenset[Transaction]](simultaneous) + q = deque[frozenset[TBody]](simultaneous) while q: new_group = q.popleft() @@ -280,51 +285,48 @@ def conflicting(group: frozenset[Transaction]): tr_simultaneous.add(new_group) # step 3: maximal group selection - def maximal(group: frozenset[Transaction]): + def maximal(group: frozenset[TBody]): return not any(group.issubset(group2) and group != group2 for group2 in tr_simultaneous) final_simultaneous = set(filter(maximal, tr_simultaneous)) # step 4: convert transactions to methods - joined_transactions = set[Transaction]().union(*final_simultaneous) + joined_transactions = set[TBody]().union(*final_simultaneous) - self.transactions = list(filter(lambda t: t not in joined_transactions, self.transactions)) - methods = dict[Transaction, Method]() + self.transactions = list(filter(lambda t: t._body not in joined_transactions, self.transactions)) + methods = dict[TBody, Method]() + + m = TModule() + m._MustUse__silence = True # type: ignore for transaction in joined_transactions: - # TODO: some simpler way? - method = Method(name=transaction.name) - method.owner = transaction.owner - method.src_loc = transaction.src_loc - method.ready = transaction.request - method.run = transaction.grant - method.defined = transaction.defined - method.method_calls = transaction.method_calls - method.method_uses = transaction.method_uses - method.relations = transaction.relations - method.def_order = transaction.def_order - method.ctrl_path = transaction.ctrl_path + method = Method(name=transaction.name, src_loc=transaction.src_loc) + method._set_impl(m, transaction) methods[transaction] = method - for elem in method_map.methods_and_transactions: - # I guess method/transaction unification is really needed - for relation in elem.relations: - if relation["end"] in methods: - relation["end"] = methods[relation["end"]] - # step 5: construct merged transactions - m = TModule() - m._MustUse__silence = True # type: ignore - - for group in final_simultaneous: - name = "_".join([t.name for t in group]) - with Transaction(manager=self, name=name).body(m): - for transaction in group: - methods[transaction](m) + with DependencyContext(DependencyManager()): + DependencyContext.get().add_dependency(TransactionManagerKey(), self) + for group in final_simultaneous: + name = "_".join([t.name for t in group]) + with Transaction(name=name).body(m): + for transaction in group: + methods[transaction](m) return m def elaborate(self, platform): + for elem in chain(self.transactions, self.methods): + for relation in elem.relations: + elem._body.relations.append(RelationBase(**{**dataclass_asdict(relation), "end": relation.end._body})) + for elem2 in elem.simultaneous_list: + elem._body.simultaneous_list.append(elem2._body) + for elem2 in elem.independent_list: + elem._body.independent_list.append(elem2._body) + elem.relations = [] + elem.simultaneous_list = [] + elem.independent_list = [] + # In the following, various problems in the transaction set-up are detected. # The exception triggers an unused Elaboratable warning. with silence_mustuse(self): @@ -334,12 +336,13 @@ def elaborate(self, platform): cgr, porder = TransactionManager._conflict_graph(method_map) m = Module() + m._MustUse__silence = True # type: ignore m.submodules.merge_manager = merge_manager for elem in method_map.methods_and_transactions: elem._set_method_uses(m) - for transaction in self.transactions: + for transaction in method_map.transactions: ready = [ method._validate_arguments(method_map.argument_by_call[transaction, method]) for method in method_map.methods_by_transaction[transaction] @@ -347,14 +350,11 @@ def elaborate(self, platform): m.d.comb += transaction.runnable.eq(Cat(ready).all()) ccs = _graph_ccs(cgr) - m.submodules._transactron_schedulers = ModuleConnector( - *[self.cc_scheduler(method_map, cgr, cc, porder) for cc in ccs] - ) method_enables = self._method_enables(method_map) for method, transactions in method_map.transactions_by_method.items(): - granted = Cat(transaction.grant & method_enables[transaction][method] for transaction in transactions) + granted = Cat(transaction.run & method_enables[transaction][method] for transaction in transactions) m.d.comb += method.run.eq(granted.any()) (method_args, method_runs) = self._method_calls(m, method_map) @@ -369,13 +369,17 @@ def elaborate(self, platform): runs = Cat(method_runs[method]) m.d.comb += assign(method.data_in, method.combiner(m, method_args[method], runs), fields=AssignType.ALL) + m.submodules._transactron_schedulers = ModuleConnector( + *[self.cc_scheduler(method_map, cgr, cc, porder) for cc in ccs] + ) + if "TRANSACTRON_VERBOSE" in environ: self.print_info(cgr, porder, ccs, method_map) return m def print_info( - self, cgr: TransactionGraph, porder: PriorityOrder, ccs: list[GraphCC["Transaction"]], method_map: MethodMap + self, cgr: TransactionGraph, porder: PriorityOrder, ccs: list[GraphCC["TBody"]], method_map: MethodMap ): print("Transactron statistics") print(f"\tMethods: {len(method_map.methods)}") @@ -425,14 +429,12 @@ def debug_signals(self) -> SignalBundle: method_map = MethodMap(self.transactions) cgr, _ = TransactionManager._conflict_graph(method_map) - def transaction_debug(t: Transaction): + def transaction_debug(t: TBody): return ( - [t.request, t.grant] - + [m.ready for m in method_map.methods_by_transaction[t]] - + [t2.grant for t2 in cgr[t]] + [t.ready, t.run] + [m.ready for m in method_map.methods_by_transaction[t]] + [t2.run for t2 in cgr[t]] ) - def method_debug(m: Method): + def method_debug(m: MBody): return [m.ready, m.run, {t.name: transaction_debug(t) for t in method_map.transactions_by_method[m]}] return { diff --git a/transactron/core/method.py b/transactron/core/method.py index afc678d..7484b86 100644 --- a/transactron/core/method.py +++ b/transactron/core/method.py @@ -1,22 +1,39 @@ from collections.abc import Sequence +import enum from transactron.utils import * from amaranth import * from amaranth import tracer -from typing import Optional, Callable, Iterator, TYPE_CHECKING +from typing import TYPE_CHECKING, Annotated, Optional, Iterator, TypeAlias, TypeVar, Unpack from .transaction_base import * -from .sugar import def_method from contextlib import contextmanager from transactron.utils.assign import AssignArg from transactron.utils._typing import type_self_add_1pos_kwargs_as +from .body import Body, BodyParams, MBody +from .keys import TransactionManagerKey +from .tmodule import TModule +from .transaction_base import TransactionBase + + if TYPE_CHECKING: - from .tmodule import TModule + from .transaction import Transaction # noqa: F401 + + +__all__ = ["MethodDir", "Provided", "Required", "Method", "Methods"] + + +class MethodDir(enum.Enum): + PROVIDED = enum.auto() + REQUIRED = enum.auto() -__all__ = ["Method", "Methods"] +_T = TypeVar("_T") +Provided: TypeAlias = Annotated[_T, MethodDir.PROVIDED] +Required: TypeAlias = Annotated[_T, MethodDir.REQUIRED] -class Method(TransactionBase): + +class Method(TransactionBase["Transaction | Method"]): """Transactional method. A `Method` serves to interface a module with external `Transaction`\\s @@ -55,16 +72,10 @@ class Method(TransactionBase): calling `body`. """ + _body_ptr: Optional["Body | Method"] = None + def __init__( - self, - *, - name: Optional[str] = None, - i: MethodLayout = (), - o: MethodLayout = (), - nonexclusive: bool = False, - combiner: Optional[Callable[[Module, Sequence[MethodStruct], Value], AssignArg]] = None, - single_caller: bool = False, - src_loc: int | SrcLoc = 0, + self, *, name: Optional[str] = None, i: MethodLayout = (), o: MethodLayout = (), src_loc: int | SrcLoc = 0 ): """ Parameters @@ -76,45 +87,17 @@ def __init__( The format of `data_in`. o: method layout The format of `data_out`. - nonexclusive: bool - If true, the method is non-exclusive: it can be called by multiple - transactions in the same clock cycle. If such a situation happens, - the method still is executed only once, and each of the callers - receive its output. Nonexclusive methods cannot have inputs. - combiner: (Module, Sequence[MethodStruct], Value) -> AssignArg - If `nonexclusive` is true, the combiner function combines the - arguments from multiple calls to this method into a single - argument, which is passed to the method body. The third argument - is a bit vector, whose n-th bit is 1 if the n-th call is active - in a given cycle. - single_caller: bool - If true, this method is intended to be called from a single - transaction. An error will be thrown if called from multiple - transactions. src_loc: int | SrcLoc How many stack frames deep the source location is taken from. Alternatively, the source location to use instead of the default. """ super().__init__(src_loc=get_src_loc(src_loc)) - - def default_combiner(m: Module, args: Sequence[MethodStruct], runs: Value) -> AssignArg: - ret = Signal(from_method_layout(i)) - for k in OneHotSwitchDynamic(m, runs): - m.d.comb += ret.eq(args[k]) - return ret - self.owner, owner_name = get_caller_class_name(default="$method") self.name = name or tracer.get_var_name(depth=2, default=owner_name) self.ready = Signal(name=self.owned_name + "_ready") self.run = Signal(name=self.owned_name + "_run") - self.data_in: MethodStruct = Signal(from_method_layout(i)) - self.data_out: MethodStruct = Signal(from_method_layout(o)) - self.nonexclusive = nonexclusive - self.combiner: Callable[[Module, Sequence[MethodStruct], Value], AssignArg] = combiner or default_combiner - self.single_caller = single_caller - self.validate_arguments: Optional[Callable[..., ValueLike]] = None - if nonexclusive: - assert len(self.data_in.as_value()) == 0 or combiner is not None + self.data_in: MethodStruct = Signal(from_method_layout(i), name=self.owned_name + "_data_in") + self.data_out: MethodStruct = Signal(from_method_layout(o), name=self.owned_name + "_data_out") @property def layout_in(self): @@ -125,7 +108,7 @@ def layout_out(self): return self.data_out.shape() @staticmethod - def like(other: "Method", *, name: Optional[str] = None, src_loc: int | SrcLoc = 0) -> "Method": + def like(other: "Method", *, name: Optional[str] = None) -> "Method": """Constructs a new `Method` based on another. The returned `Method` has the same input/output data layouts as the @@ -137,18 +120,35 @@ def like(other: "Method", *, name: Optional[str] = None, src_loc: int | SrcLoc = The `Method` which serves as a blueprint for the new `Method`. name : str, optional Name of the new `Method`. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. Returns ------- Method The freshly constructed `Method`. """ - return Method(name=name, i=other.layout_in, o=other.layout_out, src_loc=get_src_loc(src_loc)) + return Method(name=name, i=other.layout_in, o=other.layout_out) - def proxy(self, m: "TModule", method: "Method"): + @property + def _body(self) -> MBody: + if isinstance(self._body_ptr, Body): + return MBody(self._body_ptr) + if isinstance(self._body_ptr, Method): + self._body_ptr = self._body_ptr._body + return self._body_ptr + raise RuntimeError(f"Method '{self.name}' not defined") + + def _set_impl(self, m: TModule, value: "Body | Method"): + if self._body_ptr is not None: + raise RuntimeError(f"Method '{self.name}' already defined") + if value.data_in.shape() != self.layout_in or value.data_out.shape() != self.layout_out: + raise ValueError(f"Method {value.name} has different interface than {self.name}") + self._body_ptr = value + m.d.comb += self.ready.eq(value.ready) + m.d.comb += self.run.eq(value.run) + m.d.comb += self.data_in.eq(value.data_in) + m.d.comb += self.data_out.eq(value.data_out) + + def proxy(self, m: TModule, method: "Method"): """Define as a proxy for another method. The calls to this method will be forwarded to `method`. @@ -161,19 +161,11 @@ def proxy(self, m: "TModule", method: "Method"): method : Method Method for which this method is a proxy for. """ - - @def_method(m, self, ready=method.ready) - def _(arg): - return method(m, arg) + self._set_impl(m, method) @contextmanager def body( - self, - m: "TModule", - *, - ready: ValueLike = C(1), - out: ValueLike = C(0, 0), - validate_arguments: Optional[Callable[..., ValueLike]] = None, + self, m: TModule, *, ready: ValueLike = C(1), out: ValueLike = C(0, 0), **kwargs: Unpack[BodyParams] ) -> Iterator[MethodStruct]: """Define method body @@ -203,6 +195,21 @@ def body( It instantiates a combinational circuit for each method caller. By default, there is no function, so all arguments are accepted. + combiner: (Module, Sequence[MethodStruct], Value) -> AssignArg + If `nonexclusive` is true, the combiner function combines the + arguments from multiple calls to this method into a single + argument, which is passed to the method body. The third argument + is a bit vector, whose n-th bit is 1 if the n-th call is active + in a given cycle. + nonexclusive: bool + If true, the method is non-exclusive: it can be called by multiple + transactions in the same clock cycle. If such a situation happens, + the method still is executed only once, and each of the callers + receive its output. Nonexclusive methods cannot have inputs. + single_caller: bool + If true, this method is intended to be called from a single + transaction. An error will be thrown if called from multiple + transactions. Returns ------- @@ -221,24 +228,22 @@ def body( with my_sum_method.body(m, out = sum) as data_in: m.d.comb += sum.eq(data_in.arg1 + data_in.arg2) """ - if self.defined: - raise RuntimeError(f"Method '{self.name}' already defined") - self.def_order = next(TransactionBase.def_counter) - self.validate_arguments = validate_arguments + body = Body( + name=self.name, owner=self.owner, i=self.layout_in, o=self.layout_out, src_loc=self.src_loc, **kwargs + ) + self._set_impl(m, body) - m.d.av_comb += self.ready.eq(ready) - m.d.top_comb += self.data_out.eq(out) - with self.context(m): - with m.AvoidedIf(self.run): - yield self.data_in + m.d.av_comb += body.ready.eq(ready) + m.d.top_comb += body.data_out.eq(out) + with body.context(m): + with m.AvoidedIf(body.run): + yield body.data_in - def _validate_arguments(self, arg_rec: MethodStruct) -> ValueLike: - if self.validate_arguments is not None: - return self.ready & method_def_helper(self, self.validate_arguments, arg_rec) - return self.ready + manager = DependencyContext.get().get_dependency(TransactionManagerKey()) + manager._add_method(self) def __call__( - self, m: "TModule", arg: Optional[AssignArg] = None, enable: ValueLike = C(1), /, **kwargs: AssignArg + self, m: TModule, arg: Optional[AssignArg] = None, enable: ValueLike = C(1), /, **kwargs: AssignArg ) -> MethodStruct: """Call a method. @@ -298,7 +303,7 @@ def __call__( m.d.av_comb += enable_sig.eq(enable) m.d.top_comb += assign(arg_rec, arg, fields=AssignType.ALL) - caller = TransactionBase.get() + caller = Body.get() if not all(ctrl_path.exclusive_with(m.ctrl_path) for ctrl_path, _, _ in caller.method_calls[self]): raise RuntimeError(f"Method '{self.name}' can't be called twice from the same caller '{caller.name}'") caller.method_calls[self].append((m.ctrl_path, arg_rec, enable_sig)) @@ -339,7 +344,7 @@ def layout_out(self): return self._methods[0].layout_out def __call__( - self, m: "TModule", arg: Optional[AssignArg] = None, enable: ValueLike = C(1), /, **kwargs: AssignArg + self, m: TModule, arg: Optional[AssignArg] = None, enable: ValueLike = C(1), /, **kwargs: AssignArg ) -> MethodStruct: if len(self._methods) != 1: raise RuntimeError("calling Methods only allowed when count=1") diff --git a/transactron/core/schedulers.py b/transactron/core/schedulers.py index 856d445..ab12075 100644 --- a/transactron/core/schedulers.py +++ b/transactron/core/schedulers.py @@ -38,9 +38,9 @@ def eager_deterministic_cc_scheduler( ccl = list(cc) ccl.sort(key=lambda transaction: porder[transaction]) for k, transaction in enumerate(ccl): - conflicts = [ccl[j].grant for j in range(k) if ccl[j] in gr[transaction]] + conflicts = [ccl[j].run for j in range(k) if ccl[j] in gr[transaction]] noconflict = ~Cat(conflicts).any() - m.d.comb += transaction.grant.eq(transaction.request & transaction.runnable & noconflict) + m.d.comb += transaction.run.eq(transaction.ready & transaction.runnable & noconflict) return m @@ -72,6 +72,6 @@ def trivial_roundrobin_cc_scheduler( sched = Scheduler(len(cc)) m.submodules.scheduler = sched for k, transaction in enumerate(cc): - m.d.comb += sched.requests[k].eq(transaction.request & transaction.runnable) - m.d.comb += transaction.grant.eq(sched.grant[k] & sched.valid) + m.d.comb += sched.requests[k].eq(transaction.ready & transaction.runnable) + m.d.comb += transaction.run.eq(sched.grant[k] & sched.valid) return m diff --git a/transactron/core/sugar.py b/transactron/core/sugar.py index 640cddb..4b1e599 100644 --- a/transactron/core/sugar.py +++ b/transactron/core/sugar.py @@ -1,13 +1,12 @@ from collections.abc import Sequence, Callable from amaranth import * -from typing import TYPE_CHECKING, Optional, Concatenate, ParamSpec +from typing import Optional, Concatenate, ParamSpec, Unpack +from transactron.core.body import BodyParams from transactron.utils import * from transactron.utils.assign import AssignArg from functools import partial - -if TYPE_CHECKING: - from .tmodule import TModule - from .method import Method +from .tmodule import TModule +from .method import Method __all__ = ["def_method", "def_methods"] @@ -15,12 +14,7 @@ P = ParamSpec("P") -def def_method( - m: "TModule", - method: "Method", - ready: ValueLike = C(1), - validate_arguments: Optional[Callable[..., ValueLike]] = None, -): +def def_method(m: TModule, method: Method, ready: ValueLike = C(1), **kwargs: Unpack[BodyParams]): """Define a method. This decorator allows to define transactional methods in an @@ -45,12 +39,8 @@ def def_method( Signal to indicate if the method is ready to be run. By default it is `Const(1)`, so the method is always ready. Assigned combinationally to the `ready` attribute. - validate_arguments: Optional[Callable[..., ValueLike]] - Function that takes input arguments used to call the method - and checks whether the method can be called with those arguments. - It instantiates a combinational circuit for each - method caller. By default, there is no function, so all arguments - are accepted. + **kwargs: BodyParams + For details, see `Method.body`. Examples -------- @@ -86,7 +76,7 @@ def decorator(func: Callable[..., Optional[AssignArg]]): out = Signal(method.layout_out) ret_out = None - with method.body(m, ready=ready, out=out, validate_arguments=validate_arguments) as arg: + with method.body(m, ready=ready, out=out, **kwargs) as arg: ret_out = method_def_helper(method, func, arg) if ret_out is not None: @@ -96,10 +86,10 @@ def decorator(func: Callable[..., Optional[AssignArg]]): def def_methods( - m: "TModule", - methods: Sequence["Method"], + m: TModule, + methods: Sequence[Method], ready: Callable[[int], ValueLike] = lambda _: C(1), - validate_arguments: Optional[Callable[..., ValueLike]] = None, + **kwargs: Unpack[BodyParams], ): """Decorator for defining similar methods @@ -130,12 +120,8 @@ def _(arg): A `Callable` that takes the index in the form of an `int` of the currently defined method and produces a `Value` describing whether the method is ready to be run. When omitted, each defined method is always ready. Assigned combinationally to the `ready` attribute. - validate_arguments: Optional[Callable[Concatenate[int, ...], ValueLike]] - Function that takes input arguments used to call the method - and checks whether the method can be called with those arguments. - It instantiates a combinational circuit for each - method caller. By default, there is no function, so all arguments - are accepted. + **kwargs: BodyParams + For details, see `Method.body`. Examples -------- @@ -174,7 +160,6 @@ def _(_): def decorator(func: Callable[Concatenate[int, P], Optional[RecordDict]]): for i in range(len(methods)): partial_f = partial(func, i) - partial_vargs = partial(validate_arguments, i) if validate_arguments is not None else None - def_method(m, methods[i], ready(i), partial_vargs)(partial_f) + def_method(m, methods[i], ready(i), **kwargs)(partial_f) return decorator diff --git a/transactron/core/transaction.py b/transactron/core/transaction.py index c6f4176..2667029 100644 --- a/transactron/core/transaction.py +++ b/transactron/core/transaction.py @@ -1,19 +1,23 @@ +from amaranth.lib.data import StructLayout from transactron.utils import * from amaranth import * from amaranth import tracer -from typing import Optional, Iterator, TYPE_CHECKING -from .transaction_base import * +from typing import TYPE_CHECKING, Optional, Iterator from .keys import * from contextlib import contextmanager +from .body import Body, TBody +from .tmodule import TModule +from .transaction_base import TransactionBase + if TYPE_CHECKING: - from .tmodule import TModule - from .manager import TransactionManager + from .method import Method # noqa: F401 + __all__ = ["Transaction"] -class Transaction(TransactionBase): +class Transaction(TransactionBase["Transaction | Method"]): """Transaction. A `Transaction` represents a task which needs to be regularly done. @@ -51,9 +55,9 @@ class Transaction(TransactionBase): and all used methods are called. """ - def __init__( - self, *, name: Optional[str] = None, manager: Optional["TransactionManager"] = None, src_loc: int | SrcLoc = 0 - ): + _body_ptr: Optional["Body"] = None + + def __init__(self, *, name: Optional[str] = None, src_loc: int | SrcLoc = 0): """ Parameters ---------- @@ -62,9 +66,6 @@ def __init__( inferred from the variable name this `Transaction` is assigned to. If the `Transaction` was not assigned, the name is inferred from the class name where the `Transaction` was constructed. - manager: TransactionManager - The `TransactionManager` controlling this `Transaction`. - If omitted, the manager is received from `TransactionContext`. src_loc: int | SrcLoc How many stack frames deep the source location is taken from. Alternatively, the source location to use instead of the default. @@ -72,15 +73,30 @@ def __init__( super().__init__(src_loc=get_src_loc(src_loc)) self.owner, owner_name = get_caller_class_name(default="$transaction") self.name = name or tracer.get_var_name(depth=2, default=owner_name) - if manager is None: - manager = DependencyContext.get().get_dependency(TransactionManagerKey()) - manager.add_transaction(self) + manager = DependencyContext.get().get_dependency(TransactionManagerKey()) + manager._add_transaction(self) self.request = Signal(name=self.owned_name + "_request") self.runnable = Signal(name=self.owned_name + "_runnable") self.grant = Signal(name=self.owned_name + "_grant") + @property + def _body(self) -> TBody: + if self._body_ptr is not None: + return TBody(self._body_ptr) + raise RuntimeError(f"Method '{self.name}' not defined") + + def _set_impl(self, m: TModule, value: Body): + if self._body_ptr is not None: + raise RuntimeError(f"Transaction '{self.name}' already defined") + if value.data_in.shape().size != 0 or value.data_out.shape().size != 0: + raise ValueError(f"Transaction body {value.name} has invalid interface") + self._body_ptr = value + m.d.comb += self.request.eq(value.ready) + m.d.comb += self.runnable.eq(value.runnable) + m.d.comb += self.grant.eq(value.run) + @contextmanager - def body(self, m: "TModule", *, request: ValueLike = C(1)) -> Iterator["Transaction"]: + def body(self, m: TModule, *, request: ValueLike = C(1)) -> Iterator["Transaction"]: """Defines the `Transaction` body. This context manager allows to conveniently define the actions @@ -99,13 +115,18 @@ def body(self, m: "TModule", *, request: ValueLike = C(1)) -> Iterator["Transact default it is `Const(1)`, so it wants to be executed in every clock cycle. """ - if self.defined: - raise RuntimeError(f"Transaction '{self.name}' already defined") - self.def_order = next(TransactionBase.def_counter) - - m.d.av_comb += self.request.eq(request) - with self.context(m): - with m.AvoidedIf(self.grant): + impl = Body( + name=self.name, + owner=self.owner, + i=StructLayout({}), + o=StructLayout({}), + src_loc=self.src_loc, + ) + self._set_impl(m, impl) + + m.d.av_comb += impl.ready.eq(request) + with impl.context(m): + with m.AvoidedIf(impl.run): yield self def __repr__(self) -> str: diff --git a/transactron/core/transaction_base.py b/transactron/core/transaction_base.py index 2f01118..acf272c 100644 --- a/transactron/core/transaction_base.py +++ b/transactron/core/transaction_base.py @@ -1,34 +1,21 @@ -from collections import defaultdict -from collections.abc import Iterator -from contextlib import contextmanager from enum import Enum, auto -from itertools import count +from dataclasses import KW_ONLY, dataclass from typing import ( - ClassVar, - TypeAlias, - TypedDict, - Union, - TypeVar, + Generic, Protocol, - Self, + TypeVar, runtime_checkable, - TYPE_CHECKING, - Optional, ) from amaranth import * -from .tmodule import TModule, CtrlPath from transactron.graph import Owned from transactron.utils import * -if TYPE_CHECKING: - from .method import Method - from .transaction import Transaction __all__ = ["TransactionBase", "Priority"] -TransactionOrMethod: TypeAlias = Union["Transaction", "Method"] -TransactionOrMethodBound = TypeVar("TransactionOrMethodBound", "Transaction", "Method") + +_T = TypeVar("_T", bound="TransactionBase") class Priority(Enum): @@ -40,41 +27,35 @@ class Priority(Enum): RIGHT = auto() -class RelationBase(TypedDict): - end: TransactionOrMethod - priority: Priority - conflict: bool - silence_warning: bool +@dataclass +class RelationBase(Generic[_T]): + _: KW_ONLY + end: _T + priority: Priority = Priority.UNDEFINED + conflict: bool = False + silence_warning: bool = False -class Relation(RelationBase): - start: TransactionOrMethod +@dataclass +class Relation(RelationBase[_T], Generic[_T]): + _: KW_ONLY + start: _T @runtime_checkable -class TransactionBase(Owned, Protocol): - stack: ClassVar[list[Union["Transaction", "Method"]]] = [] - def_counter: ClassVar[count] = count() - def_order: int - defined: bool = False - name: str +class TransactionBase(Owned, Protocol, Generic[_T]): src_loc: SrcLoc - method_uses: dict["Method", tuple[MethodStruct, Signal]] - method_calls: defaultdict["Method", list[tuple[CtrlPath, MethodStruct, ValueLike]]] - relations: list[RelationBase] - simultaneous_list: list[TransactionOrMethod] - independent_list: list[TransactionOrMethod] - ctrl_path: CtrlPath = CtrlPath(-1, []) - - def __init__(self, *, src_loc: int | SrcLoc): - self.src_loc = get_src_loc(src_loc) - self.method_uses = {} - self.method_calls = defaultdict(list) + relations: list[RelationBase[_T]] + simultaneous_list: list[_T] + independent_list: list[_T] + + def __init__(self, *, src_loc: SrcLoc): + self.src_loc = src_loc self.relations = [] self.simultaneous_list = [] self.independent_list = [] - def add_conflict(self, end: TransactionOrMethod, priority: Priority = Priority.UNDEFINED) -> None: + def add_conflict(self, end: _T, priority: Priority = Priority.UNDEFINED) -> None: """Registers a conflict. Record that that the given `Transaction` or `Method` cannot execute @@ -93,7 +74,7 @@ def add_conflict(self, end: TransactionOrMethod, priority: Priority = Priority.U RelationBase(end=end, priority=priority, conflict=True, silence_warning=self.owner != end.owner) ) - def schedule_before(self, end: TransactionOrMethod) -> None: + def schedule_before(self, end: _T) -> None: """Adds a priority relation. Record that that the given `Transaction` or `Method` needs to be @@ -109,7 +90,7 @@ def schedule_before(self, end: TransactionOrMethod) -> None: RelationBase(end=end, priority=Priority.LEFT, conflict=False, silence_warning=self.owner != end.owner) ) - def simultaneous(self, *others: TransactionOrMethod) -> None: + def simultaneous(self, *others: _T) -> None: """Adds simultaneity relations. The given `Transaction`\\s or `Method``\\s will execute simultaneously @@ -122,7 +103,7 @@ def simultaneous(self, *others: TransactionOrMethod) -> None: """ self.simultaneous_list += others - def simultaneous_alternatives(self, *others: TransactionOrMethod) -> None: + def simultaneous_alternatives(self, *others: _T) -> None: """Adds exclusive simultaneity relations. Each of the given `Transaction`\\s or `Method``\\s will execute @@ -139,7 +120,7 @@ def simultaneous_alternatives(self, *others: TransactionOrMethod) -> None: self.simultaneous(*others) others[0]._independent(*others[1:]) - def _independent(self, *others: TransactionOrMethod) -> None: + def _independent(self, *others: _T) -> None: """Adds independence relations. This `Transaction` or `Method`, together with all the given @@ -157,54 +138,3 @@ def _independent(self, *others: TransactionOrMethod) -> None: for execution. """ self.independent_list += others - - @contextmanager - def context(self: TransactionOrMethodBound, m: TModule) -> Iterator[TransactionOrMethodBound]: - self.ctrl_path = m.ctrl_path - - parent = TransactionBase.peek() - if parent is not None: - parent.schedule_before(self) - - TransactionBase.stack.append(self) - - try: - yield self - finally: - TransactionBase.stack.pop() - self.defined = True - - def _set_method_uses(self, m: ModuleLike): - for method, calls in self.method_calls.items(): - arg_rec, enable_sig = self.method_uses[method] - if len(calls) == 1: - m.d.comb += arg_rec.eq(calls[0][1]) - m.d.comb += enable_sig.eq(calls[0][2]) - else: - call_ens = Cat([en for _, _, en in calls]) - - for i in OneHotSwitchDynamic(m, call_ens): - m.d.comb += arg_rec.eq(calls[i][1]) - m.d.comb += enable_sig.eq(1) - - @classmethod - def get(cls) -> Self: - ret = cls.peek() - if ret is None: - raise RuntimeError("No current body") - return ret - - @classmethod - def peek(cls) -> Optional[Self]: - if not TransactionBase.stack: - return None - if not isinstance(TransactionBase.stack[-1], cls): - raise RuntimeError(f"Current body not a {cls.__name__}") - return TransactionBase.stack[-1] - - @property - def owned_name(self): - if self.owner is not None and self.owner.__class__.__name__ != self.name: - return f"{self.owner.__class__.__name__}_{self.name}" - else: - return self.name diff --git a/transactron/graph.py b/transactron/graph.py index 709ba87..c93e28a 100644 --- a/transactron/graph.py +++ b/transactron/graph.py @@ -14,6 +14,13 @@ class Owned(Protocol): name: str owner: Optional[Elaboratable] + @property + def owned_name(self): + if self.owner is not None and self.owner.__class__.__name__ != self.name: + return f"{self.owner.__class__.__name__}_{self.name}" + else: + return self.name + class Direction(IntFlag): NONE = 0 diff --git a/transactron/lib/adapters.py b/transactron/lib/adapters.py index 81816b3..be8b5a8 100644 --- a/transactron/lib/adapters.py +++ b/transactron/lib/adapters.py @@ -1,12 +1,12 @@ from abc import abstractmethod -from typing import Optional +from typing import Optional, Unpack from amaranth import * from amaranth.lib.wiring import Component, In, Out from amaranth.lib.data import StructLayout, View from ..utils import SrcLoc, get_src_loc, MethodStruct from ..core import * -from ..utils._typing import type_self_kwargs_as, SignalBundle +from ..utils._typing import SignalBundle, MethodLayout __all__ = [ "AdapterBase", @@ -100,8 +100,7 @@ class Adapter(AdapterBase): Hooks for `validate_arguments`. """ - @type_self_kwargs_as(Method.__init__) - def __init__(self, **kwargs): + def __init__(self, method: Method, /, **kwargs: Unpack[AdapterBodyParams]): """ Parameters ---------- @@ -110,12 +109,25 @@ def __init__(self, **kwargs): See transactron.core.Method.__init__ for parameters description. """ - kwargs["src_loc"] = get_src_loc(kwargs.setdefault("src_loc", 0)) - - iface = Method(**kwargs) - super().__init__(iface, iface.layout_out, iface.layout_in) + super().__init__(method, method.layout_out, method.layout_in) self.validators: list[tuple[View[StructLayout], Signal]] = [] self.with_validate_arguments: bool = False + self.kwargs = kwargs + + @staticmethod + def create( + name: Optional[str] = None, + i: MethodLayout = [], + o: MethodLayout = [], + src_loc: int | SrcLoc = 0, + **kwargs: Unpack[AdapterBodyParams], + ): + method = Method(name=name, i=i, o=o, src_loc=get_src_loc(src_loc)) + return Adapter(method, **kwargs) + + def update_args(self, **kwargs: Unpack[AdapterBodyParams]): + self.kwargs.update(kwargs) + return self def set(self, with_validate_arguments: Optional[bool]): if with_validate_arguments is not None: @@ -129,7 +141,7 @@ def elaborate(self, platform): data_in = Signal.like(self.data_in) m.d.comb += data_in.eq(self.data_in) - kwargs = {} + kwargs: BodyParams = self.kwargs # type: ignore (pyright complains about optional attribute) if self.with_validate_arguments: diff --git a/transactron/lib/allocators.py b/transactron/lib/allocators.py index 93b1845..1bdf3ae 100644 --- a/transactron/lib/allocators.py +++ b/transactron/lib/allocators.py @@ -93,7 +93,6 @@ def __init__(self, entries: int): self.free_idx = Method(i=[("idx", range(entries))]) self.order = Method( o=[("used", range(entries + 1)), ("order", ArrayLayout(range(self.entries), self.entries))], - nonexclusive=True, ) def elaborate(self, platform) -> TModule: @@ -125,7 +124,7 @@ def _(ident): m.d.comb += idx.eq(i) self.free_idx(m, idx=idx) - @def_method(m, self.order) + @def_method(m, self.order, nonexclusive=True) def _(): return {"used": used, "order": order} diff --git a/transactron/lib/fifo.py b/transactron/lib/fifo.py index 21cccff..17f7264 100644 --- a/transactron/lib/fifo.py +++ b/transactron/lib/fifo.py @@ -51,7 +51,7 @@ def __init__(self, layout: MethodLayout, depth: int, *, src_loc: int | SrcLoc = src_loc = get_src_loc(src_loc) self.read = Method(o=self.layout, src_loc=src_loc) - self.peek = Method(o=self.layout, nonexclusive=True, src_loc=src_loc) + self.peek = Method(o=self.layout, src_loc=src_loc) self.write = Method(i=self.layout, src_loc=src_loc) self.clear = Method(src_loc=src_loc) self.head = Signal(from_method_layout(layout)) @@ -106,7 +106,7 @@ def _() -> ValueLike: m.d.sync += self.read_idx.eq(next_read_idx) return self.head - @def_method(m, self.peek, self.read_ready) + @def_method(m, self.peek, self.read_ready, nonexclusive=True) def _() -> ValueLike: return self.head @@ -179,7 +179,7 @@ def __init__( self.depth = depth self.read = Method(i=[("count", range(read_width + 1))], o=self.read_layout, src_loc=src_loc) - self.peek = Method(o=self.read_layout, nonexclusive=True, src_loc=src_loc) + self.peek = Method(o=self.read_layout, src_loc=src_loc) self.write = Method(i=self.write_layout, src_loc=src_loc) self.clear = Method(src_loc=src_loc) @@ -275,7 +275,7 @@ def _(count): m.d.comb += incr_row_col(next_read_row, next_read_col, read_row, read_col, incr_read_row, read_count) return {"count": read_count, "data": head} - @def_method(m, self.peek, level != 0) + @def_method(m, self.peek, level != 0, nonexclusive=True) def _(): return {"count": read_available, "data": head} diff --git a/transactron/lib/metrics.py b/transactron/lib/metrics.py index 78f5c5e..fc1a4f6 100644 --- a/transactron/lib/metrics.py +++ b/transactron/lib/metrics.py @@ -11,6 +11,7 @@ from transactron import Method, def_method, TModule from transactron.lib import FIFO, AsyncMemoryBank, logging from transactron.utils.dependencies import ListKey, DependencyContext, SimpleKey +from transactron.utils.transactron_helpers import make_layout __all__ = [ "MetricRegisterModel", @@ -661,7 +662,7 @@ def elaborate(self, platform): epoch_width = bits_for(self.max_latency) m.submodules.slots = self.slots = AsyncMemoryBank( - data_layout=[("epoch", epoch_width)], elem_count=self.slots_number + shape=make_layout(("epoch", epoch_width)), depth=self.slots_number ) m.submodules.histogram = self.histogram @@ -690,7 +691,7 @@ def _(slot: Value): ret = self.slots.read(m, addr=slot) # The result of substracting two unsigned n-bit is a signed (n+1)-bit value, # so we need to cast the result and discard the most significant bit. - duration = (epoch - ret.epoch).as_unsigned()[:-1] + duration = (epoch - ret.data.epoch).as_unsigned()[:-1] self.histogram.add(m, duration) return m diff --git a/transactron/lib/reqres.py b/transactron/lib/reqres.py index d471125..c05bd69 100644 --- a/transactron/lib/reqres.py +++ b/transactron/lib/reqres.py @@ -67,7 +67,7 @@ def __init__(self, args_layout: MethodLayout, results_layout: MethodLayout, src_ self.args_layout = args_layout self.output_layout = [("args", self.args_layout), ("results", results_layout)] - self.peek_arg = Method(o=self.args_layout, nonexclusive=True, src_loc=self.src_loc) + self.peek_arg = Method(o=self.args_layout, src_loc=self.src_loc) self.write_args = Method(i=self.args_layout, src_loc=self.src_loc) self.write_results = Method(i=self.results_layout, src_loc=self.src_loc) self.read = Method(o=self.output_layout, src_loc=self.src_loc) diff --git a/transactron/lib/simultaneous.py b/transactron/lib/simultaneous.py index 7b00f93..35eca3b 100644 --- a/transactron/lib/simultaneous.py +++ b/transactron/lib/simultaneous.py @@ -1,8 +1,7 @@ from amaranth import * from ..utils import SrcLoc -from ..core import * -from ..core import TransactionBase +from ..core import TModule, Transaction, Body from contextlib import contextmanager from typing import Optional from transactron.utils import ValueLike @@ -56,8 +55,8 @@ def condition(m: TModule, *, nonblocking: bool = False, priority: bool = False): with branch(): # default, optional ... """ - this = TransactionBase.get() - transactions = list[Transaction]() + this = Body.get() + transactions = list[Body]() last = False conds = list[Signal]() @@ -73,10 +72,10 @@ def branch(cond: Optional[ValueLike] = None, *, src_loc: int | SrcLoc = 2): with (transaction := Transaction(name=name, src_loc=src_loc)).body(m, request=req): yield if transactions and priority: - transactions[-1].schedule_before(transaction) + transactions[-1].schedule_before(transaction._body) if cond is None: last = True - transactions.append(transaction) + transactions.append(transaction._body) yield branch diff --git a/transactron/lib/storage.py b/transactron/lib/storage.py index 46079d7..d4c7d89 100644 --- a/transactron/lib/storage.py +++ b/transactron/lib/storage.py @@ -1,6 +1,7 @@ from amaranth import * from amaranth.utils import * import amaranth.lib.memory as memory +from amaranth_types import ShapeLike import amaranth_types.memory as amemory from transactron.utils.transactron_helpers import from_method_layout, make_layout @@ -35,25 +36,26 @@ class MemoryBank(Elaboratable): def __init__( self, *, - data_layout: LayoutList, - elem_count: int, + shape: ShapeLike, + depth: int, granularity: Optional[int] = None, transparent: bool = False, read_ports: int = 1, write_ports: int = 1, - memory_type: amemory.AbstractMemoryConstructor[int, Value] = memory.Memory, + memory_type: amemory.AbstractMemoryConstructor[ShapeLike, Value] = memory.Memory, src_loc: int | SrcLoc = 0, ): """ Parameters ---------- - data_layout: method layout + shape: ShapeLike The format of structures stored in the Memory. - elem_count: int + depth: int Number of elements stored in Memory. granularity: Optional[int] - Granularity of write, forwarded to Amaranth. If `None` the whole structure is always saved at once. - If not, the width of `data_layout` is split into `granularity` parts, which can be saved independently. + Granularity of write. If `None` the whole structure is always saved at once. + If not, shape is split into `granularity` parts, which can be saved independently (according to + `amaranth.lib.memory` granularity logic). transparent: bool Read port transparency, false by default. When a read port is transparent, if a given memory address is read and written in the same clock cycle, the read returns the written value instead of the value @@ -67,37 +69,43 @@ def __init__( Alternatively, the source location to use instead of the default. """ self.src_loc = get_src_loc(src_loc) - self.data_layout = make_layout(*data_layout) - self.elem_count = elem_count + self.shape = shape + self.depth = depth self.granularity = granularity - self.width = from_method_layout(self.data_layout).size - self.addr_width = bits_for(self.elem_count - 1) + self.addr_width = bits_for(self.depth - 1) self.transparent = transparent self.reads_ports = read_ports self.writes_ports = write_ports self.memory_type = memory_type self.read_reqs_layout: LayoutList = [("addr", self.addr_width)] - write_layout = [("addr", self.addr_width), ("data", self.data_layout)] + self.read_resps_layout = make_layout(("data", self.shape)) + write_layout = [("addr", self.addr_width), ("data", self.shape)] if self.granularity is not None: - write_layout.append(("mask", self.width // self.granularity)) + # use Amaranth lib.memory granularity rule checks and width + amaranth_write_port_sig = memory.WritePort.Signature( + addr_width=0, + shape=self.shape, # type: ignore + granularity=granularity, + ) + write_layout.append(("mask", amaranth_write_port_sig.members["en"].shape)) self.writes_layout = make_layout(*write_layout) self.read_req = Methods(read_ports, i=self.read_reqs_layout, src_loc=self.src_loc) - self.read_resp = Methods(read_ports, o=self.data_layout, src_loc=self.src_loc) + self.read_resp = Methods(read_ports, o=self.read_resps_layout, src_loc=self.src_loc) self.write = Methods(write_ports, i=self.writes_layout, src_loc=self.src_loc) def elaborate(self, platform) -> TModule: m = TModule() - m.submodules.mem = self.mem = mem = self.memory_type(shape=self.width, depth=self.elem_count, init=[]) - write_port = [mem.write_port() for _ in range(self.writes_ports)] + m.submodules.mem = self.mem = mem = self.memory_type(shape=self.shape, depth=self.depth, init=[]) + write_port = [mem.write_port(granularity=self.granularity) for _ in range(self.writes_ports)] read_port = [ mem.read_port(transparent_for=write_port if self.transparent else []) for _ in range(self.reads_ports) ] read_output_valid = [Signal() for _ in range(self.reads_ports)] overflow_valid = [Signal() for _ in range(self.reads_ports)] - overflow_data = [Signal(self.width) for _ in range(self.reads_ports)] + overflow_data = [Signal(self.shape) for _ in range(self.reads_ports)] # The read request method can be called at most twice when not reading the response. # The first result is stored in the overflow buffer, the second - in the read value buffer of the memory. @@ -114,7 +122,13 @@ def _(i: int): m.d.sync += overflow_valid[i].eq(0) with m.Else(): m.d.sync += read_output_valid[i].eq(0) - return Mux(overflow_valid[i], overflow_data[i], read_port[i].data) + + ret = Signal(self.shape) + with m.If(overflow_valid[i]): + m.d.av_comb += ret.eq(overflow_data[i]) + with m.Else(): + m.d.av_comb += ret.eq(read_port[i].data) + return {"data": ret} for i in range(self.reads_ports): m.d.comb += read_port[i].en.eq(0) # because the init value is 1 @@ -123,12 +137,12 @@ def _(i: int): def _(i: int, addr): m.d.sync += read_output_valid[i].eq(1) m.d.comb += read_port[i].en.eq(1) - m.d.comb += read_port[i].addr.eq(addr) + m.d.av_comb += read_port[i].addr.eq(addr) @def_methods(m, self.write) def _(i: int, arg): - m.d.comb += write_port[i].addr.eq(arg.addr) - m.d.comb += write_port[i].data.eq(arg.data) + m.d.av_comb += write_port[i].addr.eq(arg.addr) + m.d.av_comb += write_port[i].data.eq(arg.data) if self.granularity is None: m.d.comb += write_port[i].en.eq(1) else: @@ -254,24 +268,25 @@ class AsyncMemoryBank(Elaboratable): def __init__( self, *, - data_layout: LayoutList, - elem_count: int, + shape: ShapeLike, + depth: int, granularity: Optional[int] = None, read_ports: int = 1, write_ports: int = 1, - memory_type: amemory.AbstractMemoryConstructor[int, Value] = memory.Memory, + memory_type: amemory.AbstractMemoryConstructor[ShapeLike, Value] = memory.Memory, src_loc: int | SrcLoc = 0, ): """ Parameters ---------- - data_layout: method layout + shape: ShapeLike The format of structures stored in the Memory. - elem_count: int + depth: int Number of elements stored in Memory. granularity: Optional[int] - Granularity of write, forwarded to Amaranth. If `None` the whole structure is always saved at once. - If not, the width of `data_layout` is split into `granularity` parts, which can be saved independently. + Granularity of write. If `None` the whole structure is always saved at once. + If not, shape is split into `granularity` parts, which can be saved independently (according to + `amaranth.lib.memory` granularity logic). read_ports: int Number of read ports. write_ports: int @@ -281,36 +296,42 @@ def __init__( Alternatively, the source location to use instead of the default. """ self.src_loc = get_src_loc(src_loc) - self.data_layout = make_layout(*data_layout) - self.elem_count = elem_count + self.shape = shape + self.depth = depth self.granularity = granularity - self.width = from_method_layout(self.data_layout).size - self.addr_width = bits_for(self.elem_count - 1) + self.addr_width = bits_for(self.depth - 1) self.reads_ports = read_ports self.writes_ports = write_ports self.memory_type = memory_type self.read_reqs_layout: LayoutList = [("addr", self.addr_width)] - write_layout = [("addr", self.addr_width), ("data", self.data_layout)] + self.read_resps_layout: LayoutList = [("data", self.shape)] + write_layout = [("addr", self.addr_width), ("data", self.shape)] if self.granularity is not None: - write_layout.append(("mask", self.width // self.granularity)) + # use Amaranth lib.memory granularity rule checks and width + amaranth_write_port_sig = memory.WritePort.Signature( + addr_width=0, + shape=shape, # type: ignore + granularity=granularity, + ) + write_layout.append(("mask", amaranth_write_port_sig.members["en"].shape)) self.writes_layout = make_layout(*write_layout) - self.read = Methods(read_ports, i=self.read_reqs_layout, o=self.data_layout, src_loc=self.src_loc) + self.read = Methods(read_ports, i=self.read_reqs_layout, o=self.read_resps_layout, src_loc=self.src_loc) self.write = Methods(write_ports, i=self.writes_layout, src_loc=self.src_loc) def elaborate(self, platform) -> TModule: m = TModule() - mem = self.memory_type(shape=self.width, depth=self.elem_count, init=[]) + mem = self.memory_type(shape=self.shape, depth=self.depth, init=[]) m.submodules.mem = self.mem = mem - write_port = [mem.write_port() for _ in range(self.writes_ports)] + write_port = [mem.write_port(granularity=self.granularity) for _ in range(self.writes_ports)] read_port = [mem.read_port(domain="comb") for _ in range(self.reads_ports)] @def_methods(m, self.read) def _(i: int, addr): m.d.comb += read_port[i].addr.eq(addr) - return read_port[i].data + return {"data": read_port[i].data} @def_methods(m, self.write) def _(i: int, arg): diff --git a/transactron/lib/transformers.py b/transactron/lib/transformers.py index d3198d7..0a75cef 100644 --- a/transactron/lib/transformers.py +++ b/transactron/lib/transformers.py @@ -177,7 +177,7 @@ def __init__( self.target = target self.use_condition = use_condition src_loc = get_src_loc(src_loc) - self.method = Method(i=target.layout_in, o=target.layout_out, single_caller=self.use_condition, src_loc=src_loc) + self.method = Method(i=target.layout_in, o=target.layout_out, src_loc=src_loc) self.condition = condition self.default = default @@ -189,7 +189,7 @@ def elaborate(self, platform): ret = Signal.like(self.target.data_out) m.d.comb += assign(ret, self.default, fields=AssignType.ALL) - @def_method(m, self.method) + @def_method(m, self.method, single_caller=self.use_condition) def _(arg): if self.use_condition: cond = Signal() diff --git a/transactron/testing/infrastructure.py b/transactron/testing/infrastructure.py index b65ffd7..43edd97 100644 --- a/transactron/testing/infrastructure.py +++ b/transactron/testing/infrastructure.py @@ -10,6 +10,8 @@ from amaranth import * from amaranth.sim import * from amaranth.sim._async import SimulatorContext +from transactron.core.method import MethodDir +from transactron.lib.adapters import Adapter from transactron.utils.dependencies import DependencyContext, DependencyManager from .testbenchio import TestbenchIO @@ -58,6 +60,7 @@ def __getattr__(self, name: str) -> Any: def elaborate(self, platform): def transform_methods_to_testbenchios( + adapter_type: type[Adapter] | type[AdapterTrans], container: _T_nested_collection[Method | Methods], ) -> tuple[ _T_nested_collection["TestbenchIO"], @@ -67,7 +70,7 @@ def transform_methods_to_testbenchios( tb_list = [] mc_list = [] for elem in container: - tb, mc = transform_methods_to_testbenchios(elem) + tb, mc = transform_methods_to_testbenchios(adapter_type, elem) tb_list.append(tb) mc_list.append(mc) return tb_list, ModuleConnector(*mc_list) @@ -75,24 +78,29 @@ def transform_methods_to_testbenchios( tb_dict = {} mc_dict = {} for name, elem in container.items(): - tb, mc = transform_methods_to_testbenchios(elem) + tb, mc = transform_methods_to_testbenchios(adapter_type, elem) tb_dict[name] = tb mc_dict[name] = mc return tb_dict, ModuleConnector(*mc_dict) elif isinstance(container, Methods): - tb_list = [TestbenchIO(AdapterTrans(method)) for method in container] + tb_list = [TestbenchIO(adapter_type(method)) for method in container] return list(tb_list), ModuleConnector(*tb_list) else: - tb = TestbenchIO(AdapterTrans(container)) + tb = TestbenchIO(adapter_type(container)) return tb, tb m = Module() m.submodules.dut = self._dut + hints = self._dut.__class__.__annotations__ for name, attr in vars(self._dut).items(): if guard_nested_collection(attr, Method | Methods) and attr: - tb_cont, mc = transform_methods_to_testbenchios(attr) + if name in hints and MethodDir.REQUIRED in hints[name].__metadata__: + adapter_type = Adapter + else: # PROVIDED is the default + adapter_type = AdapterTrans + tb_cont, mc = transform_methods_to_testbenchios(adapter_type, attr) self._io[name] = tb_cont m.submodules[name] = mc @@ -166,20 +174,16 @@ def __init__( else: self.ctx = nullcontext() - self.timeouted = False - async def timeout_testbench(sim: SimulatorContext): await sim.delay(clk_period * max_cycles) - self.timeouted = True + assert False, "simulation timed out" self.add_testbench(timeout_testbench, background=True) - def run(self) -> bool: + def run(self) -> None: with self.ctx: super().run() - return not self.timeouted - class TestCaseWithSimulator: dependency_manager: DependencyManager @@ -316,8 +320,7 @@ def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_tra if ret is not None: sim.add_process(ret) - res = sim.run() - assert res, "Simulation time limit exceeded" + sim.run() async def tick(self, sim: SimulatorContext, cycle_cnt: int = 1): """ diff --git a/transactron/testing/method_mock.py b/transactron/testing/method_mock.py index f3fdfe0..b6dc0a4 100644 --- a/transactron/testing/method_mock.py +++ b/transactron/testing/method_mock.py @@ -1,8 +1,9 @@ from contextlib import contextmanager import functools -from typing import Callable, Any, Optional +from typing import Callable, Any, Optional, Unpack from amaranth.sim._async import SimulatorContext +from transactron.core.body import AdapterBodyParams from transactron.lib.adapters import Adapter, AdapterBase from transactron.utils.transactron_helpers import async_mock_def_helper from .testbenchio import TestbenchIO @@ -21,7 +22,13 @@ def __init__( validate_arguments: Optional[Callable[..., bool]] = None, enable: Callable[[], bool] = lambda: True, delay: float = 0, + **kwargs: Unpack[AdapterBodyParams], ): + if isinstance(adapter, Adapter): + adapter.set(with_validate_arguments=validate_arguments is not None).update_args(**kwargs) + else: + assert validate_arguments is None + assert kwargs == {} self.adapter = adapter self.function = function self.validate_arguments = validate_arguments diff --git a/transactron/testing/profiler.py b/transactron/testing/profiler.py index ace2b63..d699c15 100644 --- a/transactron/testing/profiler.py +++ b/transactron/testing/profiler.py @@ -20,7 +20,7 @@ async def process(sim: ProcessContext) -> None: sim.tick() .sample( *( - View(transaction_sample_layout, Cat(transaction.request, transaction.runnable, transaction.grant)) + View(transaction_sample_layout, Cat(transaction.ready, transaction.runnable, transaction.run)) for transaction in method_map.transactions ) ) diff --git a/transactron/testing/testbenchio.py b/transactron/testing/testbenchio.py index 0553184..2a769cf 100644 --- a/transactron/testing/testbenchio.py +++ b/transactron/testing/testbenchio.py @@ -88,6 +88,12 @@ async def until_done(self) -> Any: if any(res is not None for res in results): return results + async def until_all_done(self) -> Any: + """Same as `until_done` but wait for all results instead of any result.""" + async for results in self: + if all(res is not None for res in results): + return results + def __await__(self) -> Generator: only_calls = [t for t in self.calls_and_values if isinstance(t, tuple)] only_values = [t for t in self.calls_and_values if not isinstance(t, tuple)] diff --git a/transactron/utils/amaranth_ext/functions.py b/transactron/utils/amaranth_ext/functions.py index d046240..b273c0a 100644 --- a/transactron/utils/amaranth_ext/functions.py +++ b/transactron/utils/amaranth_ext/functions.py @@ -13,6 +13,7 @@ "popcount", "count_leading_zeros", "count_trailing_zeros", + "cyclic_mask", "flatten_signals", "shape_of", "const_of", @@ -76,6 +77,26 @@ def count_trailing_zeros(s: Value) -> Value: return count_leading_zeros(s[::-1]) +def cyclic_mask(bits: int, start: Value, end: Value): + """ + Generate `bits` bit-wide mask with ones from `start` to `end` position, including both ends. + If `end` value is < than `start` the mask wraps around. + """ + start = start.as_unsigned() + end = end.as_unsigned() + + # start <= end + length = (end - start + 1).as_unsigned() + mask_se = ((1 << length) - 1) << start + + # start > end + left = (1 << (end + 1)) - 1 + right = (1 << ((bits - start).as_unsigned())) - 1 + mask_es = left | (right << start) + + return Mux(start <= end, mask_se, mask_es) + + def flatten_signals(signals: SignalBundle) -> Iterable[Signal]: """ Flattens input data, which can be either a signal, a record, a list (or a dict) of SignalBundle items. diff --git a/transactron/utils/gen.py b/transactron/utils/gen.py index 780e151..cc3ad6d 100644 --- a/transactron/utils/gen.py +++ b/transactron/utils/gen.py @@ -193,9 +193,9 @@ def collect_transaction_method_signals( get_id = IdGenerator() for transaction in method_map.transactions: - request_loc = get_signal_location(transaction.request, name_map) + request_loc = get_signal_location(transaction.ready, name_map) runnable_loc = get_signal_location(transaction.runnable, name_map) - grant_loc = get_signal_location(transaction.grant, name_map) + grant_loc = get_signal_location(transaction.run, name_map) transaction_signals_location[get_id(transaction)] = TransactionSignalsLocation( request_loc, runnable_loc, grant_loc ) diff --git a/transactron/utils/transactron_helpers.py b/transactron/utils/transactron_helpers.py index 9cb23cd..b438825 100644 --- a/transactron/utils/transactron_helpers.py +++ b/transactron/utils/transactron_helpers.py @@ -9,6 +9,7 @@ from amaranth import tracer from amaranth.lib.data import StructLayout import amaranth.lib.data as data +import dataclasses __all__ = [ @@ -23,6 +24,7 @@ "from_method_layout", "make_layout", "extend_layout", + "dataclass_asdict", ] T = TypeVar("T") @@ -167,3 +169,9 @@ def from_method_layout(layout: MethodLayout) -> StructLayout: return layout else: return StructLayout({k: from_layout_field(v) for k, v in layout}) + + +def dataclass_asdict(obj: Any) -> dict[str, Any]: + # Workaround for dataclass.asdict calling deepcopy without a reason, see: + # https://github.com/python/cpython/issues/88071 + return {field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)}