diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index b9af6482b..fc4f73052 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -36,9 +36,6 @@ jobs: # https://github.com/actions/runner/issues/2033 chown -R $(id -u):$(id -g) $PWD - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: @@ -114,9 +111,6 @@ jobs: # https://github.com/actions/runner/issues/2033 chown -R $(id -u):$(id -g) $PWD - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: @@ -165,9 +159,6 @@ jobs: # https://github.com/actions/runner/issues/2033 chown -R $(id -u):$(id -g) $PWD - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: diff --git a/.github/workflows/deploy_gh_pages.yml b/.github/workflows/deploy_gh_pages.yml index 07dad5c22..eaf35d90a 100644 --- a/.github/workflows/deploy_gh_pages.yml +++ b/.github/workflows/deploy_gh_pages.yml @@ -23,9 +23,6 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6609cac4d..3d3650c57 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -20,9 +20,6 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: @@ -153,9 +150,6 @@ jobs: git config --global --add safe.directory /__w/coreblocks/coreblocks git submodule > .gitmodules-hash - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: @@ -263,9 +257,6 @@ jobs: git config --global --add safe.directory /__w/coreblocks/coreblocks git submodule > .gitmodules-hash - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: @@ -318,9 +309,6 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: @@ -353,9 +341,6 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: diff --git a/.gitmodules b/.gitmodules index e1e6ec15d..8dea05eb8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -8,6 +8,3 @@ [submodule "test/external/riscof/riscv-arch-test"] path = test/external/riscof/riscv-arch-test url = https://github.com/riscv-non-isa/riscv-arch-test.git -[submodule "amaranth-stubs"] - path = amaranth-stubs - url = https://github.com/kuznia-rdzeni/amaranth-stubs.git diff --git a/README.md b/README.md index dcd7fb056..5a3b4d0ac 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ Coreblocks is an experimental, modular out-of-order [RISC-V](https://riscv.org/s * Simplicity. Coreblocks is an academic project, accessible to students. It should be suitable for teaching essentials of out-of-order architectures. * Modularity. We want to be able to easily experiment with the core by adding, replacing and modifying modules without changing the source too much. - For this goal, we designed a [transaction system](https://kuznia-rdzeni.github.io/coreblocks/Transactions.html) inspired by [Bluespec](http://wiki.bluespec.com/). + For this goal, we designed a transaction system called [Transactron](https://github.com/kuznia-rdzeni/transactron), which is inspired by [Bluespec](http://wiki.bluespec.com/). * Fine-grained testing. Outside of the integration tests for the full core, modules are tested individually. This is to support an agile style of development. @@ -25,9 +25,6 @@ The core currently supports the full RV32I instruction set and several extension Exceptions and some of machine-mode CSRs are supported, the support for interrupts is currently rudimentary and incompatible with the RISC-V spec. Coreblocks can be used with [LiteX](https://github.com/enjoy-digital/litex) (currently using a [patched version](https://github.com/kuznia-rdzeni/litex/tree/coreblocks)). -The transaction system we use as the foundation for the core is well-tested and usable. -We plan to make it available as a separate Python package. - ## Documentation The [documentation for our project](https://kuznia-rdzeni.github.io/coreblocks/) is automatically generated using [Sphinx](https://www.sphinx-doc.org/). diff --git a/amaranth-stubs b/amaranth-stubs deleted file mode 160000 index edb302b00..000000000 --- a/amaranth-stubs +++ /dev/null @@ -1 +0,0 @@ -Subproject commit edb302b001433edf4c8568190adc9bd0c0039f45 diff --git a/coreblocks/core.py b/coreblocks/core.py index 913654f1c..308b835a4 100644 --- a/coreblocks/core.py +++ b/coreblocks/core.py @@ -136,8 +136,9 @@ def elaborate(self, platform): m.submodules.exception_information_register = self.exception_information_register - if self.connections.dependency_provided(FetchResumeKey()): - fetch_resume_fb, fetch_resume_unifiers = self.connections.get_dependency(FetchResumeKey()) + fetch_resume = self.connections.get_optional_dependency(FetchResumeKey()) + if fetch_resume is not None: + fetch_resume_fb, fetch_resume_unifiers = fetch_resume m.submodules.fetch_resume_unifiers = ModuleConnector(**fetch_resume_unifiers) m.submodules.fetch_resume_connector = ConnectTrans(fetch_resume_fb, self.frontend.resume_from_unsafe) diff --git a/coreblocks/frontend/fetch/fetch.py b/coreblocks/frontend/fetch/fetch.py index d04735f7b..d26906e52 100644 --- a/coreblocks/frontend/fetch/fetch.py +++ b/coreblocks/frontend/fetch/fetch.py @@ -403,8 +403,9 @@ def _(): expect_unstall_unsafe = Signal() prev_stalled_unsafe = Signal() dependencies = DependencyContext.get() - if dependencies.dependency_provided(FetchResumeKey()): - unifier_ready = DependencyContext.get().get_dependency(FetchResumeKey())[0].ready + fetch_resume = dependencies.get_optional_dependency(FetchResumeKey()) + if fetch_resume is not None: + unifier_ready = fetch_resume[0].ready else: unifier_ready = C(0) diff --git a/docs/api.md b/docs/api.md index 5daa246b7..226f38e51 100644 --- a/docs/api.md +++ b/docs/api.md @@ -2,5 +2,4 @@ ```{eval-rst} .. include:: modules-coreblocks.rst -.. include:: modules-transactron.rst ``` diff --git a/docs/index.md b/docs/index.md index 0e16a25ec..6a9b5afba 100644 --- a/docs/index.md +++ b/docs/index.md @@ -8,7 +8,6 @@ maxdepth: 3 home.md assumptions.md development-environment.md -transactions.md scheduler/overview.md shared-structs/implementation/rs-impl.md shared-structs/rs.md diff --git a/docs/transactions.md b/docs/transactions.md deleted file mode 100644 index 41b5d5528..000000000 --- a/docs/transactions.md +++ /dev/null @@ -1,336 +0,0 @@ -# Documentation for Coreblocks transaction framework - -## Introduction - -Coreblocks utilizes a transaction framework for modularizing the design. -It is inspired by the [Bluespec](http://bluespec.com/) programming language (see: [Bluespec wiki](http://wiki.bluespec.com/), [Bluespec compiler](https://github.com/B-Lang-org/bsc)). - -The basic idea is to interface hardware modules using _transactions_ and _methods_. -A transaction is a state-changing operation performed by the hardware in a single clock cycle. -Transactions are atomic: in a given clock cycle, a transaction either executes in its entriety, or not at all. -A transaction is executed only if it is ready for execution and it does not _conflict_ with another transaction scheduled for execution in the same clock cycle. - -A transaction defined in a given hardware module can depend on other hardware modules via the use of methods. -A method can be _called_ by a transaction or by other methods. -Execution of methods is directly linked to the execution of transactions: a method only executes if some transaction which calls the method (directly or indirectly, via other methods) is executed. -If multiple transactions try to call the same method in the same clock cycle, the transactions conflict, and only one of them is executed. -In this way, access to methods is coordinated via the transaction system to avoid conflicts. - -Methods can communicate with their callers in both directions: from caller to method and back. -The communication is structured using Amaranth records. - -## Basic usage - -### Implementing transactions - -The simplest way to implement a transaction as a part of Amaranth `Elaboratable` is by using a `with` block: - -```python -class MyThing(Elaboratable): - ... - - def elaborate(self, platform): - m = TModule() - - ... - - with Transaction().body(m): - # Operations conditioned on the transaction executing. - # Including Amaranth assignments, like: - - m.d.comb += sig1.eq(expr1) - m.d.sync += sig2.eq(expr2) - - # Method calls can also be used, like: - - result = self.method(m, arg_expr) - - ... - - return m -``` - -The transaction body `with` block works analogously to Amaranth's `with m.If():` blocks: the Amaranth assignments and method calls only "work" in clock cycles when the transaction is executed. -This is implemented in hardware via multiplexers. -Please remember that this is not a Python `if` statement -- the *Python code* inside the `with` block is always executed once. - -### Implementing methods - -As methods are used as a way to communicate with other `Elaboratable`s, they are typically declared in the `Elaboratable`'s constructor, and then defined in the `elaborate` method: - -```python -class MyOtherThing(Elaboratable): - def __init__(self): - ... - - # Declaration of the method. - # The i/o parameters pass the format of method argument/result as Amaranth layouts. - # Both parameters are optional. - - self.my_method = Method(i=input_layout, o=output_layout) - - ... - - def elaborate(self, platform): - # A TModule needs to be used instead of an Amaranth module - - m = TModule() - - ... - - @def_method(m, self.my_method) - def _(arg): - # Operations conditioned on the method executing. - # Including Amaranth assignments, like: - - m.d.comb += sig1.eq(expr1) - m.d.sync += sig2.eq(expr2) - - # Method calls can also be used, like: - - result = self.other_method(m, arg_expr) - - # Method result should be returned: - - return ret_expr - - ... - - return m -``` - -The `def_method` technique presented above is a convenience syntax, but it works just like other Amaranth `with` blocks. -In particular, the *Python code* inside the unnamed `def` function is always executed once. - -A method defined in one `Elaboratable` is usually passed to other `Elaboratable`s via constructor parameters. -For example, the `MyThing` constructor could be defined as follows. -Only methods should be passed around, not entire `Elaboratable`s! - -```python -class MyThing(Elaboratable): - def __init__(self, method: Method): - self.method = method - - ... - - ... -``` - -### Method or transaction? - -Sometimes, there might be two alternative ways to implement some functionality: - -* Using a transaction, which calls methods on other `Elaboratable`s. -* Using a method, which is called from other `Elaboratable`s. - -Deciding on a best method is not always easy. -An important question to ask yourself is -- is this functionality something that runs independently from other things (not in lock-step)? -If so, maybe it should be a transaction. -Or is it something that is dependent on some external condition? -If so, maybe it should be a method. - -If in doubt, methods are preferred. -This is because if a functionality is implemented as a method, and a transaction is needed, one can use a transaction which calls this method and does nothing else. -Such a transaction is included in the library -- it's named `AdapterTrans`. - -### Method argument passing conventions - -Even though method arguments are Amaranth records, their use can be avoided in many cases, which results in cleaner code. -Suppose we have the following layout, which is an input layout for a method called `method`: - -```python -layout = [("foo", 1), ("bar", 32)] -method = Method(input_layout=layout) -``` - -The method can be called in multiple ways. -The cleanest and recommended way is to pass each record field using a keyword argument: - -```python -method(m, foo=foo_expr, bar=bar_expr) -``` - -Another way is to pass the arguments using a `dict`: - -```python -method(m, {'foo': foo_expr, 'bar': bar_expr}) -``` - -Finally, one can directly pass an Amaranth record: - -```python -rec = Record(layout) -m.d.comb += rec.foo.eq(foo_expr) -m.d.comb += rec.bar.eq(bar_expr) -method(m, rec) -``` - -The `dict` convention can be used recursively when layouts are nested. -Take the following definitions: - -```python -layout2 = [("foobar", layout), ("baz", 42)] -method2 = Method(input_layout=layout2) -``` - -One can then pass the arguments using `dict`s in following ways: - -```python -# the preferred way -method2(m, foobar={'foo': foo_expr, 'bar': bar_expr}, baz=baz_expr) - -# the alternative way -method2(m, {'foobar': {'foo': foo_expr, 'bar': bar_expr}, 'baz': baz_expr}) -``` - -### Method definition conventions - -When defining methods, two conventions can be used. -The cleanest and recommended way is to create an argument for each record field: - -```python -@def_method(m, method) -def _(foo: Value, bar: Value): - ... -``` - -The other is to receive the argument record directly. The `arg` name is required: - -```python -def_method(m, method) -def _(arg: Record): - ... -``` - -### Method return value conventions - -The `dict` syntax can be used for returning values from methods. -Take the following method declaration: - -```python -method3 = Method(input_layout=layout, output_layout=layout2) -``` - -One can then define this method as follows: - -```python -@def_method(m, method3) -def _(foo: Value, bar: Value): - return {{'foo': foo, 'bar': foo + bar}, 'baz': foo - bar} -``` - -### Readiness signals - -If a transaction is not always ready for execution (for example, because of the dependence on some resource), a `request` parameter should be used. -An Amaranth single-bit expression should be passed. -When the `request` parameter is not passed, the transaction is always requesting execution. - -```python - with Transaction().body(m, request=expr): -``` - -Methods have a similar mechanism, which uses the `ready` parameter on `def_method`: - -```python - @def_method(m, self.my_method, ready=expr) - def _(arg): - ... -``` - -The `request` signal typically should only depend on the internal state of an `Elaboratable`. -Other dependencies risk introducing combinational loops. -In certain occasions, it is possible to relax this requirement; see e.g. [Scheduling order](#scheduling-order). - -## The library - -The transaction framework is designed to facilitate code re-use. -It includes a library, which contains `Elaboratable`s providing useful methods and transactions. -The most useful ones are: - -* `ConnectTrans`, for connecting two methods together with a transaction. -* `FIFO`, for queues accessed with two methods, `read` and `write`. -* `Adapter` and `AdapterTrans`, for communicating with transactions and methods from plain Amaranth code. - These are very useful in testbenches. - -## Advanced concepts - -### Special combinational domains - -Transactron defines its own variant of Amaranth modules, called `TModule`. -Its role is to allow to improve circuit performance by omitting unneeded multiplexers in combinational circuits. -This is done by adding two additional, special combinatorial domains, `av_comb` and `top_comb`. - -Statements added to the `av_comb` domain (the "avoiding" domain) are not executed when under a false `m.If`, but are executed when under a false `m.AvoidedIf`. -Transaction and method bodies are internally guarded by an `m.AvoidedIf` with the transaction `grant` or method `run` signal. -Therefore combinational assignments added to `av_comb` work even if the transaction or method definition containing the assignments are not running. -Because combinational signals usually don't induce state changes, this is often safe to do and improves performance. - -Statements added to the `top_comb` domain are always executed, even if the statement is under false conditions (including `m.If`, `m.Switch` etc.). -This allows for cleaner code, as combinational assignments which logically belong to some case, but aren't actually required to be there, can be as performant as if they were manually moved to the top level. - -An important caveat of the special domains is that, just like with normal domains, a signal assigned in one of them cannot be assigned in others. - -### Scheduling order - -When writing multiple methods and transactions in the same `Elaboratable`, sometimes some dependency between them needs to exist. -For example, in the `Forwarder` module in the library, forwarding can take place only if both `read` and `write` are executed simultaneously. -This requirement is handled by making the the `read` method's readiness depend on the execution of the `write` method. -If the `read` method was considered for execution before `write`, this would introduce a combinational loop into the circuit. -In order to avoid such issues, one can require a certain scheduling order between methods and transactions. - -`Method` and `Transaction` objects include a `schedule_before` method. -Its only argument is another `Method` or `Transaction`, which will be scheduled after the first one: - -```python -first_t_or_m.schedule_before(other_t_or_m) -``` - -Internally, scheduling orders exist only on transactions. -If a scheduling order is added to a `Method`, it is lifted to the transaction level. -For example, if `first_m` is scheduled before `other_t`, and is called by `t1` and `t2`, the added scheduling orderings will be the same as if the following calls were made: - -```python -t1.schedule_before(other_t) -t2.schedule_before(other_t) -``` - -### Conflicts - -In some situations it might be useful to make some methods or transactions mutually exclusive with others. -Two conflicting transactions or methods can't execute simultaneously: only one or the other runs in a given clock cycle. - -Conflicts are defined similarly to scheduling orders: - -```python -first_t_or_m.add_conflict(other_t_or_m) -``` - -Conflicts are lifted to the transaction level, just like scheduling orders. - -The `add_conflict` method has an optional argument `priority`, which allows to define a scheduling order between conflicting transactions or methods. -Possible values are `Priority.LEFT`, `Priority.RIGHT` and `Priority.UNDEFINED` (the default). -For example, the following code adds a conflict with a scheduling order, where `first_m` is scheduled before `other_m`: - -```python -first_m.add_conflict(other_m, priority = Priority.LEFT) -``` - -Scheduling conflicts come with a possible cost. -The conflicting transactions have a dependency in the transaction scheduler, which can increase the size and combinational delay of the scheduling circuit. -Therefore, use of this feature requires consideration. - -### Transaction and method nesting - -Transaction and method bodies can be nested. For example: - -```python -with Transaction().body(m): - # Transaction body. - - with Transaction().body(m): - # Nested transaction body. -``` - -Nested transactions and methods can only run if the parent also runs. -The converse is not true: it is possible that only the parent runs, but the nested transaction or method doesn't (because of other limitations). -Nesting implies scheduling order: the nested transaction or method is considered for execution after the parent. diff --git a/requirements.txt b/requirements.txt index 73a29a158..19b874c8c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -./amaranth-stubs/ # can't use -e -- pyright doesn't see the stubs then :( +amaranth-stubs @ git+https://github.com/kuznia-rdzeni/amaranth-stubs.git@edb302b001433edf4c8568190adc9bd0c0039f45 +transactron @ git+https://github.com/kuznia-rdzeni/transactron.git@972047b7bfac3d2e193a428de35c976f9b17c51a amaranth-yosys==0.40.0.0.post100 amaranth==0.5.3 dataclasses-json==0.6.3 diff --git a/scripts/build_docs.sh b/scripts/build_docs.sh index 6f58a5a6b..40e56ba89 100755 --- a/scripts/build_docs.sh +++ b/scripts/build_docs.sh @@ -60,5 +60,4 @@ $ROOT_PATH/scripts/core_graph.py -p -f mermaid $DOCS_DIR/auto_graph.rst sed -i -e '1i\.. mermaid::\n' -e 's/^/ /' $DOCS_DIR/auto_graph.rst sphinx-apidoc --tocfile modules-coreblocks -o $DOCS_DIR $ROOT_PATH/coreblocks/ -sphinx-apidoc --tocfile modules-transactron -o $DOCS_DIR $ROOT_PATH/transactron/ sphinx-build -b html -W $DOCS_DIR $BUILD_DIR diff --git a/test/transactron/__init__.py b/test/transactron/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/transactron/core/__init__.py b/test/transactron/core/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/transactron/core/test_transactions.py b/test/transactron/core/test_transactions.py deleted file mode 100644 index fd4f9e7e0..000000000 --- a/test/transactron/core/test_transactions.py +++ /dev/null @@ -1,458 +0,0 @@ -from abc import abstractmethod -from unittest.case import TestCase -from amaranth_types import HasElaborate -import pytest -from amaranth import * -from amaranth.sim import * - -import random -import contextlib - -from collections import deque -from typing import Iterable, Callable -from parameterized import parameterized, parameterized_class - -from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout - -from transactron import * -from transactron.lib import Adapter, AdapterTrans -from transactron.utils import Scheduler - -from transactron.core import Priority -from transactron.core.schedulers import trivial_roundrobin_cc_scheduler, eager_deterministic_cc_scheduler -from transactron.core.manager import TransactionScheduler -from transactron.utils.dependencies import DependencyContext - - -class TestNames(TestCase): - def test_names(self): - mgr = TransactionManager() - mgr._MustUse__silence = True # type: ignore - - class T(Elaboratable): - def __init__(self): - self._MustUse__silence = True # type: ignore - Transaction(manager=mgr) - - T() - assert mgr.transactions[0].name == "T" - - t = Transaction(name="x", manager=mgr) - assert t.name == "x" - - t = Transaction(manager=mgr) - assert t.name == "t" - - m = Method(name="x") - assert m.name == "x" - - m = Method() - assert m.name == "m" - - -class TestScheduler(TestCaseWithSimulator): - def count_test(self, sched, cnt): - assert sched.count == cnt - assert len(sched.requests) == cnt - assert len(sched.grant) == cnt - assert len(sched.valid) == 1 - - async def sim_step(self, sim, sched: Scheduler, request: int, expected_grant: int): - sim.set(sched.requests, request) - _, _, valid, grant = await sim.tick().sample(sched.valid, sched.grant) - - if request == 0: - assert not valid - else: - assert grant == expected_grant - assert valid - - def test_single(self): - sched = Scheduler(1) - self.count_test(sched, 1) - - async def process(sim): - await self.sim_step(sim, sched, 0, 0) - await self.sim_step(sim, sched, 1, 1) - await self.sim_step(sim, sched, 1, 1) - await self.sim_step(sim, sched, 0, 0) - - with self.run_simulation(sched) as sim: - sim.add_testbench(process) - - def test_multi(self): - sched = Scheduler(4) - self.count_test(sched, 4) - - async def process(sim): - await self.sim_step(sim, sched, 0b0000, 0b0000) - await self.sim_step(sim, sched, 0b1010, 0b0010) - await self.sim_step(sim, sched, 0b1010, 0b1000) - await self.sim_step(sim, sched, 0b1010, 0b0010) - await self.sim_step(sim, sched, 0b1001, 0b1000) - await self.sim_step(sim, sched, 0b1001, 0b0001) - - await self.sim_step(sim, sched, 0b1111, 0b0010) - await self.sim_step(sim, sched, 0b1111, 0b0100) - await self.sim_step(sim, sched, 0b1111, 0b1000) - await self.sim_step(sim, sched, 0b1111, 0b0001) - - await self.sim_step(sim, sched, 0b0000, 0b0000) - await self.sim_step(sim, sched, 0b0010, 0b0010) - await self.sim_step(sim, sched, 0b0010, 0b0010) - - with self.run_simulation(sched) as sim: - sim.add_testbench(process) - - -class TransactionConflictTestCircuit(Elaboratable): - def __init__(self, scheduler): - self.scheduler = scheduler - - def elaborate(self, platform): - m = TModule() - tm = TransactionModule(m, DependencyContext.get(), TransactionManager(self.scheduler)) - adapter = Adapter(i=data_layout(32), o=data_layout(32)) - m.submodules.out = self.out = TestbenchIO(adapter) - m.submodules.in1 = self.in1 = TestbenchIO(AdapterTrans(adapter.iface)) - m.submodules.in2 = self.in2 = TestbenchIO(AdapterTrans(adapter.iface)) - return tm - - -@parameterized_class( - ("name", "scheduler"), - [ - ("trivial_roundrobin", trivial_roundrobin_cc_scheduler), - ("eager_deterministic", eager_deterministic_cc_scheduler), - ], -) -class TestTransactionConflict(TestCaseWithSimulator): - scheduler: TransactionScheduler - - def setup_method(self): - random.seed(42) - - def make_process( - self, - io: TestbenchIO, - prob: float, - src: Iterable[int], - tgt: Callable[[int], None], - chk: Callable[[int], None], - ): - async def process(sim): - for i in src: - while random.random() >= prob: - await sim.tick() - tgt(i) - r = await io.call(sim, data=i) - chk(r["data"]) - - return process - - def make_in1_process(self, prob: float): - def tgt(x: int): - self.out1_expected.append(x) - - def chk(x: int): - assert x == self.in_expected.popleft() - - return self.make_process(self.m.in1, prob, self.in1_stream, tgt, chk) - - def make_in2_process(self, prob: float): - def tgt(x: int): - self.out2_expected.append(x) - - def chk(x: int): - assert x == self.in_expected.popleft() - - return self.make_process(self.m.in2, prob, self.in2_stream, tgt, chk) - - def make_out_process(self, prob: float): - def tgt(x: int): - self.in_expected.append(x) - - def chk(x: int): - if self.out1_expected and x == self.out1_expected[0]: - self.out1_expected.popleft() - elif self.out2_expected and x == self.out2_expected[0]: - self.out2_expected.popleft() - else: - assert False, "%d not found in any of the queues" % x - - return self.make_process(self.m.out, prob, self.out_stream, tgt, chk) - - @parameterized.expand( - [ - ("fullcontention", 1, 1, 1), - ("highcontention", 0.5, 0.5, 0.75), - ("lowcontention", 0.1, 0.1, 0.5), - ] - ) - def test_calls(self, name, prob1, prob2, probout): - self.in1_stream = range(0, 100) - self.in2_stream = range(100, 200) - self.out_stream = range(200, 400) - self.in_expected = deque() - self.out1_expected = deque() - self.out2_expected = deque() - self.m = TransactionConflictTestCircuit(self.__class__.scheduler) - - with self.run_simulation(self.m, add_transaction_module=False) as sim: - sim.add_testbench(self.make_in1_process(prob1)) - sim.add_testbench(self.make_in2_process(prob2)) - sim.add_testbench(self.make_out_process(probout)) - - assert not self.in_expected - assert not self.out1_expected - assert not self.out2_expected - - -class SchedulingTestCircuit(Elaboratable): - def __init__(self): - self.r1 = Signal() - self.r2 = Signal() - self.t1 = Signal() - self.t2 = Signal() - - @abstractmethod - def elaborate(self, platform) -> HasElaborate: - raise NotImplementedError - - -class PriorityTestCircuit(SchedulingTestCircuit): - def __init__(self, priority: Priority, unsatisfiable=False): - super().__init__() - self.priority = priority - self.unsatisfiable = unsatisfiable - - def make_relations(self, t1: Transaction | Method, t2: Transaction | Method): - t1.add_conflict(t2, self.priority) - if self.unsatisfiable: - t2.add_conflict(t1, self.priority) - - -class TransactionPriorityTestCircuit(PriorityTestCircuit): - def elaborate(self, platform): - m = TModule() - - transaction1 = Transaction() - transaction2 = Transaction() - - with transaction1.body(m, request=self.r1): - m.d.comb += self.t1.eq(1) - - with transaction2.body(m, request=self.r2): - m.d.comb += self.t2.eq(1) - - self.make_relations(transaction1, transaction2) - - return m - - -class MethodPriorityTestCircuit(PriorityTestCircuit): - def elaborate(self, platform): - m = TModule() - - method1 = Method() - method2 = Method() - - @def_method(m, method1, ready=self.r1) - def _(): - m.d.comb += self.t1.eq(1) - - @def_method(m, method2, ready=self.r2) - def _(): - m.d.comb += self.t2.eq(1) - - with Transaction().body(m): - method1(m) - - with Transaction().body(m): - method2(m) - - self.make_relations(method1, method2) - - return m - - -@parameterized_class( - ("name", "circuit"), [("transaction", TransactionPriorityTestCircuit), ("method", MethodPriorityTestCircuit)] -) -class TestTransactionPriorities(TestCaseWithSimulator): - circuit: type[PriorityTestCircuit] - - def setup_method(self): - random.seed(42) - - @parameterized.expand([(Priority.UNDEFINED,), (Priority.LEFT,), (Priority.RIGHT,)]) - def test_priorities(self, priority: Priority): - m = self.circuit(priority) - - async def process(sim): - to_do = 5 * [(0, 1), (1, 0), (1, 1)] - random.shuffle(to_do) - for r1, r2 in to_do: - sim.set(m.r1, r1) - sim.set(m.r2, r2) - _, t1, t2 = await sim.delay(1e-9).sample(m.t1, m.t2) - assert t1 != t2 - if r1 == 1 and r2 == 1: - if priority == Priority.LEFT: - assert t1 - if priority == Priority.RIGHT: - assert t2 - - with self.run_simulation(m) as sim: - sim.add_testbench(process) - - @parameterized.expand([(Priority.UNDEFINED,), (Priority.LEFT,), (Priority.RIGHT,)]) - def test_unsatisfiable(self, priority: Priority): - m = self.circuit(priority, True) - - import graphlib - - if priority != Priority.UNDEFINED: - cm = pytest.raises(graphlib.CycleError) - else: - cm = contextlib.nullcontext() - - with cm: - with self.run_simulation(m): - pass - - -class NestedTransactionsTestCircuit(SchedulingTestCircuit): - def elaborate(self, platform): - m = TModule() - tm = TransactionModule(m) - - with tm.context(): - with Transaction().body(m, request=self.r1): - m.d.comb += self.t1.eq(1) - with Transaction().body(m, request=self.r2): - m.d.comb += self.t2.eq(1) - - return tm - - -class NestedMethodsTestCircuit(SchedulingTestCircuit): - def elaborate(self, platform): - m = TModule() - tm = TransactionModule(m) - - method1 = Method() - method2 = Method() - - @def_method(m, method1, ready=self.r1) - def _(): - m.d.comb += self.t1.eq(1) - - @def_method(m, method2, ready=self.r2) - def _(): - m.d.comb += self.t2.eq(1) - - with tm.context(): - with Transaction().body(m): - method1(m) - - with Transaction().body(m): - method2(m) - - return tm - - -@parameterized_class( - ("name", "circuit"), [("transaction", NestedTransactionsTestCircuit), ("method", NestedMethodsTestCircuit)] -) -class TestNested(TestCaseWithSimulator): - circuit: type[SchedulingTestCircuit] - - def setup_method(self): - random.seed(42) - - def test_scheduling(self): - m = self.circuit() - - async def process(sim): - to_do = 5 * [(0, 1), (1, 0), (1, 1)] - random.shuffle(to_do) - for r1, r2 in to_do: - sim.set(m.r1, r1) - sim.set(m.r2, r2) - *_, t1, t2 = await sim.tick().sample(m.t1, m.t2) - assert t1 == r1 - assert t2 == r1 * r2 - - with self.run_simulation(m) as sim: - sim.add_testbench(process) - - -class ScheduleBeforeTestCircuit(SchedulingTestCircuit): - def elaborate(self, platform): - m = TModule() - tm = TransactionModule(m) - - method = Method() - - @def_method(m, method) - def _(): - pass - - with tm.context(): - with (t1 := Transaction()).body(m, request=self.r1): - method(m) - m.d.comb += self.t1.eq(1) - - with (t2 := Transaction()).body(m, request=self.r2 & t1.grant): - method(m) - m.d.comb += self.t2.eq(1) - - t1.schedule_before(t2) - - return tm - - -class TestScheduleBefore(TestCaseWithSimulator): - def setup_method(self): - random.seed(42) - - def test_schedule_before(self): - m = ScheduleBeforeTestCircuit() - - async def process(sim): - to_do = 5 * [(0, 1), (1, 0), (1, 1)] - random.shuffle(to_do) - for r1, r2 in to_do: - sim.set(m.r1, r1) - sim.set(m.r2, r2) - *_, t1, t2 = await sim.tick().sample(m.t1, m.t2) - assert t1 == r1 - assert not t2 - - with self.run_simulation(m) as sim: - sim.add_testbench(process) - - -class SingleCallerTestCircuit(Elaboratable): - def elaborate(self, platform): - m = TModule() - - method = Method(single_caller=True) - - with Transaction().body(m): - method(m) - - with Transaction().body(m): - method(m) - - return m - - -class TestSingleCaller(TestCaseWithSimulator): - def test_single_caller(self): - m = SingleCallerTestCircuit() - - with pytest.raises(RuntimeError): - with self.run_simulation(m): - pass diff --git a/test/transactron/lib/__init__.py b/test/transactron/lib/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/transactron/lib/test_fifo.py b/test/transactron/lib/test_fifo.py deleted file mode 100644 index b9d0c5745..000000000 --- a/test/transactron/lib/test_fifo.py +++ /dev/null @@ -1,72 +0,0 @@ -from amaranth import * - -from transactron.lib import AdapterTrans, BasicFifo - -from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout, TestbenchContext -from collections import deque -from parameterized import parameterized_class -import random - - -class BasicFifoTestCircuit(Elaboratable): - def __init__(self, depth): - self.depth = depth - - def elaborate(self, platform): - m = Module() - - m.submodules.fifo = self.fifo = BasicFifo(layout=data_layout(8), depth=self.depth) - - m.submodules.fifo_read = self.fifo_read = TestbenchIO(AdapterTrans(self.fifo.read)) - m.submodules.fifo_write = self.fifo_write = TestbenchIO(AdapterTrans(self.fifo.write)) - m.submodules.fifo_clear = self.fifo_clear = TestbenchIO(AdapterTrans(self.fifo.clear)) - - return m - - -@parameterized_class( - ("name", "depth"), - [ - ("notpower", 5), - ("power", 4), - ], -) -class TestBasicFifo(TestCaseWithSimulator): - depth: int - - def test_randomized(self): - fifoc = BasicFifoTestCircuit(depth=self.depth) - expq = deque() - - cycles = 256 - random.seed(42) - - self.done = False - - async def source(sim: TestbenchContext): - for _ in range(cycles): - await self.random_wait_geom(sim, 0.5) - - v = random.randint(0, (2**fifoc.fifo.width) - 1) - expq.appendleft(v) - await fifoc.fifo_write.call(sim, data=v) - - if random.random() < 0.005: - await fifoc.fifo_clear.call(sim) - await sim.delay(1e-9) - expq.clear() - - self.done = True - - async def target(sim: TestbenchContext): - while not self.done or expq: - await self.random_wait_geom(sim, 0.5) - - v = await fifoc.fifo_read.call_try(sim) - - if v is not None: - assert v.data == expq.pop() - - with self.run_simulation(fifoc) as sim: - sim.add_testbench(source) - sim.add_testbench(target) diff --git a/test/transactron/lib/test_transaction_lib.py b/test/transactron/lib/test_transaction_lib.py deleted file mode 100644 index 6932e4985..000000000 --- a/test/transactron/lib/test_transaction_lib.py +++ /dev/null @@ -1,786 +0,0 @@ -import pytest -from itertools import product -import random -from operator import and_ -from functools import reduce -from typing import Optional, TypeAlias -from parameterized import parameterized -from collections import deque - -from amaranth import * -from transactron import * -from transactron.lib import * -from transactron.testing.method_mock import MethodMock -from transactron.testing.testbenchio import CallTrigger -from transactron.utils._typing import ModuleLike, MethodStruct, RecordDict -from transactron.utils import ModuleConnector -from transactron.testing import ( - SimpleTestCircuit, - TestCaseWithSimulator, - data_layout, - def_method_mock, - TestbenchIO, - TestbenchContext, -) - - -class RevConnect(Elaboratable): - def __init__(self, layout: MethodLayout): - self.connect = Connect(rev_layout=layout) - self.read = self.connect.write - self.write = self.connect.read - - def elaborate(self, platform): - return self.connect - - -FIFO_Like: TypeAlias = FIFO | Forwarder | Connect | RevConnect | Pipe - - -class TestFifoBase(TestCaseWithSimulator): - def do_test_fifo( - self, fifo_class: type[FIFO_Like], writer_rand: int = 0, reader_rand: int = 0, fifo_kwargs: dict = {} - ): - iosize = 8 - - m = SimpleTestCircuit(fifo_class(data_layout(iosize), **fifo_kwargs)) - - random.seed(1337) - - async def writer(sim: TestbenchContext): - for i in range(2**iosize): - await m.write.call(sim, data=i) - await self.random_wait(sim, writer_rand) - - async def reader(sim: TestbenchContext): - for i in range(2**iosize): - assert (await m.read.call(sim)).data == i - await self.random_wait(sim, reader_rand) - - with self.run_simulation(m) as sim: - sim.add_testbench(reader) - sim.add_testbench(writer) - - -class TestFIFO(TestFifoBase): - @parameterized.expand([(0, 0), (2, 0), (0, 2), (1, 1)]) - def test_fifo(self, writer_rand, reader_rand): - self.do_test_fifo(FIFO, writer_rand=writer_rand, reader_rand=reader_rand, fifo_kwargs=dict(depth=4)) - - -class TestConnect(TestFifoBase): - @parameterized.expand([(0, 0), (2, 0), (0, 2), (1, 1)]) - def test_fifo(self, writer_rand, reader_rand): - self.do_test_fifo(Connect, writer_rand=writer_rand, reader_rand=reader_rand) - - @parameterized.expand([(0, 0), (2, 0), (0, 2), (1, 1)]) - def test_rev_fifo(self, writer_rand, reader_rand): - self.do_test_fifo(RevConnect, writer_rand=writer_rand, reader_rand=reader_rand) - - -class TestForwarder(TestFifoBase): - @parameterized.expand([(0, 0), (2, 0), (0, 2), (1, 1)]) - def test_fifo(self, writer_rand, reader_rand): - self.do_test_fifo(Forwarder, writer_rand=writer_rand, reader_rand=reader_rand) - - def test_forwarding(self): - iosize = 8 - - m = SimpleTestCircuit(Forwarder(data_layout(iosize))) - - async def forward_check(sim: TestbenchContext, x: int): - read_res, write_res = await CallTrigger(sim).call(m.read).call(m.write, data=x) - assert read_res is not None and read_res.data == x - assert write_res is not None - - async def process(sim: TestbenchContext): - # test forwarding behavior - for x in range(4): - await forward_check(sim, x) - - # load the overflow buffer - res = await m.write.call_try(sim, data=42) - assert res is not None - - # writes are not possible now - res = await m.write.call_try(sim, data=42) - assert res is None - - # read from the overflow buffer, writes still blocked - read_res, write_res = await CallTrigger(sim).call(m.read).call(m.write, data=111) - assert read_res is not None and read_res.data == 42 - assert write_res is None - - # forwarding now works again - for x in range(4): - await forward_check(sim, x) - - with self.run_simulation(m) as sim: - sim.add_testbench(process) - - -class TestPipe(TestFifoBase): - @parameterized.expand([(0, 0), (2, 0), (0, 2), (1, 1)]) - def test_fifo(self, writer_rand, reader_rand): - self.do_test_fifo(Pipe, writer_rand=writer_rand, reader_rand=reader_rand) - - def test_pipelining(self): - self.do_test_fifo(Pipe, writer_rand=0, reader_rand=0) - - -class TestMemoryBank(TestCaseWithSimulator): - test_conf = [(9, 3, 3, 3, 14), (16, 1, 1, 3, 15), (16, 1, 1, 1, 16), (12, 3, 1, 1, 17), (9, 0, 0, 0, 18)] - - @pytest.mark.parametrize("max_addr, writer_rand, reader_req_rand, reader_resp_rand, seed", test_conf) - @pytest.mark.parametrize("transparent", [False, True]) - @pytest.mark.parametrize("read_ports", [1, 2]) - @pytest.mark.parametrize("write_ports", [1, 2]) - def test_mem( - self, - max_addr: int, - writer_rand: int, - reader_req_rand: int, - reader_resp_rand: int, - seed: int, - transparent: bool, - read_ports: int, - write_ports: int, - ): - test_count = 200 - - data_width = 6 - m = SimpleTestCircuit( - MemoryBank( - data_layout=[("data", data_width)], - elem_count=max_addr, - transparent=transparent, - read_ports=read_ports, - write_ports=write_ports, - ), - ) - - data: list[int] = [0 for _ in range(max_addr)] - read_req_queues = [deque() for _ in range(read_ports)] - - random.seed(seed) - - def writer(i): - async def process(sim: TestbenchContext): - for cycle in range(test_count): - d = random.randrange(2**data_width) - a = random.randrange(max_addr) - await m.writes[i].call(sim, data={"data": d}, addr=a) - await sim.delay(1e-9 * (i + 2 if not transparent else i)) - data[a] = d - await self.random_wait(sim, writer_rand) - - return process - - def reader_req(i): - async def process(sim: TestbenchContext): - for cycle in range(test_count): - a = random.randrange(max_addr) - await m.read_reqs[i].call(sim, addr=a) - await sim.delay(1e-9 * (1 if not transparent else write_ports + 2)) - d = data[a] - read_req_queues[i].append(d) - await self.random_wait(sim, reader_req_rand) - - return process - - def reader_resp(i): - async def process(sim: TestbenchContext): - for cycle in range(test_count): - await sim.delay(1e-9 * (write_ports + 3)) - while not read_req_queues[i]: - await self.random_wait(sim, reader_resp_rand or 1, min_cycle_cnt=1) - await sim.delay(1e-9 * (write_ports + 3)) - d = read_req_queues[i].popleft() - assert (await m.read_resps[i].call(sim)).data == d - await self.random_wait(sim, reader_resp_rand) - - return process - - pipeline_test = writer_rand == 0 and reader_req_rand == 0 and reader_resp_rand == 0 - max_cycles = test_count + 2 if pipeline_test else 100000 - - with self.run_simulation(m, max_cycles=max_cycles) as sim: - for i in range(read_ports): - sim.add_testbench(reader_req(i)) - sim.add_testbench(reader_resp(i)) - for i in range(write_ports): - sim.add_testbench(writer(i)) - - -class TestAsyncMemoryBank(TestCaseWithSimulator): - @pytest.mark.parametrize( - "max_addr, writer_rand, reader_rand, seed", [(9, 3, 3, 14), (16, 1, 1, 15), (16, 1, 1, 16), (12, 3, 1, 17)] - ) - @pytest.mark.parametrize("read_ports", [1, 2]) - @pytest.mark.parametrize("write_ports", [1, 2]) - def test_mem(self, max_addr: int, writer_rand: int, reader_rand: int, seed: int, read_ports: int, write_ports: int): - test_count = 200 - - data_width = 6 - m = SimpleTestCircuit( - AsyncMemoryBank( - data_layout=[("data", data_width)], elem_count=max_addr, read_ports=read_ports, write_ports=write_ports - ), - ) - - data: list[int] = list(0 for i in range(max_addr)) - - random.seed(seed) - - def writer(i): - async def process(sim: TestbenchContext): - for cycle in range(test_count): - d = random.randrange(2**data_width) - a = random.randrange(max_addr) - await m.writes[i].call(sim, data={"data": d}, addr=a) - await sim.delay(1e-9 * (i + 2)) - data[a] = d - await self.random_wait(sim, writer_rand, min_cycle_cnt=1) - - return process - - def reader(i): - async def process(sim: TestbenchContext): - for cycle in range(test_count): - a = random.randrange(max_addr) - d = await m.reads[i].call(sim, addr=a) - await sim.delay(1e-9) - expected_d = data[a] - assert d["data"] == expected_d - await self.random_wait(sim, reader_rand, min_cycle_cnt=1) - - return process - - with self.run_simulation(m) as sim: - for i in range(read_ports): - sim.add_testbench(reader(i)) - for i in range(write_ports): - sim.add_testbench(writer(i)) - - -class ManyToOneConnectTransTestCircuit(Elaboratable): - def __init__(self, count: int, lay: MethodLayout): - self.count = count - self.lay = lay - self.inputs: list[TestbenchIO] = [] - - def elaborate(self, platform): - m = TModule() - - get_results = [] - for i in range(self.count): - input = TestbenchIO(Adapter(o=self.lay)) - get_results.append(input.adapter.iface) - m.submodules[f"input_{i}"] = input - self.inputs.append(input) - - # Create ManyToOneConnectTrans, which will serialize results from different inputs - output = TestbenchIO(Adapter(i=self.lay)) - m.submodules.output = self.output = output - m.submodules.fu_arbitration = ManyToOneConnectTrans(get_results=get_results, put_result=output.adapter.iface) - - return m - - -class TestManyToOneConnectTrans(TestCaseWithSimulator): - def initialize(self): - f1_size = 14 - f2_size = 3 - self.lay = [("field1", f1_size), ("field2", f2_size)] - - self.m = ManyToOneConnectTransTestCircuit(self.count, self.lay) - random.seed(14) - - self.inputs = [] - # Create list with info if we processed all data from inputs - self.producer_end = [False for i in range(self.count)] - self.expected_output = {} - self.max_wait = 4 - - # Prepare random results for inputs - for i in range(self.count): - data = [] - input_size = random.randint(20, 30) - for j in range(input_size): - t = ( - random.randint(0, 2**f1_size), - random.randint(0, 2**f2_size), - ) - data.append(t) - if t in self.expected_output: - self.expected_output[t] += 1 - else: - self.expected_output[t] = 1 - self.inputs.append(data) - - def generate_producer(self, i: int): - """ - This is an helper function, which generates a producer process, - which will simulate an FU. Producer will insert in random intervals new - results to its output FIFO. This records will be next serialized by FUArbiter. - """ - - async def producer(sim: TestbenchContext): - inputs = self.inputs[i] - for field1, field2 in inputs: - self.m.inputs[i].call_init(sim, field1=field1, field2=field2) - await self.random_wait(sim, self.max_wait) - self.producer_end[i] = True - - return producer - - async def consumer(self, sim: TestbenchContext): - # TODO: this test doesn't test anything, needs to be fixed! - while reduce(and_, self.producer_end, True): - result = await self.m.output.call_do(sim) - - assert result is not None - - t = (result["field1"], result["field2"]) - assert t in self.expected_output - if self.expected_output[t] == 1: - del self.expected_output[t] - else: - self.expected_output[t] -= 1 - await self.random_wait(sim, self.max_wait) - - @pytest.mark.parametrize("count", [1, 4]) - def test(self, count: int): - self.count = count - self.initialize() - with self.run_simulation(self.m) as sim: - sim.add_testbench(self.consumer) - for i in range(self.count): - sim.add_testbench(self.generate_producer(i)) - - -class MethodMapTestCircuit(Elaboratable): - def __init__(self, iosize: int, use_methods: bool, use_dicts: bool): - self.iosize = iosize - self.use_methods = use_methods - self.use_dicts = use_dicts - - def elaborate(self, platform): - m = TModule() - - layout = data_layout(self.iosize) - - def itransform_rec(m: ModuleLike, v: MethodStruct) -> MethodStruct: - s = Signal.like(v) - m.d.comb += s.data.eq(v.data + 1) - return s - - def otransform_rec(m: ModuleLike, v: MethodStruct) -> MethodStruct: - s = Signal.like(v) - m.d.comb += s.data.eq(v.data - 1) - return s - - def itransform_dict(_, v: MethodStruct) -> RecordDict: - return {"data": v.data + 1} - - def otransform_dict(_, v: MethodStruct) -> RecordDict: - return {"data": v.data - 1} - - if self.use_dicts: - itransform = itransform_dict - otransform = otransform_dict - else: - itransform = itransform_rec - otransform = otransform_rec - - m.submodules.target = self.target = TestbenchIO(Adapter(i=layout, o=layout)) - - if self.use_methods: - imeth = Method(i=layout, o=layout) - ometh = Method(i=layout, o=layout) - - @def_method(m, imeth) - def _(arg: MethodStruct): - return itransform(m, arg) - - @def_method(m, ometh) - def _(arg: MethodStruct): - return otransform(m, arg) - - trans = MethodMap(self.target.adapter.iface, i_transform=(layout, imeth), o_transform=(layout, ometh)) - else: - trans = MethodMap( - self.target.adapter.iface, - i_transform=(layout, itransform), - o_transform=(layout, otransform), - ) - - m.submodules.source = self.source = TestbenchIO(AdapterTrans(trans.use(m))) - - return m - - -class TestMethodTransformer(TestCaseWithSimulator): - m: MethodMapTestCircuit - - async def source(self, sim: TestbenchContext): - for i in range(2**self.m.iosize): - v = await self.m.source.call(sim, data=i) - i1 = (i + 1) & ((1 << self.m.iosize) - 1) - assert v.data == (((i1 << 1) | (i1 >> (self.m.iosize - 1))) - 1) & ((1 << self.m.iosize) - 1) - - @def_method_mock(lambda self: self.m.target) - def target(self, data): - return {"data": (data << 1) | (data >> (self.m.iosize - 1))} - - def test_method_transformer(self): - self.m = MethodMapTestCircuit(4, False, False) - with self.run_simulation(self.m) as sim: - sim.add_testbench(self.source) - - def test_method_transformer_dicts(self): - self.m = MethodMapTestCircuit(4, False, True) - with self.run_simulation(self.m) as sim: - sim.add_testbench(self.source) - - def test_method_transformer_with_methods(self): - self.m = MethodMapTestCircuit(4, True, True) - with self.run_simulation(self.m) as sim: - sim.add_testbench(self.source) - - -class TestMethodFilter(TestCaseWithSimulator): - def initialize(self): - self.iosize = 4 - self.layout = data_layout(self.iosize) - self.target = TestbenchIO(Adapter(i=self.layout, o=self.layout)) - self.cmeth = TestbenchIO(Adapter(i=self.layout, o=data_layout(1))) - - async def source(self, sim: TestbenchContext): - for i in range(2**self.iosize): - v = await self.tc.method.call(sim, data=i) - if i & 1: - assert v.data == (i + 1) & ((1 << self.iosize) - 1) - else: - assert v.data == 0 - - @def_method_mock(lambda self: self.target) - def target_mock(self, data): - return {"data": data + 1} - - @def_method_mock(lambda self: self.cmeth) - def cmeth_mock(self, data): - return {"data": data % 2} - - def test_method_filter_with_methods(self): - self.initialize() - self.tc = SimpleTestCircuit(MethodFilter(self.target.adapter.iface, self.cmeth.adapter.iface)) - m = ModuleConnector(test_circuit=self.tc, target=self.target, cmeth=self.cmeth) - with self.run_simulation(m) as sim: - sim.add_testbench(self.source) - - @parameterized.expand([(True,), (False,)]) - def test_method_filter_plain(self, use_condition): - self.initialize() - - def condition(_, v): - return v.data[0] - - self.tc = SimpleTestCircuit(MethodFilter(self.target.adapter.iface, condition, use_condition=use_condition)) - m = ModuleConnector(test_circuit=self.tc, target=self.target, cmeth=self.cmeth) - with self.run_simulation(m) as sim: - sim.add_testbench(self.source) - - -class MethodProductTestCircuit(Elaboratable): - def __init__(self, iosize: int, targets: int, add_combiner: bool): - self.iosize = iosize - self.targets = targets - self.add_combiner = add_combiner - self.target: list[TestbenchIO] = [] - - def elaborate(self, platform): - m = TModule() - - layout = data_layout(self.iosize) - - methods = [] - - for k in range(self.targets): - tgt = TestbenchIO(Adapter(i=layout, o=layout)) - methods.append(tgt.adapter.iface) - self.target.append(tgt) - m.submodules += tgt - - combiner = None - if self.add_combiner: - combiner = (layout, lambda _, vs: {"data": sum(x.data for x in vs)}) - - product = MethodProduct(methods, combiner) - - m.submodules.method = self.method = TestbenchIO(AdapterTrans(product.use(m))) - - return m - - -class TestMethodProduct(TestCaseWithSimulator): - @parameterized.expand([(1, False), (2, False), (5, True)]) - def test_method_product(self, targets: int, add_combiner: bool): - random.seed(14) - - iosize = 8 - m = MethodProductTestCircuit(iosize, targets, add_combiner) - - method_en = [False] * targets - - def target_process(k: int): - @def_method_mock(lambda: m.target[k], enable=lambda: method_en[k]) - def mock(data): - return {"data": data + k} - - return mock() - - async def method_process(sim: TestbenchContext): - # if any of the target methods is not enabled, call does not succeed - for i in range(2**targets - 1): - for k in range(targets): - method_en[k] = bool(i & (1 << k)) - - await sim.tick() - assert (await m.method.call_try(sim, data=0)) is None - - # otherwise, the call succeeds - for k in range(targets): - method_en[k] = True - await sim.tick() - - data = random.randint(0, (1 << iosize) - 1) - val = (await m.method.call(sim, data=data)).data - if add_combiner: - assert val == (targets * data + (targets - 1) * targets // 2) & ((1 << iosize) - 1) - else: - assert val == data - - with self.run_simulation(m) as sim: - sim.add_testbench(method_process) - for k in range(targets): - self.add_mock(sim, target_process(k)) - - -class TestSerializer(TestCaseWithSimulator): - def setup_method(self): - self.test_count = 100 - - self.port_count = 2 - self.data_width = 5 - - self.requestor_rand = 4 - - layout = [("field", self.data_width)] - - self.req_method = TestbenchIO(Adapter(i=layout)) - self.resp_method = TestbenchIO(Adapter(o=layout)) - - self.test_circuit = SimpleTestCircuit( - Serializer( - port_count=self.port_count, - serialized_req_method=self.req_method.adapter.iface, - serialized_resp_method=self.resp_method.adapter.iface, - ), - ) - self.m = ModuleConnector( - test_circuit=self.test_circuit, req_method=self.req_method, resp_method=self.resp_method - ) - - random.seed(14) - - self.serialized_data = deque() - self.port_data = [deque() for _ in range(self.port_count)] - - self.got_request = False - - @def_method_mock(lambda self: self.req_method, enable=lambda self: not self.got_request) - def serial_req_mock(self, field): - @MethodMock.effect - def eff(): - self.serialized_data.append(field) - self.got_request = True - - @def_method_mock(lambda self: self.resp_method, enable=lambda self: self.got_request) - def serial_resp_mock(self): - @MethodMock.effect - def eff(): - self.got_request = False - - if self.serialized_data: - return {"field": self.serialized_data[-1]} - - def requestor(self, i: int): - async def f(sim: TestbenchContext): - for _ in range(self.test_count): - d = random.randrange(2**self.data_width) - await self.test_circuit.serialize_in[i].call(sim, field=d) - self.port_data[i].append(d) - await self.random_wait(sim, self.requestor_rand, min_cycle_cnt=1) - - return f - - def responder(self, i: int): - async def f(sim: TestbenchContext): - for _ in range(self.test_count): - data_out = await self.test_circuit.serialize_out[i].call(sim) - assert self.port_data[i].popleft() == data_out.field - await self.random_wait(sim, self.requestor_rand, min_cycle_cnt=1) - - return f - - def test_serial(self): - with self.run_simulation(self.m) as sim: - for i in range(self.port_count): - sim.add_testbench(self.requestor(i)) - sim.add_testbench(self.responder(i)) - - -class TestMethodTryProduct(TestCaseWithSimulator): - @parameterized.expand([(1, False), (2, False), (5, True)]) - def test_method_try_product(self, targets: int, add_combiner: bool): - random.seed(14) - - iosize = 8 - m = MethodTryProductTestCircuit(iosize, targets, add_combiner) - - method_en = [False] * targets - - def target_process(k: int): - @def_method_mock(lambda: m.target[k], enable=lambda: method_en[k]) - def mock(data): - return {"data": data + k} - - return mock() - - async def method_process(sim: TestbenchContext): - for i in range(2**targets): - for k in range(targets): - method_en[k] = bool(i & (1 << k)) - - active_targets = sum(method_en) - - await sim.tick() - - data = random.randint(0, (1 << iosize) - 1) - val = await m.method.call(sim, data=data) - if add_combiner: - adds = sum(k * method_en[k] for k in range(targets)) - assert val.data == (active_targets * data + adds) & ((1 << iosize) - 1) - else: - assert val.shape().size == 0 - - with self.run_simulation(m) as sim: - sim.add_testbench(method_process) - for k in range(targets): - self.add_mock(sim, target_process(k)) - - -class MethodTryProductTestCircuit(Elaboratable): - def __init__(self, iosize: int, targets: int, add_combiner: bool): - self.iosize = iosize - self.targets = targets - self.add_combiner = add_combiner - self.target: list[TestbenchIO] = [] - - def elaborate(self, platform): - m = TModule() - - layout = data_layout(self.iosize) - - methods = [] - - for k in range(self.targets): - tgt = TestbenchIO(Adapter(i=layout, o=layout)) - methods.append(tgt.adapter.iface) - self.target.append(tgt) - m.submodules += tgt - - combiner = None - if self.add_combiner: - combiner = (layout, lambda _, vs: {"data": sum(Mux(s, r, 0) for (s, r) in vs)}) - - product = MethodTryProduct(methods, combiner) - - m.submodules.method = self.method = TestbenchIO(AdapterTrans(product.use(m))) - - return m - - -class ConditionTestCircuit(Elaboratable): - def __init__(self, target: Method, *, nonblocking: bool, priority: bool, catchall: bool): - self.target = target - self.source = Method(i=[("cond1", 1), ("cond2", 1), ("cond3", 1)], single_caller=True) - self.nonblocking = nonblocking - self.priority = priority - self.catchall = catchall - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.source) - def _(cond1, cond2, cond3): - with condition(m, nonblocking=self.nonblocking, priority=self.priority) as branch: - with branch(cond1): - self.target(m, cond=1) - with branch(cond2): - self.target(m, cond=2) - with branch(cond3): - self.target(m, cond=3) - if self.catchall: - with branch(): - self.target(m, cond=0) - - return m - - -class TestCondition(TestCaseWithSimulator): - @pytest.mark.parametrize("nonblocking", [False, True]) - @pytest.mark.parametrize("priority", [False, True]) - @pytest.mark.parametrize("catchall", [False, True]) - def test_condition(self, nonblocking: bool, priority: bool, catchall: bool): - target = TestbenchIO(Adapter(i=[("cond", 2)])) - - circ = SimpleTestCircuit( - ConditionTestCircuit(target.adapter.iface, nonblocking=nonblocking, priority=priority, catchall=catchall), - ) - m = ModuleConnector(test_circuit=circ, target=target) - - selection: Optional[int] - - @def_method_mock(lambda: target) - def target_process(cond): - @MethodMock.effect - def eff(): - nonlocal selection - selection = cond - - async def process(sim: TestbenchContext): - nonlocal selection - await sim.tick() # TODO workaround for mocks inactive in first cycle - for c1, c2, c3 in product([0, 1], [0, 1], [0, 1]): - selection = None - res = await circ.source.call_try(sim, cond1=c1, cond2=c2, cond3=c3) - - if catchall or nonblocking: - assert res is not None - - if res is None: - assert selection is None - assert not catchall or nonblocking - assert (c1, c2, c3) == (0, 0, 0) - elif selection is None: - assert nonblocking - assert (c1, c2, c3) == (0, 0, 0) - elif priority: - assert selection == c1 + 2 * c2 * (1 - c1) + 3 * c3 * (1 - c2) * (1 - c1) - else: - assert selection in [c1, 2 * c2, 3 * c3] - - with self.run_simulation(m) as sim: - sim.add_testbench(process) diff --git a/test/transactron/test_adapter.py b/test/transactron/test_adapter.py deleted file mode 100644 index 93d0611ae..000000000 --- a/test/transactron/test_adapter.py +++ /dev/null @@ -1,61 +0,0 @@ -from amaranth import * - -from transactron import Method, def_method, TModule - -from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit, TestbenchContext -from transactron.utils.amaranth_ext.elaboratables import ModuleConnector - - -class Echo(Elaboratable): - def __init__(self): - self.data_bits = 8 - - self.layout_in = data_layout(self.data_bits) - self.layout_out = data_layout(self.data_bits) - - self.action = Method(i=self.layout_in, o=self.layout_out) - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.action, ready=C(1)) - def _(arg): - return arg - - return m - - -class Consumer(Elaboratable): - def __init__(self): - self.data_bits = 8 - - self.layout_in = data_layout(self.data_bits) - self.layout_out = [] - - self.action = Method(i=self.layout_in, o=self.layout_out) - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.action, ready=C(1)) - def _(arg): - return None - - return m - - -class TestAdapterTrans(TestCaseWithSimulator): - async def proc(self, sim: TestbenchContext): - for _ in range(3): - await self.consumer.action.call(sim, data=0) - for expected in [4, 1, 0]: - obtained = (await self.echo.action.call(sim, data=expected)).data - assert expected == obtained - - def test_single(self): - self.echo = SimpleTestCircuit(Echo()) - self.consumer = SimpleTestCircuit(Consumer()) - self.m = ModuleConnector(echo=self.echo, consumer=self.consumer) - - with self.run_simulation(self.m, max_cycles=100) as sim: - sim.add_testbench(self.proc) diff --git a/test/transactron/test_assign.py b/test/transactron/test_assign.py deleted file mode 100644 index 7398570fa..000000000 --- a/test/transactron/test_assign.py +++ /dev/null @@ -1,160 +0,0 @@ -import pytest -from typing import Callable -from amaranth import * -from amaranth.lib import data -from amaranth.lib.enum import Enum -from amaranth.hdl._ast import ArrayProxy, SwitchValue, Slice - -from transactron.utils._typing import MethodLayout -from transactron.utils import AssignType, assign -from transactron.utils.assign import AssignArg, AssignFields - -from unittest import TestCase -from parameterized import parameterized_class, parameterized - - -class ExampleEnum(Enum, shape=1): - ZERO = 0 - ONE = 1 - - -def with_reversed(pairs: list[tuple[str, str]]): - return pairs + [(b, a) for (a, b) in pairs] - - -layout_a = [("a", 1)] -layout_ab = [("a", 1), ("b", 2)] -layout_ac = [("a", 1), ("c", 3)] -layout_a_alt = [("a", 2)] -layout_a_enum = [("a", ExampleEnum)] - -# Defines functions build, wrap, extr used in TestAssign -params_funs = { - "normal": (lambda mk, lay: mk(lay), lambda x: x, lambda r: r), - "rec": (lambda mk, lay: mk([("x", lay)]), lambda x: {"x": x}, lambda r: r.x), - "dict": (lambda mk, lay: {"x": mk(lay)}, lambda x: {"x": x}, lambda r: r["x"]), - "list": (lambda mk, lay: [mk(lay)], lambda x: {0: x}, lambda r: r[0]), - "union": ( - lambda mk, lay: Signal(data.UnionLayout({"x": reclayout2datalayout(lay)})), - lambda x: {"x": x}, - lambda r: r.x, - ), - "array": (lambda mk, lay: Signal(data.ArrayLayout(reclayout2datalayout(lay), 1)), lambda x: {0: x}, lambda r: r[0]), -} - - -params_pairs = [(k, k) for k in params_funs if k != "union"] + with_reversed( - [("rec", "dict"), ("list", "array"), ("union", "dict")] -) - - -def mkproxy(layout): - arr = Array([Signal(reclayout2datalayout(layout)) for _ in range(4)]) - sig = Signal(2) - return arr[sig] - - -def reclayout2datalayout(layout): - if not isinstance(layout, list): - return layout - return data.StructLayout({k: reclayout2datalayout(lay) for k, lay in layout}) - - -def mkstruct(layout): - return Signal(reclayout2datalayout(layout)) - - -params_mk = [ - ("proxy", mkproxy), - ("struct", mkstruct), -] - - -@parameterized_class( - ["name", "buildl", "wrapl", "extrl", "buildr", "wrapr", "extrr", "mk"], - [ - (f"{nl}_{nr}_{c}", *map(staticmethod, params_funs[nl] + params_funs[nr] + (m,))) - for nl, nr in params_pairs - for c, m in params_mk - ], -) -class TestAssign(TestCase): - # constructs `assign` arguments (views, proxies, dicts) which have an "inner" and "outer" part - # parameterized with a constructor and a layout of the inner part - buildl: Callable[[Callable[[MethodLayout], AssignArg], MethodLayout], AssignArg] - buildr: Callable[[Callable[[MethodLayout], AssignArg], MethodLayout], AssignArg] - # constructs field specifications for `assign`, takes field specifications for the inner part - wrapl: Callable[[AssignFields], AssignFields] - wrapr: Callable[[AssignFields], AssignFields] - # extracts the inner part of the structure - extrl: Callable[[AssignArg], ArrayProxy] - extrr: Callable[[AssignArg], ArrayProxy] - # constructor, takes a layout - mk: Callable[[MethodLayout], AssignArg] - - def test_wraps_eq(self): - assert self.wrapl({}) == self.wrapr({}) - - def test_rhs_exception(self): - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_a), self.buildr(self.mk, layout_ab), fields=AssignType.RHS)) - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_ab), self.buildr(self.mk, layout_ac), fields=AssignType.RHS)) - - def test_all_exception(self): - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_a), self.buildr(self.mk, layout_ab), fields=AssignType.ALL)) - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_ab), self.buildr(self.mk, layout_a), fields=AssignType.ALL)) - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_ab), self.buildr(self.mk, layout_ac), fields=AssignType.ALL)) - - def test_missing_exception(self): - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_a), self.buildr(self.mk, layout_ab), fields=self.wrapl({"b"}))) - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_ab), self.buildr(self.mk, layout_a), fields=self.wrapl({"b"}))) - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_a), self.buildr(self.mk, layout_a), fields=self.wrapl({"b"}))) - - def test_wrong_bits(self): - with pytest.raises(ValueError): - list(assign(self.buildl(self.mk, layout_a), self.buildr(self.mk, layout_a_alt))) - if self.mk != mkproxy: # Arrays are troublesome and defeat some checks - with pytest.raises(ValueError): - list(assign(self.buildl(self.mk, layout_a), self.buildr(self.mk, layout_a_enum))) - - @parameterized.expand( - [ - ("lhs", layout_a, layout_ab, AssignType.LHS), - ("rhs", layout_ab, layout_a, AssignType.RHS), - ("all", layout_a, layout_a, AssignType.ALL), - ("common", layout_ab, layout_ac, AssignType.COMMON), - ("set", layout_ab, layout_ab, {"a"}), - ("list", layout_ab, layout_ab, ["a", "a"]), - ] - ) - def test_assign_a(self, name, layout1: MethodLayout, layout2: MethodLayout, atype: AssignType): - lhs = self.buildl(self.mk, layout1) - rhs = self.buildr(self.mk, layout2) - alist = list(assign(lhs, rhs, fields=self.wrapl(atype))) - assert len(alist) == 1 - self.assertIs_AP(alist[0].lhs, self.extrl(lhs).a) - self.assertIs_AP(alist[0].rhs, self.extrr(rhs).a) - - def assertIs_AP(self, expr1, expr2): # noqa: N802 - expr1 = Value.cast(expr1) - expr2 = Value.cast(expr2) - if isinstance(expr1, SwitchValue) and isinstance(expr2, SwitchValue): - # new proxies are created on each index, structural equality is needed - self.assertIs(expr1.test, expr2.test) - assert len(expr1.cases) == len(expr2.cases) - for (px, x), (py, y) in zip(expr1.cases, expr2.cases): - self.assertEqual(px, py) - self.assertIs_AP(x, y) - elif isinstance(expr1, Slice) and isinstance(expr2, Slice): - self.assertIs_AP(expr1.value, expr2.value) - assert expr1.start == expr2.start - assert expr1.stop == expr2.stop - else: - self.assertIs(expr1, expr2) diff --git a/test/transactron/test_branches.py b/test/transactron/test_branches.py deleted file mode 100644 index bfb1d5842..000000000 --- a/test/transactron/test_branches.py +++ /dev/null @@ -1,99 +0,0 @@ -from amaranth import * -from itertools import product -from transactron.core import ( - TModule, - Method, - Transaction, - TransactionManager, - TransactionModule, - def_method, -) -from transactron.core.tmodule import CtrlPath -from transactron.core.manager import MethodMap -from unittest import TestCase -from transactron.testing import TestCaseWithSimulator -from transactron.utils.dependencies import DependencyContext - - -class TestExclusivePath(TestCase): - def test_exclusive_path(self): - m = TModule() - m._MustUse__silence = True # type: ignore - - with m.If(0): - cp0 = m.ctrl_path - with m.Switch(3): - with m.Case(0): - cp0a0 = m.ctrl_path - with m.Case(1): - cp0a1 = m.ctrl_path - with m.Default(): - cp0a2 = m.ctrl_path - with m.If(1): - cp0b0 = m.ctrl_path - with m.Else(): - cp0b1 = m.ctrl_path - with m.Elif(1): - cp1 = m.ctrl_path - with m.FSM(): - with m.State("start"): - cp10 = m.ctrl_path - with m.State("next"): - cp11 = m.ctrl_path - with m.Else(): - cp2 = m.ctrl_path - - def mutually_exclusive(*cps: CtrlPath): - return all(cpa.exclusive_with(cpb) for i, cpa in enumerate(cps) for cpb in cps[i + 1 :]) - - def pairwise_exclusive(cps1: list[CtrlPath], cps2: list[CtrlPath]): - return all(cpa.exclusive_with(cpb) for cpa, cpb in product(cps1, cps2)) - - def pairwise_not_exclusive(cps1: list[CtrlPath], cps2: list[CtrlPath]): - return all(not cpa.exclusive_with(cpb) for cpa, cpb in product(cps1, cps2)) - - assert mutually_exclusive(cp0, cp1, cp2) - assert mutually_exclusive(cp0a0, cp0a1, cp0a2) - assert mutually_exclusive(cp0b0, cp0b1) - assert mutually_exclusive(cp10, cp11) - assert pairwise_exclusive([cp0, cp0a0, cp0a1, cp0a2, cp0b0, cp0b1], [cp1, cp10, cp11]) - assert pairwise_not_exclusive([cp0, cp0a0, cp0a1, cp0a2], [cp0, cp0b0, cp0b1]) - - -class ExclusiveConflictRemovalCircuit(Elaboratable): - def __init__(self): - self.sel = Signal() - - def elaborate(self, platform): - m = TModule() - - called_method = Method(i=[], o=[]) - - @def_method(m, called_method) - def _(): - pass - - with m.If(self.sel): - with Transaction().body(m): - called_method(m) - with m.Else(): - with Transaction().body(m): - called_method(m) - - return m - - -class TestExclusiveConflictRemoval(TestCaseWithSimulator): - def test_conflict_removal(self): - circ = ExclusiveConflictRemovalCircuit() - - tm = TransactionManager() - dut = TransactionModule(circ, DependencyContext.get(), tm) - - with self.run_simulation(dut, add_transaction_module=False): - pass - - cgr, _ = tm._conflict_graph(MethodMap(tm.transactions)) - - for s in cgr.values(): - assert not s diff --git a/test/transactron/test_connectors.py b/test/transactron/test_connectors.py deleted file mode 100644 index e147a2fb6..000000000 --- a/test/transactron/test_connectors.py +++ /dev/null @@ -1,42 +0,0 @@ -import random -from parameterized import parameterized_class - -from transactron.lib import StableSelectingNetwork -from transactron.testing import TestCaseWithSimulator, TestbenchContext - - -@parameterized_class( - ("n"), - [(2,), (3,), (7,), (8,)], -) -class TestStableSelectingNetwork(TestCaseWithSimulator): - n: int - - def test(self): - m = StableSelectingNetwork(self.n, [("data", 8)]) - - random.seed(42) - - async def process(sim: TestbenchContext): - for _ in range(100): - inputs = [random.randrange(2**8) for _ in range(self.n)] - valids = [random.randrange(2) for _ in range(self.n)] - total = sum(valids) - - expected_output_prefix = [] - for i in range(self.n): - sim.set(m.valids[i], valids[i]) - sim.set(m.inputs[i].data, inputs[i]) - - if valids[i]: - expected_output_prefix.append(inputs[i]) - - for i in range(total): - out = sim.get(m.outputs[i].data) - assert out == expected_output_prefix[i] - - assert sim.get(m.output_cnt) == total - await sim.tick() - - with self.run_simulation(m) as sim: - sim.add_testbench(process) diff --git a/test/transactron/test_methods.py b/test/transactron/test_methods.py deleted file mode 100644 index e4a5ced78..000000000 --- a/test/transactron/test_methods.py +++ /dev/null @@ -1,790 +0,0 @@ -from collections.abc import Callable, Sequence -import pytest -import random -from amaranth import * -from amaranth.sim import * - -from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout - -from transactron import * -from transactron.testing.infrastructure import SimpleTestCircuit -from transactron.utils import MethodStruct -from transactron.lib import * - -from parameterized import parameterized - -from unittest import TestCase - -from transactron.utils.assign import AssignArg - - -class TestDefMethod(TestCaseWithSimulator): - class CircuitTestModule(Elaboratable): - def __init__(self, method_definition): - self.method = Method( - i=[("foo1", 3), ("foo2", [("bar1", 4), ("bar2", 6)])], - o=[("foo1", 3), ("foo2", [("bar1", 4), ("bar2", 6)])], - ) - - self.method_definition = method_definition - - def elaborate(self, platform): - m = TModule() - m._MustUse__silence = True # type: ignore - - def_method(m, self.method)(self.method_definition) - - return m - - def do_test_definition(self, definer): - with self.run_simulation(TestDefMethod.CircuitTestModule(definer)): - pass - - def test_fields_valid1(self): - def definition(arg): - return {"foo1": Signal(3), "foo2": {"bar1": Signal(4), "bar2": Signal(6)}} - - self.do_test_definition(definition) - - def test_fields_valid2(self): - rec = Signal(from_method_layout([("bar1", 4), ("bar2", 6)])) - - def definition(arg): - return {"foo1": Signal(3), "foo2": rec} - - self.do_test_definition(definition) - - def test_fields_valid3(self): - def definition(arg): - return arg - - self.do_test_definition(definition) - - def test_fields_valid4(self): - def definition(arg: MethodStruct): - return arg - - self.do_test_definition(definition) - - def test_fields_valid5(self): - def definition(**arg): - return arg - - self.do_test_definition(definition) - - def test_fields_valid6(self): - def definition(foo1, foo2): - return {"foo1": foo1, "foo2": foo2} - - self.do_test_definition(definition) - - def test_fields_valid7(self): - def definition(foo1, **arg): - return {"foo1": foo1, "foo2": arg["foo2"]} - - self.do_test_definition(definition) - - def test_fields_invalid1(self): - def definition(arg): - return {"foo1": Signal(3), "baz": Signal(4)} - - with pytest.raises(KeyError): - self.do_test_definition(definition) - - def test_fields_invalid2(self): - def definition(arg): - return {"foo1": Signal(3)} - - with pytest.raises(KeyError): - self.do_test_definition(definition) - - def test_fields_invalid3(self): - def definition(arg): - return {"foo1": {"baz1": Signal(), "baz2": Signal()}, "foo2": {"bar1": Signal(4), "bar2": Signal(6)}} - - with pytest.raises(TypeError): - self.do_test_definition(definition) - - def test_fields_invalid4(self): - def definition(arg: Value): - return arg - - with pytest.raises(TypeError): - self.do_test_definition(definition) - - def test_fields_invalid5(self): - def definition(foo): - return foo - - with pytest.raises(TypeError): - self.do_test_definition(definition) - - def test_fields_invalid6(self): - def definition(foo1): - return {"foo1": foo1, "foo2": {"bar1": Signal(4), "bar2": Signal(6)}} - - with pytest.raises(TypeError): - self.do_test_definition(definition) - - -class TestDefMethods(TestCaseWithSimulator): - class CircuitTestModule(Elaboratable): - def __init__(self, method_definition): - self.methods = [ - Method( - i=[("foo", 3)], - o=[("foo", 3)], - ) - for _ in range(4) - ] - - self.method_definition = method_definition - - def elaborate(self, platform): - m = TModule() - m._MustUse__silence = True # type: ignore - - def_methods(m, self.methods)(self.method_definition) - - return m - - def test_basic_methods(self): - def definition(idx: int, foo: Value): - return {"foo": foo + idx} - - circuit = SimpleTestCircuit(TestDefMethods.CircuitTestModule(definition)) - - async def test_process(sim): - for k, method in enumerate(circuit.methods): - val = random.randrange(0, 2**3) - ret = await method.call(sim, foo=val) - assert ret["foo"] == (val + k) % 2**3 - - with self.run_simulation(circuit) as sim: - sim.add_testbench(test_process) - - -class AdapterCircuit(Elaboratable): - def __init__(self, module, methods): - self.module = module - self.methods = methods - - def elaborate(self, platform): - m = TModule() - - m.submodules += self.module - for method in self.methods: - m.submodules += AdapterTrans(method) - - return m - - -class TestInvalidMethods(TestCase): - def assert_re(self, msg, m): - with pytest.raises(RuntimeError, match=msg): - Fragment.get(TransactionModule(m), platform=None) - - def test_twice(self): - class Twice(Elaboratable): - def __init__(self): - self.meth1 = Method() - self.meth2 = Method() - - def elaborate(self, platform): - m = TModule() - m._MustUse__silence = True # type: ignore - - with self.meth1.body(m): - pass - - with self.meth2.body(m): - self.meth1(m) - self.meth1(m) - - return m - - self.assert_re("called twice", Twice()) - - def test_twice_cond(self): - class Twice(Elaboratable): - def __init__(self): - self.meth1 = Method() - self.meth2 = Method() - - def elaborate(self, platform): - m = TModule() - m._MustUse__silence = True # type: ignore - - with self.meth1.body(m): - pass - - with self.meth2.body(m): - with m.If(1): - self.meth1(m) - with m.Else(): - self.meth1(m) - - return m - - Fragment.get(TransactionModule(Twice()), platform=None) - - def test_diamond(self): - class Diamond(Elaboratable): - def __init__(self): - self.meth1 = Method() - self.meth2 = Method() - self.meth3 = Method() - self.meth4 = Method() - - def elaborate(self, platform): - m = TModule() - - with self.meth1.body(m): - pass - - with self.meth2.body(m): - self.meth1(m) - - with self.meth3.body(m): - self.meth1(m) - - with self.meth4.body(m): - self.meth2(m) - self.meth3(m) - - return m - - m = Diamond() - self.assert_re("called twice", AdapterCircuit(m, [m.meth4])) - - def test_loop(self): - class Loop(Elaboratable): - def __init__(self): - self.meth1 = Method() - - def elaborate(self, platform): - m = TModule() - - with self.meth1.body(m): - self.meth1(m) - - return m - - m = Loop() - self.assert_re("called twice", AdapterCircuit(m, [m.meth1])) - - def test_cycle(self): - class Cycle(Elaboratable): - def __init__(self): - self.meth1 = Method() - self.meth2 = Method() - - def elaborate(self, platform): - m = TModule() - - with self.meth1.body(m): - self.meth2(m) - - with self.meth2.body(m): - self.meth1(m) - - return m - - m = Cycle() - self.assert_re("called twice", AdapterCircuit(m, [m.meth1])) - - def test_redefine(self): - class Redefine(Elaboratable): - def elaborate(self, platform): - m = TModule() - m._MustUse__silence = True # type: ignore - - meth = Method() - - with meth.body(m): - pass - - with meth.body(m): - pass - - self.assert_re("already defined", Redefine()) - - def test_undefined_in_trans(self): - class Undefined(Elaboratable): - def __init__(self): - self.meth = Method(i=data_layout(1)) - - def elaborate(self, platform): - return TModule() - - class Circuit(Elaboratable): - def elaborate(self, platform): - m = TModule() - - m.submodules.undefined = undefined = Undefined() - m.submodules.adapter = AdapterTrans(undefined.meth) - - return m - - self.assert_re("not defined", Circuit()) - - -WIDTH = 8 - - -class Quadruple(Elaboratable): - def __init__(self): - layout = data_layout(WIDTH) - self.id = Method(i=layout, o=layout) - self.double = Method(i=layout, o=layout) - self.quadruple = Method(i=layout, o=layout) - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.id) - def _(arg): - return arg - - @def_method(m, self.double) - def _(arg): - return {"data": self.id(m, arg).data * 2} - - @def_method(m, self.quadruple) - def _(arg): - return {"data": self.double(m, arg).data * 2} - - return m - - -class QuadrupleCircuit(Elaboratable): - def __init__(self, quadruple): - self.quadruple = quadruple - - def elaborate(self, platform): - m = TModule() - - m.submodules.quadruple = self.quadruple - m.submodules.tb = self.tb = TestbenchIO(AdapterTrans(self.quadruple.quadruple)) - - return m - - -class Quadruple2(Elaboratable): - def __init__(self): - layout = data_layout(WIDTH) - self.quadruple = Method(i=layout, o=layout) - - def elaborate(self, platform): - m = TModule() - - m.submodules.sub = Quadruple() - - @def_method(m, self.quadruple) - def _(arg): - return {"data": 2 * m.submodules.sub.double(m, arg).data} - - return m - - -class TestQuadrupleCircuits(TestCaseWithSimulator): - @parameterized.expand([(Quadruple,), (Quadruple2,)]) - def test(self, quadruple): - circ = QuadrupleCircuit(quadruple()) - - async def process(sim): - for n in range(1 << (WIDTH - 2)): - out = await circ.tb.call(sim, data=n) - assert out["data"] == n * 4 - - with self.run_simulation(circ) as sim: - sim.add_testbench(process) - - -class ConditionalCallCircuit(Elaboratable): - def elaborate(self, platform): - m = TModule() - - meth = Method(i=data_layout(1)) - - m.submodules.tb = self.tb = TestbenchIO(AdapterTrans(meth)) - m.submodules.out = self.out = TestbenchIO(Adapter()) - - @def_method(m, meth) - def _(arg): - with m.If(arg): - self.out.adapter.iface(m) - - return m - - -class ConditionalMethodCircuit1(Elaboratable): - def elaborate(self, platform): - m = TModule() - - meth = Method() - - self.ready = Signal() - m.submodules.tb = self.tb = TestbenchIO(AdapterTrans(meth)) - - @def_method(m, meth, ready=self.ready) - def _(arg): - pass - - return m - - -class ConditionalMethodCircuit2(Elaboratable): - def elaborate(self, platform): - m = TModule() - - meth = Method() - - self.ready = Signal() - m.submodules.tb = self.tb = TestbenchIO(AdapterTrans(meth)) - - with m.If(self.ready): - - @def_method(m, meth) - def _(arg): - pass - - return m - - -class ConditionalTransactionCircuit1(Elaboratable): - def elaborate(self, platform): - m = TModule() - - self.ready = Signal() - m.submodules.tb = self.tb = TestbenchIO(Adapter()) - - with Transaction().body(m, request=self.ready): - self.tb.adapter.iface(m) - - return m - - -class ConditionalTransactionCircuit2(Elaboratable): - def elaborate(self, platform): - m = TModule() - - self.ready = Signal() - m.submodules.tb = self.tb = TestbenchIO(Adapter()) - - with m.If(self.ready): - with Transaction().body(m): - self.tb.adapter.iface(m) - - return m - - -class TestConditionals(TestCaseWithSimulator): - def test_conditional_call(self): - circ = ConditionalCallCircuit() - - async def process(sim): - circ.out.disable(sim) - circ.tb.call_init(sim, data=0) - *_, out_done, tb_done = await sim.tick().sample(circ.out.adapter.done, circ.tb.adapter.done) - assert not out_done and not tb_done - - circ.out.enable(sim) - *_, out_done, tb_done = await sim.tick().sample(circ.out.adapter.done, circ.tb.adapter.done) - assert not out_done and tb_done - - circ.tb.call_init(sim, data=1) - *_, out_done, tb_done = await sim.tick().sample(circ.out.adapter.done, circ.tb.adapter.done) - assert out_done and tb_done - - # the argument is still 1 but the method is not called - circ.tb.disable(sim) - *_, out_done, tb_done = await sim.tick().sample(circ.out.adapter.done, circ.tb.adapter.done) - assert not out_done and not tb_done - - with self.run_simulation(circ) as sim: - sim.add_testbench(process) - - @parameterized.expand( - [ - (ConditionalMethodCircuit1,), - (ConditionalMethodCircuit2,), - (ConditionalTransactionCircuit1,), - (ConditionalTransactionCircuit2,), - ] - ) - def test_conditional(self, elaboratable): - circ = elaboratable() - - async def process(sim): - circ.tb.enable(sim) - sim.set(circ.ready, 0) - *_, tb_done = await sim.tick().sample(circ.tb.adapter.done) - assert not tb_done - - sim.set(circ.ready, 1) - *_, tb_done = await sim.tick().sample(circ.tb.adapter.done) - assert tb_done - - with self.run_simulation(circ) as sim: - sim.add_testbench(process) - - -class NonexclusiveMethodCircuit(Elaboratable): - def elaborate(self, platform): - m = TModule() - - self.ready = Signal() - self.running = Signal() - self.data = Signal(WIDTH) - - method = Method(o=data_layout(WIDTH), nonexclusive=True) - - @def_method(m, method, self.ready) - def _(): - m.d.comb += self.running.eq(1) - return {"data": self.data} - - m.submodules.t1 = self.t1 = TestbenchIO(AdapterTrans(method)) - m.submodules.t2 = self.t2 = TestbenchIO(AdapterTrans(method)) - - return m - - -class TestNonexclusiveMethod(TestCaseWithSimulator): - def test_nonexclusive_method(self): - circ = NonexclusiveMethodCircuit() - - async def process(sim): - for x in range(8): - t1en = bool(x & 1) - t2en = bool(x & 2) - mrdy = bool(x & 4) - - circ.t1.set_enable(sim, t1en) - circ.t2.set_enable(sim, t2en) - sim.set(circ.ready, int(mrdy)) - sim.set(circ.data, x) - - *_, running, t1_done, t2_done, t1_outputs, t2_outputs = await sim.delay(1e-9).sample( - circ.running, circ.t1.done, circ.t2.done, circ.t1.outputs, circ.t2.outputs - ) - - assert bool(running) == ((t1en or t2en) and mrdy) - assert bool(t1_done) == (t1en and mrdy) - assert bool(t2_done) == (t2en and mrdy) - - if t1en and mrdy: - assert t1_outputs["data"] == x - - if t2en and mrdy: - assert t2_outputs["data"] == x - - with self.run_simulation(circ) as sim: - sim.add_testbench(process) - - -class TwoNonexclusiveConflictCircuit(Elaboratable): - def __init__(self, two_nonexclusive: bool): - self.two_nonexclusive = two_nonexclusive - - def elaborate(self, platform): - m = TModule() - - self.running1 = Signal() - self.running2 = Signal() - - method1 = Method(o=data_layout(WIDTH), nonexclusive=True) - method2 = Method(o=data_layout(WIDTH), nonexclusive=self.two_nonexclusive) - method_in = Method(o=data_layout(WIDTH)) - - @def_method(m, method_in) - def _(): - return {"data": 0} - - @def_method(m, method1) - def _(): - m.d.comb += self.running1.eq(1) - return method_in(m) - - @def_method(m, method2) - def _(): - m.d.comb += self.running2.eq(1) - return method_in(m) - - m.submodules.t1 = self.t1 = TestbenchIO(AdapterTrans(method1)) - m.submodules.t2 = self.t2 = TestbenchIO(AdapterTrans(method2)) - - return m - - -class TestConflicting(TestCaseWithSimulator): - @pytest.mark.parametrize( - "test_circuit", [lambda: TwoNonexclusiveConflictCircuit(False), lambda: TwoNonexclusiveConflictCircuit(True)] - ) - def test_conflicting(self, test_circuit: Callable[[], TwoNonexclusiveConflictCircuit]): - circ = test_circuit() - - async def process(sim): - circ.t1.enable(sim) - circ.t2.enable(sim) - *_, running1, running2 = await sim.delay(1e-9).sample(circ.running1, circ.running2) - - assert not running1 or not running2 - - with self.run_simulation(circ) as sim: - sim.add_testbench(process) - - -class CustomCombinerMethodCircuit(Elaboratable): - def elaborate(self, platform): - m = TModule() - - self.ready = Signal() - self.running = Signal() - - def combiner(m: Module, args: Sequence[MethodStruct], runs: Value) -> AssignArg: - result = C(0) - for i, v in enumerate(args): - result = result ^ Mux(runs[i], v.data, 0) - return {"data": result} - - method = Method(i=data_layout(WIDTH), o=data_layout(WIDTH), nonexclusive=True, combiner=combiner) - - @def_method(m, method, self.ready) - def _(data: Value): - m.d.comb += self.running.eq(1) - return {"data": data} - - m.submodules.t1 = self.t1 = TestbenchIO(AdapterTrans(method)) - m.submodules.t2 = self.t2 = TestbenchIO(AdapterTrans(method)) - - return m - - -class TestCustomCombinerMethod(TestCaseWithSimulator): - def test_custom_combiner_method(self): - circ = CustomCombinerMethodCircuit() - - async def process(sim): - for x in range(8): - t1en = bool(x & 1) - t2en = bool(x & 2) - mrdy = bool(x & 4) - - val1 = random.randrange(0, 2**WIDTH) - val2 = random.randrange(0, 2**WIDTH) - val1e = val1 if t1en else 0 - val2e = val2 if t2en else 0 - - circ.t1.call_init(sim, data=val1) - circ.t2.call_init(sim, data=val2) - - circ.t1.set_enable(sim, t1en) - circ.t2.set_enable(sim, t2en) - - sim.set(circ.ready, int(mrdy)) - - *_, running, t1_done, t2_done, t1_outputs, t2_outputs = await sim.delay(1e-9).sample( - circ.running, circ.t1.done, circ.t2.done, circ.t1.outputs, circ.t2.outputs - ) - - assert bool(running) == ((t1en or t2en) and mrdy) - assert bool(t1_done) == (t1en and mrdy) - assert bool(t2_done) == (t2en and mrdy) - - if t1en and mrdy: - assert t1_outputs["data"] == val1e ^ val2e - - if t2en and mrdy: - assert t2_outputs["data"] == val1e ^ val2e - - with self.run_simulation(circ) as sim: - sim.add_testbench(process) - - -class DataDependentConditionalCircuit(Elaboratable): - def __init__(self, n=2, ready_function=lambda arg: arg.data != 3): - self.method = Method(i=data_layout(n)) - self.ready_function = ready_function - - self.in_t1 = Signal(n) - self.in_t2 = Signal(n) - self.ready = Signal() - self.req_t1 = Signal() - self.req_t2 = Signal() - - self.out_m = Signal() - self.out_t1 = Signal() - self.out_t2 = Signal() - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.method, self.ready, validate_arguments=self.ready_function) - def _(data): - m.d.comb += self.out_m.eq(1) - - with Transaction().body(m, request=self.req_t1): - m.d.comb += self.out_t1.eq(1) - self.method(m, data=self.in_t1) - - with Transaction().body(m, request=self.req_t2): - m.d.comb += self.out_t2.eq(1) - self.method(m, data=self.in_t2) - - return m - - -class TestDataDependentConditionalMethod(TestCaseWithSimulator): - def setup_method(self): - self.test_number = 200 - self.bad_number = 3 - self.n = 2 - - def base_random(self, f): - random.seed(14) - self.circ = DataDependentConditionalCircuit(n=self.n, ready_function=f) - - async def process(sim): - for _ in range(self.test_number): - in1 = random.randrange(0, 2**self.n) - in2 = random.randrange(0, 2**self.n) - m_ready = random.randrange(2) - req_t1 = random.randrange(2) - req_t2 = random.randrange(2) - - sim.set(self.circ.in_t1, in1) - sim.set(self.circ.in_t2, in2) - sim.set(self.circ.req_t1, req_t1) - sim.set(self.circ.req_t2, req_t2) - sim.set(self.circ.ready, m_ready) - - *_, out_m, out_t1, out_t2 = await sim.delay(1e-9).sample( - self.circ.out_m, self.circ.out_t1, self.circ.out_t2 - ) - - if not m_ready or (not req_t1 or in1 == self.bad_number) and (not req_t2 or in2 == self.bad_number): - assert out_m == 0 - assert out_t1 == 0 - assert out_t2 == 0 - continue - # Here method global ready signal is high and we requested one of the transactions - # we also know that one of the transactions request correct input data - - assert out_m == 1 - assert out_t1 ^ out_t2 == 1 - # inX == self.bad_number implies out_tX==0 - assert in1 != self.bad_number or not out_t1 - assert in2 != self.bad_number or not out_t2 - - await sim.tick() - - with self.run_simulation(self.circ, 100) as sim: - sim.add_testbench(process) - - def test_random_arg(self): - self.base_random(lambda arg: arg.data != self.bad_number) - - def test_random_kwarg(self): - self.base_random(lambda data: data != self.bad_number) diff --git a/test/transactron/test_metrics.py b/test/transactron/test_metrics.py deleted file mode 100644 index 181dcc839..000000000 --- a/test/transactron/test_metrics.py +++ /dev/null @@ -1,526 +0,0 @@ -import json -import random -import queue -from typing import Type -from enum import IntFlag, IntEnum, auto, Enum - -from parameterized import parameterized_class - -from amaranth import * - -from transactron.lib.metrics import * -from transactron import * -from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit, TestbenchContext -from transactron.testing.tick_count import TicksKey -from transactron.utils.dependencies import DependencyContext - - -class CounterInMethodCircuit(Elaboratable): - def __init__(self): - self.method = Method() - self.counter = HwCounter("in_method") - - def elaborate(self, platform): - m = TModule() - - m.submodules.counter = self.counter - - @def_method(m, self.method) - def _(): - self.counter.incr(m) - - return m - - -class CounterWithConditionInMethodCircuit(Elaboratable): - def __init__(self): - self.method = Method(i=[("cond", 1)]) - self.counter = HwCounter("with_condition_in_method") - - def elaborate(self, platform): - m = TModule() - - m.submodules.counter = self.counter - - @def_method(m, self.method) - def _(cond): - self.counter.incr(m, cond=cond) - - return m - - -class CounterWithoutMethodCircuit(Elaboratable): - def __init__(self): - self.cond = Signal() - self.counter = HwCounter("with_condition_without_method") - - def elaborate(self, platform): - m = TModule() - - m.submodules.counter = self.counter - - with Transaction().body(m): - self.counter.incr(m, cond=self.cond) - - return m - - -class TestHwCounter(TestCaseWithSimulator): - def setup_method(self) -> None: - random.seed(42) - - def test_counter_in_method(self): - m = SimpleTestCircuit(CounterInMethodCircuit()) - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - async def test_process(sim): - called_cnt = 0 - for _ in range(200): - call_now = random.randint(0, 1) == 0 - - if call_now: - await m.method.call(sim) - called_cnt += 1 - else: - await sim.tick() - - assert called_cnt == sim.get(m._dut.counter.count.value) - - with self.run_simulation(m) as sim: - sim.add_testbench(test_process) - - def test_counter_with_condition_in_method(self): - m = SimpleTestCircuit(CounterWithConditionInMethodCircuit()) - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - async def test_process(sim): - called_cnt = 0 - for _ in range(200): - call_now = random.randint(0, 1) == 0 - condition = random.randint(0, 1) - - if call_now: - await m.method.call(sim, cond=condition) - called_cnt += condition - else: - await sim.tick() - - assert called_cnt == sim.get(m._dut.counter.count.value) - - with self.run_simulation(m) as sim: - sim.add_testbench(test_process) - - def test_counter_with_condition_without_method(self): - m = CounterWithoutMethodCircuit() - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - async def test_process(sim): - called_cnt = 0 - for _ in range(200): - condition = random.randint(0, 1) - - sim.set(m.cond, condition) - await sim.tick() - - if condition == 1: - called_cnt += 1 - - assert called_cnt == sim.get(m.counter.count.value) - - with self.run_simulation(m) as sim: - sim.add_testbench(test_process) - - -class OneHotEnum(IntFlag): - ADD = auto() - XOR = auto() - OR = auto() - - -class PlainIntEnum(IntEnum): - TEST_1 = auto() - TEST_2 = auto() - TEST_3 = auto() - - -class TaggedCounterCircuit(Elaboratable): - def __init__(self, tags: range | Type[Enum] | list[int]): - self.counter = TaggedCounter("counter", "", tags=tags) - - self.cond = Signal() - self.tag = Signal(self.counter.tag_width) - - def elaborate(self, platform): - m = TModule() - - m.submodules.counter = self.counter - - with Transaction().body(m): - self.counter.incr(m, self.tag, cond=self.cond) - - return m - - -class TestTaggedCounter(TestCaseWithSimulator): - def setup_method(self) -> None: - random.seed(42) - - def do_test_enum(self, tags: range | Type[Enum] | list[int], tag_values: list[int]): - m = TaggedCounterCircuit(tags) - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - counts: dict[int, int] = {} - for i in tag_values: - counts[i] = 0 - - async def test_process(sim): - for _ in range(200): - for i in tag_values: - assert counts[i] == sim.get(m.counter.counters[i].value) - - tag = random.choice(list(tag_values)) - - sim.set(m.cond, 1) - sim.set(m.tag, tag) - await sim.tick() - sim.set(m.cond, 0) - await sim.tick() - - counts[tag] += 1 - - with self.run_simulation(m) as sim: - sim.add_testbench(test_process) - - def test_one_hot_enum(self): - self.do_test_enum(OneHotEnum, [e.value for e in OneHotEnum]) - - def test_plain_int_enum(self): - self.do_test_enum(PlainIntEnum, [e.value for e in PlainIntEnum]) - - def test_negative_range(self): - r = range(-10, 15, 3) - self.do_test_enum(r, list(r)) - - def test_positive_range(self): - r = range(0, 30, 2) - self.do_test_enum(r, list(r)) - - def test_value_list(self): - values = [-2137, 2, 4, 8, 42] - self.do_test_enum(values, values) - - -class ExpHistogramCircuit(Elaboratable): - def __init__(self, bucket_cnt: int, sample_width: int): - self.sample_width = sample_width - - self.method = Method(i=data_layout(32)) - self.histogram = HwExpHistogram("histogram", bucket_count=bucket_cnt, sample_width=sample_width) - - def elaborate(self, platform): - m = TModule() - - m.submodules.histogram = self.histogram - - @def_method(m, self.method) - def _(data): - self.histogram.add(m, data[0 : self.sample_width]) - - return m - - -@parameterized_class( - ("bucket_count", "sample_width"), - [ - (5, 5), # last bucket is [8, inf), max sample=31 - (8, 5), # last bucket is [64, inf), max sample=31 - (8, 6), # last bucket is [64, inf), max sample=63 - (8, 20), # last bucket is [64, inf), max sample=big - ], -) -class TestHwHistogram(TestCaseWithSimulator): - bucket_count: int - sample_width: int - - def test_histogram(self): - random.seed(42) - - m = SimpleTestCircuit(ExpHistogramCircuit(bucket_cnt=self.bucket_count, sample_width=self.sample_width)) - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - max_sample_value = 2**self.sample_width - 1 - - async def test_process(sim): - min = max_sample_value - max = 0 - sum = 0 - count = 0 - - buckets = [0] * self.bucket_count - - for _ in range(500): - if random.randrange(3) == 0: - value = random.randint(0, max_sample_value) - if value < min: - min = value - if value > max: - max = value - sum += value - count += 1 - for i in range(self.bucket_count): - if value < 2**i or i == self.bucket_count - 1: - buckets[i] += 1 - break - await m.method.call(sim, data=value) - else: - await sim.tick() - - histogram = m._dut.histogram - - assert min == sim.get(histogram.min.value) - assert max == sim.get(histogram.max.value) - assert sum == sim.get(histogram.sum.value) - assert count == sim.get(histogram.count.value) - - total_count = 0 - for i in range(self.bucket_count): - bucket_value = sim.get(histogram.buckets[i].value) - total_count += bucket_value - assert buckets[i] == bucket_value - - # Sanity check if all buckets sum up to the total count value - assert total_count == sim.get(histogram.count.value) - - with self.run_simulation(m) as sim: - sim.add_testbench(test_process) - - -class TestLatencyMeasurerBase(TestCaseWithSimulator): - def check_latencies(self, sim, m: SimpleTestCircuit, latencies: list[int]): - assert min(latencies) == sim.get(m._dut.histogram.min.value) - assert max(latencies) == sim.get(m._dut.histogram.max.value) - assert sum(latencies) == sim.get(m._dut.histogram.sum.value) - assert len(latencies) == sim.get(m._dut.histogram.count.value) - - for i in range(m._dut.histogram.bucket_count): - bucket_start = 0 if i == 0 else 2 ** (i - 1) - bucket_end = 1e10 if i == m._dut.histogram.bucket_count - 1 else 2**i - - count = sum(1 for x in latencies if bucket_start <= x < bucket_end) - assert count == sim.get(m._dut.histogram.buckets[i].value) - - -@parameterized_class( - ("slots_number", "expected_consumer_wait"), - [ - (2, 5), - (2, 10), - (5, 10), - (10, 1), - (10, 10), - (5, 5), - ], -) -class TestFIFOLatencyMeasurer(TestLatencyMeasurerBase): - slots_number: int - expected_consumer_wait: float - - def test_latency_measurer(self): - random.seed(42) - - m = SimpleTestCircuit(FIFOLatencyMeasurer("latency", slots_number=self.slots_number, max_latency=300)) - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - latencies: list[int] = [] - - event_queue = queue.Queue() - - finish = False - - async def producer(sim: TestbenchContext): - nonlocal finish - ticks = DependencyContext.get().get_dependency(TicksKey()) - - for _ in range(200): - await m._start.call(sim) - - event_queue.put(sim.get(ticks)) - await self.random_wait_geom(sim, 0.8) - - finish = True - - async def consumer(sim: TestbenchContext): - ticks = DependencyContext.get().get_dependency(TicksKey()) - - while not finish: - await m._stop.call(sim) - - latencies.append(sim.get(ticks) - event_queue.get()) - - await self.random_wait_geom(sim, 1.0 / self.expected_consumer_wait) - - self.check_latencies(sim, m, latencies) - - with self.run_simulation(m) as sim: - sim.add_testbench(producer) - sim.add_testbench(consumer) - - -@parameterized_class( - ("slots_number", "expected_consumer_wait"), - [ - (2, 5), - (2, 10), - (5, 10), - (10, 1), - (10, 10), - (5, 5), - ], -) -class TestIndexedLatencyMeasurer(TestLatencyMeasurerBase): - slots_number: int - expected_consumer_wait: float - - def test_latency_measurer(self): - random.seed(42) - - m = SimpleTestCircuit(TaggedLatencyMeasurer("latency", slots_number=self.slots_number, max_latency=300)) - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - latencies: list[int] = [] - - events = list(0 for _ in range(self.slots_number)) - free_slots = list(k for k in range(self.slots_number)) - used_slots: list[int] = [] - - finish = False - - async def producer(sim): - nonlocal finish - - tick_count = DependencyContext.get().get_dependency(TicksKey()) - - for _ in range(200): - while not free_slots: - await sim.tick() - await sim.delay(1e-9) - - slot_id = random.choice(free_slots) - await m._start.call(sim, slot=slot_id) - - events[slot_id] = sim.get(tick_count) - free_slots.remove(slot_id) - used_slots.append(slot_id) - - await self.random_wait_geom(sim, 0.8) - - finish = True - - async def consumer(sim): - tick_count = DependencyContext.get().get_dependency(TicksKey()) - - while not finish: - while not used_slots: - await sim.tick() - - slot_id = random.choice(used_slots) - - await m._stop.call(sim, slot=slot_id) - - await sim.delay(2e-9) - - latencies.append(sim.get(tick_count) - events[slot_id]) - used_slots.remove(slot_id) - free_slots.append(slot_id) - - await self.random_wait_geom(sim, 1.0 / self.expected_consumer_wait) - - self.check_latencies(sim, m, latencies) - - with self.run_simulation(m) as sim: - sim.add_testbench(producer) - sim.add_testbench(consumer) - - -class MetricManagerTestCircuit(Elaboratable): - def __init__(self): - self.incr_counters = Method(i=[("counter1", 1), ("counter2", 1), ("counter3", 1)]) - - self.counter1 = HwCounter("foo.counter1", "this is the description") - self.counter2 = HwCounter("bar.baz.counter2") - self.counter3 = HwCounter("bar.baz.counter3", "yet another description") - - def elaborate(self, platform): - m = TModule() - - m.submodules += [self.counter1, self.counter2, self.counter3] - - @def_method(m, self.incr_counters) - def _(counter1, counter2, counter3): - self.counter1.incr(m, cond=counter1) - self.counter2.incr(m, cond=counter2) - self.counter3.incr(m, cond=counter3) - - return m - - -class TestMetricsManager(TestCaseWithSimulator): - def test_metrics_metadata(self): - # We need to initialize the circuit to make sure that metrics are registered - # in the dependency manager. - m = MetricManagerTestCircuit() - metrics_manager = HardwareMetricsManager() - - # Run the simulation so Amaranth doesn't scream that we have unused elaboratables. - with self.run_simulation(m): - pass - - assert metrics_manager.get_metrics()["foo.counter1"].to_json() == json.dumps( # type: ignore - { - "fully_qualified_name": "foo.counter1", - "description": "this is the description", - "regs": {"count": {"name": "count", "description": "the value of the counter", "width": 32}}, - } - ) - - assert metrics_manager.get_metrics()["bar.baz.counter2"].to_json() == json.dumps( # type: ignore - { - "fully_qualified_name": "bar.baz.counter2", - "description": "", - "regs": {"count": {"name": "count", "description": "the value of the counter", "width": 32}}, - } - ) - - assert metrics_manager.get_metrics()["bar.baz.counter3"].to_json() == json.dumps( # type: ignore - { - "fully_qualified_name": "bar.baz.counter3", - "description": "yet another description", - "regs": {"count": {"name": "count", "description": "the value of the counter", "width": 32}}, - } - ) - - def test_returned_reg_values(self): - random.seed(42) - - m = SimpleTestCircuit(MetricManagerTestCircuit()) - metrics_manager = HardwareMetricsManager() - - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - async def test_process(sim): - counters = [0] * 3 - for _ in range(200): - rand = [random.randint(0, 1) for _ in range(3)] - - await m.incr_counters.call(sim, counter1=rand[0], counter2=rand[1], counter3=rand[2]) - - for i in range(3): - if rand[i] == 1: - counters[i] += 1 - - assert counters[0] == sim.get(metrics_manager.get_register_value("foo.counter1", "count")) - assert counters[1] == sim.get(metrics_manager.get_register_value("bar.baz.counter2", "count")) - assert counters[2] == sim.get(metrics_manager.get_register_value("bar.baz.counter3", "count")) - - with self.run_simulation(m) as sim: - sim.add_testbench(test_process) diff --git a/test/transactron/test_simultaneous.py b/test/transactron/test_simultaneous.py deleted file mode 100644 index ad492e330..000000000 --- a/test/transactron/test_simultaneous.py +++ /dev/null @@ -1,176 +0,0 @@ -import pytest -from itertools import product -from typing import Optional -from amaranth import * -from amaranth.sim import * -from transactron.testing.method_mock import MethodMock, def_method_mock -from transactron.testing.testbenchio import TestbenchIO - -from transactron.utils import ModuleConnector - -from transactron.testing import SimpleTestCircuit, TestCaseWithSimulator, TestbenchContext - -from transactron import * -from transactron.lib import Adapter, Connect, ConnectTrans - - -def empty_method(m: TModule, method: Method): - @def_method(m, method) - def _(): - pass - - -class SimultaneousDiamondTestCircuit(Elaboratable): - def __init__(self): - self.method_l = Method() - self.method_r = Method() - self.method_u = Method() - self.method_d = Method() - - def elaborate(self, platform): - m = TModule() - - empty_method(m, self.method_l) - empty_method(m, self.method_r) - empty_method(m, self.method_u) - empty_method(m, self.method_d) - - # the only possibilities for the following are: (l, u, r) or (l, d, r) - self.method_l.simultaneous_alternatives(self.method_u, self.method_d) - self.method_r.simultaneous_alternatives(self.method_u, self.method_d) - - return m - - -class TestSimultaneousDiamond(TestCaseWithSimulator): - def test_diamond(self): - circ = SimpleTestCircuit(SimultaneousDiamondTestCircuit()) - - async def process(sim: TestbenchContext): - methods = {"l": circ.method_l, "r": circ.method_r, "u": circ.method_u, "d": circ.method_d} - for i in range(1 << len(methods)): - enables: dict[str, bool] = {} - for k, n in enumerate(methods): - enables[n] = bool(i & (1 << k)) - methods[n].set_enable(sim, enables[n]) - dones: dict[str, bool] = {} - for n in methods: - dones[n] = bool(methods[n].get_done(sim)) - await sim.tick() - for n in methods: - if not enables[n]: - assert not dones[n] - if enables["l"] and enables["r"] and (enables["u"] or enables["d"]): - assert dones["l"] - assert dones["r"] - assert dones["u"] or dones["d"] - else: - assert not any(dones.values()) - - with self.run_simulation(circ) as sim: - sim.add_testbench(process) - - -class UnsatisfiableTriangleTestCircuit(Elaboratable): - def __init__(self): - self.method_l = Method() - self.method_u = Method() - self.method_d = Method() - - def elaborate(self, platform): - m = TModule() - - empty_method(m, self.method_l) - empty_method(m, self.method_u) - empty_method(m, self.method_d) - - # the following is unsatisfiable - self.method_l.simultaneous_alternatives(self.method_u, self.method_d) - self.method_u.simultaneous(self.method_d) - - return m - - -class TestUnsatisfiableTriangle(TestCaseWithSimulator): - def test_unsatisfiable(self): - circ = SimpleTestCircuit(UnsatisfiableTriangleTestCircuit()) - - with pytest.raises(RuntimeError): - with self.run_simulation(circ) as _: - pass - - -class HelperConnect(Elaboratable): - def __init__(self, source: Method, target: Method, request: Signal, data: int): - self.source = source - self.target = target - self.request = request - self.data = data - - def elaborate(self, platform): - m = TModule() - - with Transaction().body(m, request=self.request): - self.target(m, self.data ^ self.source(m).data) - - return m - - -class TransitivityTestCircuit(Elaboratable): - def __init__(self, target: Method, req1: Signal, req2: Signal): - self.source1 = Method(i=[("data", 2)]) - self.source2 = Method(i=[("data", 2)]) - self.target = target - self.req1 = req1 - self.req2 = req2 - - def elaborate(self, platform): - m = TModule() - - m.submodules.c1 = c1 = Connect([("data", 2)]) - m.submodules.c2 = c2 = Connect([("data", 2)]) - self.source1.proxy(m, c1.write) - self.source2.proxy(m, c1.write) - m.submodules.ct = ConnectTrans(c2.read, self.target) - m.submodules.hc1 = HelperConnect(c1.read, c2.write, self.req1, 1) - m.submodules.hc2 = HelperConnect(c1.read, c2.write, self.req2, 2) - - return m - - -class TestTransitivity(TestCaseWithSimulator): - def test_transitivity(self): - target = TestbenchIO(Adapter(i=[("data", 2)])) - req1 = Signal() - req2 = Signal() - - circ = SimpleTestCircuit(TransitivityTestCircuit(target.adapter.iface, req1, req2)) - m = ModuleConnector(test_circuit=circ, target=target) - - result: Optional[int] - - @def_method_mock(lambda: target) - def target_process(data: int): - @MethodMock.effect - def eff(): - nonlocal result - result = data - - async def process(sim: TestbenchContext): - nonlocal result - for source, data, reqv1, reqv2 in product([circ.source1, circ.source2], [0, 1, 2, 3], [0, 1], [0, 1]): - result = None - sim.set(req1, reqv1) - sim.set(req2, reqv2) - call_result = await source.call_try(sim, data=data) - - if not reqv1 and not reqv2: - assert call_result is None - assert result is None - else: - assert call_result is not None - possibles = reqv1 * [data ^ 1] + reqv2 * [data ^ 2] - assert result in possibles - - with self.run_simulation(m) as sim: - sim.add_testbench(process) diff --git a/test/transactron/test_transactron_lib_storage.py b/test/transactron/test_transactron_lib_storage.py deleted file mode 100644 index d5513fe7c..000000000 --- a/test/transactron/test_transactron_lib_storage.py +++ /dev/null @@ -1,131 +0,0 @@ -from datetime import timedelta -from hypothesis import given, settings, Phase -from transactron.testing import * -from transactron.lib.storage import ContentAddressableMemory - - -class TestContentAddressableMemory(TestCaseWithSimulator): - addr_width = 4 - content_width = 5 - test_number = 30 - nop_number = 3 - addr_layout = data_layout(addr_width) - content_layout = data_layout(content_width) - - def setUp(self): - self.entries_count = 8 - - self.circ = SimpleTestCircuit( - ContentAddressableMemory(self.addr_layout, self.content_layout, self.entries_count) - ) - - self.memory = {} - - def generic_process( - self, - method, - input_lst, - behaviour_check=None, - state_change=None, - input_verification=None, - settle_count=0, - name="", - ): - async def f(sim: TestbenchContext): - while input_lst: - # wait till all processes will end the previous cycle - await sim.delay(1e-9) - elem = input_lst.pop() - if isinstance(elem, OpNOP): - await sim.tick() - continue - if input_verification is not None and not input_verification(elem): - await sim.tick() - continue - response = await method.call(sim, **elem) - await sim.delay(settle_count * 1e-9) - if behaviour_check is not None: - behaviour_check(elem, response) - if state_change is not None: - state_change(elem, response) - await sim.tick() - - return f - - def push_process(self, in_push): - def verify_in(elem): - return not (frozenset(elem["addr"].items()) in self.memory) - - def modify_state(elem, response): - self.memory[frozenset(elem["addr"].items())] = elem["data"] - - return self.generic_process( - self.circ.push, - in_push, - state_change=modify_state, - input_verification=verify_in, - settle_count=3, - name="push", - ) - - def read_process(self, in_read): - def check(elem, response): - addr = elem["addr"] - frozen_addr = frozenset(addr.items()) - if frozen_addr in self.memory: - assert response.not_found == 0 - assert data_const_to_dict(response.data) == self.memory[frozen_addr] - else: - assert response.not_found == 1 - - return self.generic_process(self.circ.read, in_read, behaviour_check=check, settle_count=0, name="read") - - def remove_process(self, in_remove): - def modify_state(elem, response): - if frozenset(elem["addr"].items()) in self.memory: - del self.memory[frozenset(elem["addr"].items())] - - return self.generic_process(self.circ.remove, in_remove, state_change=modify_state, settle_count=2, name="remv") - - def write_process(self, in_write): - def verify_in(elem): - ret = frozenset(elem["addr"].items()) in self.memory - return ret - - def check(elem, response): - assert response.not_found == int(frozenset(elem["addr"].items()) not in self.memory) - - def modify_state(elem, response): - if frozenset(elem["addr"].items()) in self.memory: - self.memory[frozenset(elem["addr"].items())] = elem["data"] - - return self.generic_process( - self.circ.write, - in_write, - behaviour_check=check, - state_change=modify_state, - input_verification=None, - settle_count=1, - name="writ", - ) - - @settings( - max_examples=10, - phases=(Phase.explicit, Phase.reuse, Phase.generate, Phase.shrink), - derandomize=True, - deadline=timedelta(milliseconds=500), - ) - @given( - generate_process_input(test_number, nop_number, [("addr", addr_layout), ("data", content_layout)]), - generate_process_input(test_number, nop_number, [("addr", addr_layout), ("data", content_layout)]), - generate_process_input(test_number, nop_number, [("addr", addr_layout)]), - generate_process_input(test_number, nop_number, [("addr", addr_layout)]), - ) - def test_random(self, in_push, in_write, in_read, in_remove): - with self.reinitialize_fixtures(): - self.setUp() - with self.run_simulation(self.circ, max_cycles=500) as sim: - sim.add_testbench(self.push_process(in_push)) - sim.add_testbench(self.read_process(in_read)) - sim.add_testbench(self.write_process(in_write)) - sim.add_testbench(self.remove_process(in_remove)) diff --git a/test/transactron/testing/__init__.py b/test/transactron/testing/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/transactron/testing/test_log.py b/test/transactron/testing/test_log.py deleted file mode 100644 index e0cbd4af8..000000000 --- a/test/transactron/testing/test_log.py +++ /dev/null @@ -1,125 +0,0 @@ -import pytest -import re -from amaranth import * - -from transactron import * -from transactron.testing import TestCaseWithSimulator, TestbenchContext -from transactron.lib import logging - -LOGGER_NAME = "test_logger" - -log = logging.HardwareLogger(LOGGER_NAME) - - -class LogTest(Elaboratable): - def __init__(self): - self.input = Signal(range(100)) - self.counter = Signal(range(200)) - - def elaborate(self, platform): - m = TModule() - - with m.If(self.input == 42): - log.warning(m, True, "Log triggered under Amaranth If value+3=0x{:x}", self.input + 3) - - log.warning(m, self.input[0] == 0, "Input is even! input={}, counter={}", self.input, self.counter) - - m.d.sync += self.counter.eq(self.counter + 1) - - return m - - -class ErrorLogTest(Elaboratable): - def __init__(self): - self.input = Signal() - self.output = Signal() - - def elaborate(self, platform): - m = TModule() - - m.d.comb += self.output.eq(self.input & ~self.input) - - log.error( - m, - self.input != self.output, - "Input is different than output! input=0x{:x} output=0x{:x}", - self.input, - self.output, - ) - - return m - - -class AssertionTest(Elaboratable): - def __init__(self): - self.input = Signal() - self.output = Signal() - - def elaborate(self, platform): - m = TModule() - - m.d.comb += self.output.eq(self.input & ~self.input) - - log.assertion(m, self.input == self.output, "Output differs") - - return m - - -class TestLog(TestCaseWithSimulator): - def test_log(self, caplog): - m = LogTest() - - async def proc(sim: TestbenchContext): - for i in range(50): - await sim.tick() - sim.set(m.input, i) - - with self.run_simulation(m) as sim: - sim.add_testbench(proc) - - assert re.search( - r"WARNING test_logger:logging\.py:\d+ \[test/transactron/testing/test_log\.py:\d+\] " - + r"Log triggered under Amaranth If value\+3=0x2d", - caplog.text, - ) - for i in range(0, 50, 2): - assert re.search( - r"WARNING test_logger:logging\.py:\d+ \[test/transactron/testing/test_log\.py:\d+\] " - + f"Input is even! input={i}, counter={i + 1}", - caplog.text, - ) - - def test_error_log(self, caplog): - m = ErrorLogTest() - - async def proc(sim: TestbenchContext): - await sim.tick() - sim.set(m.input, 1) - await sim.tick() # A log after the last tick is not handled - - with pytest.raises(AssertionError): - with self.run_simulation(m) as sim: - sim.add_testbench(proc) - - assert re.search( - r"ERROR test_logger:logging\.py:\d+ \[test/transactron/testing/test_log\.py:\d+\] " - + "Input is different than output! input=0x1 output=0x0", - caplog.text, - ) - - def test_assertion(self, caplog): - m = AssertionTest() - - async def proc(sim: TestbenchContext): - await sim.tick() - sim.set(m.input, 1) - await sim.tick() # A log after the last tick is not handled - - with pytest.raises(AssertionError): - with self.run_simulation(m) as sim: - sim.add_testbench(proc) - - assert re.search( - r"ERROR test_logger:logging\.py:\d+ \[test/transactron/testing/test_log\.py:\d+\] Output differs", - caplog.text, - ) diff --git a/test/transactron/testing/test_validate_arguments.py b/test/transactron/testing/test_validate_arguments.py deleted file mode 100644 index 7e7036975..000000000 --- a/test/transactron/testing/test_validate_arguments.py +++ /dev/null @@ -1,61 +0,0 @@ -import random -from amaranth import * -from amaranth.sim import * - -from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout, TestbenchContext - -from transactron import * -from transactron.testing.method_mock import def_method_mock -from transactron.lib import * -from transactron.testing.testbenchio import CallTrigger - - -class ValidateArgumentsTestCircuit(Elaboratable): - def elaborate(self, platform): - m = Module() - - self.method = TestbenchIO(Adapter(i=data_layout(1), o=data_layout(1)).set(with_validate_arguments=True)) - self.caller1 = TestbenchIO(AdapterTrans(self.method.adapter.iface)) - self.caller2 = TestbenchIO(AdapterTrans(self.method.adapter.iface)) - - m.submodules += [self.method, self.caller1, self.caller2] - - return m - - -class TestValidateArguments(TestCaseWithSimulator): - def control_caller(self, caller: TestbenchIO, method: TestbenchIO): - async def process(sim: TestbenchContext): - await sim.tick() - for _ in range(100): - val = random.randrange(2) - pre_accepted_val = self.accepted_val - caller_data, method_data = await CallTrigger(sim).call(caller, data=val).sample(method) - if caller_data is not None: - assert val == pre_accepted_val - assert caller_data.data == val - else: - assert val != pre_accepted_val or val == pre_accepted_val and method_data is not None - - return process - - def validate_arguments(self, data: int): - return data == self.accepted_val - - async def changer(self, sim: TestbenchContext): - for _ in range(50): - await sim.tick() - self.accepted_val = 1 - - @def_method_mock(tb_getter=lambda self: self.m.method, validate_arguments=validate_arguments) - def method_mock(self, data: int): - return {"data": data} - - def test_validate_arguments(self): - random.seed(42) - self.m = ValidateArgumentsTestCircuit() - self.accepted_val = 0 - with self.run_simulation(self.m) as sim: - sim.add_testbench(self.changer) - sim.add_testbench(self.control_caller(self.m.caller1, self.m.method)) - sim.add_testbench(self.control_caller(self.m.caller2, self.m.method)) diff --git a/test/transactron/utils/__init__.py b/test/transactron/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/transactron/utils/test_amaranth_ext.py b/test/transactron/utils/test_amaranth_ext.py deleted file mode 100644 index 349fc0b87..000000000 --- a/test/transactron/utils/test_amaranth_ext.py +++ /dev/null @@ -1,135 +0,0 @@ -from transactron.testing import * -import random -import pytest -from transactron.utils.amaranth_ext import MultiPriorityEncoder, RingMultiPriorityEncoder - - -def get_expected_multi(input_width, output_count, input, *args): - places = [] - for i in range(input_width): - if input % 2: - places.append(i) - input //= 2 - places += [None] * output_count - return places - - -def get_expected_ring(input_width, output_count, input, first, last): - places = [] - input = (input << input_width) + input - if last < first: - last += input_width - for i in range(2 * input_width): - if i >= first and i < last and input % 2: - places.append(i % input_width) - input //= 2 - places += [None] * output_count - return places - - -@pytest.mark.parametrize( - "test_class, verif_f", - [(MultiPriorityEncoder, get_expected_multi), (RingMultiPriorityEncoder, get_expected_ring)], - ids=["MultiPriorityEncoder", "RingMultiPriorityEncoder"], -) -class TestPriorityEncoder(TestCaseWithSimulator): - def process(self, get_expected): - async def f(sim: TestbenchContext): - for _ in range(self.test_number): - input = random.randrange(2**self.input_width) - first = random.randrange(self.input_width) - last = random.randrange(self.input_width) - sim.set(self.circ.input, input) - try: - sim.set(self.circ.first, first) - sim.set(self.circ.last, last) - except AttributeError: - pass - expected_output = get_expected(self.input_width, self.output_count, input, first, last) - for ex, real, valid in zip(expected_output, self.circ.outputs, self.circ.valids): - if ex is None: - assert sim.get(valid) == 0 - else: - assert sim.get(valid) == 1 - assert sim.get(real) == ex - await sim.delay(1e-7) - - return f - - @pytest.mark.parametrize("input_width", [1, 5, 16, 23, 24]) - @pytest.mark.parametrize("output_count", [1, 3, 4]) - def test_random(self, test_class, verif_f, input_width, output_count): - random.seed(input_width + output_count) - self.test_number = 50 - self.input_width = input_width - self.output_count = output_count - self.circ = test_class(self.input_width, self.output_count) - - with self.run_simulation(self.circ) as sim: - sim.add_testbench(self.process(verif_f)) - - @pytest.mark.parametrize("name", ["prio_encoder", None]) - def test_static_create_simple(self, test_class, verif_f, name): - random.seed(14) - self.test_number = 50 - self.input_width = 7 - self.output_count = 1 - - class DUT(Elaboratable): - def __init__(self, input_width, output_count, name): - self.input = Signal(input_width) - self.first = Signal(range(input_width)) - self.last = Signal(range(input_width)) - self.output_count = output_count - self.input_width = input_width - self.name = name - - def elaborate(self, platform): - m = Module() - if test_class == MultiPriorityEncoder: - out, val = test_class.create_simple(m, self.input_width, self.input, name=self.name) - else: - out, val = test_class.create_simple( - m, self.input_width, self.input, self.first, self.last, name=self.name - ) - # Save as a list to use common interface in testing - self.outputs = [out] - self.valids = [val] - return m - - self.circ = DUT(self.input_width, self.output_count, name) - - with self.run_simulation(self.circ) as sim: - sim.add_testbench(self.process(verif_f)) - - @pytest.mark.parametrize("name", ["prio_encoder", None]) - def test_static_create(self, test_class, verif_f, name): - random.seed(14) - self.test_number = 50 - self.input_width = 7 - self.output_count = 2 - - class DUT(Elaboratable): - def __init__(self, input_width, output_count, name): - self.input = Signal(input_width) - self.first = Signal(range(input_width)) - self.last = Signal(range(input_width)) - self.output_count = output_count - self.input_width = input_width - self.name = name - - def elaborate(self, platform): - m = Module() - if test_class == MultiPriorityEncoder: - out = test_class.create(m, self.input_width, self.input, self.output_count, name=self.name) - else: - out = test_class.create( - m, self.input_width, self.input, self.first, self.last, self.output_count, name=self.name - ) - self.outputs, self.valids = list(zip(*out)) - return m - - self.circ = DUT(self.input_width, self.output_count, name) - - with self.run_simulation(self.circ) as sim: - sim.add_testbench(self.process(verif_f)) diff --git a/test/transactron/utils/test_onehotswitch.py b/test/transactron/utils/test_onehotswitch.py deleted file mode 100644 index b0620c0a9..000000000 --- a/test/transactron/utils/test_onehotswitch.py +++ /dev/null @@ -1,59 +0,0 @@ -from amaranth import * -from amaranth.sim import * - -from transactron.utils import OneHotSwitch - -from transactron.testing import TestCaseWithSimulator, TestbenchContext - -from parameterized import parameterized - - -class OneHotSwitchCircuit(Elaboratable): - def __init__(self, width: int, test_zero: bool): - self.input = Signal(1 << width) - self.output = Signal(width) - self.zero = Signal() - self.test_zero = test_zero - - def elaborate(self, platform): - m = Module() - - with OneHotSwitch(m, self.input) as OneHotCase: - for i in range(len(self.input)): - with OneHotCase(1 << i): - m.d.comb += self.output.eq(i) - - if self.test_zero: - with OneHotCase(): - m.d.comb += self.zero.eq(1) - - return m - - -class TestOneHotSwitch(TestCaseWithSimulator): - @parameterized.expand([(False,), (True,)]) - def test_onehotswitch(self, test_zero): - circuit = OneHotSwitchCircuit(4, test_zero) - - async def switch_test_proc(sim: TestbenchContext): - for i in range(len(circuit.input)): - sim.set(circuit.input, 1 << i) - assert sim.get(circuit.output) == i - - with self.run_simulation(circuit) as sim: - sim.add_testbench(switch_test_proc) - - def test_onehotswitch_zero(self): - circuit = OneHotSwitchCircuit(4, True) - - async def switch_test_proc_zero(sim: TestbenchContext): - for i in range(len(circuit.input)): - sim.set(circuit.input, 1 << i) - assert sim.get(circuit.output) == i - assert not sim.get(circuit.zero) - - sim.set(circuit.input, 0) - assert sim.get(circuit.zero) - - with self.run_simulation(circuit) as sim: - sim.add_testbench(switch_test_proc_zero) diff --git a/test/transactron/utils/test_utils.py b/test/transactron/utils/test_utils.py deleted file mode 100644 index abd28f420..000000000 --- a/test/transactron/utils/test_utils.py +++ /dev/null @@ -1,196 +0,0 @@ -import unittest -import random - -from amaranth import * -from transactron.testing import * -from transactron.utils import ( - align_to_power_of_two, - align_down_to_power_of_two, - popcount, - count_leading_zeros, - count_trailing_zeros, -) -from parameterized import parameterized_class - - -class TestAlignToPowerOfTwo(unittest.TestCase): - def test_align_to_power_of_two(self): - test_cases = [ - (2, 2, 4), - (2, 1, 2), - (3, 1, 4), - (7, 3, 8), - (8, 3, 8), - (14, 3, 16), - (17, 3, 24), - (33, 3, 40), - (33, 1, 34), - (33, 0, 33), - (33, 4, 48), - (33, 5, 64), - (33, 6, 64), - ] - - for num, power, expected in test_cases: - out = align_to_power_of_two(num, power) - assert expected == out - - def test_align_down_to_power_of_two(self): - test_cases = [ - (3, 1, 2), - (3, 0, 3), - (3, 3, 0), - (8, 3, 8), - (8, 2, 8), - (33, 5, 32), - (29, 5, 0), - (29, 1, 28), - (29, 3, 24), - ] - - for num, power, expected in test_cases: - out = align_down_to_power_of_two(num, power) - assert expected == out - - -class PopcountTestCircuit(Elaboratable): - def __init__(self, size: int): - self.sig_in = Signal(size) - self.sig_out = Signal(size) - - def elaborate(self, platform): - m = Module() - - m.d.comb += self.sig_out.eq(popcount(self.sig_in)) - - return m - - -@parameterized_class( - ("name", "size"), - [("size" + str(s), s) for s in [2, 3, 4, 5, 6, 8, 10, 16, 21, 32, 33, 64, 1025]], -) -class TestPopcount(TestCaseWithSimulator): - size: int - - def setup_method(self): - random.seed(14) - self.test_number = 40 - self.m = PopcountTestCircuit(self.size) - - def check(self, sim: TestbenchContext, n): - sim.set(self.m.sig_in, n) - out_popcount = sim.get(self.m.sig_out) - assert out_popcount == n.bit_count(), f"{n:x}" - - async def process(self, sim: TestbenchContext): - for i in range(self.test_number): - n = random.randrange(2**self.size) - self.check(sim, n) - sim.delay(1e-6) - self.check(sim, 2**self.size - 1) - - def test_popcount(self): - with self.run_simulation(self.m) as sim: - sim.add_testbench(self.process) - - -class CLZTestCircuit(Elaboratable): - def __init__(self, xlen_log: int): - self.sig_in = Signal(1 << xlen_log) - self.sig_out = Signal(xlen_log + 1) - self.xlen_log = xlen_log - - def elaborate(self, platform): - m = Module() - - m.d.comb += self.sig_out.eq(count_leading_zeros(self.sig_in)) - # dummy signal - s = Signal() - m.d.sync += s.eq(1) - - return m - - -@parameterized_class( - ("name", "size"), - [("size" + str(s), s) for s in range(1, 7)], -) -class TestCountLeadingZeros(TestCaseWithSimulator): - size: int - - def setup_method(self): - random.seed(14) - self.test_number = 40 - self.m = CLZTestCircuit(self.size) - - def check(self, sim: TestbenchContext, n): - sim.set(self.m.sig_in, n) - out_clz = sim.get(self.m.sig_out) - assert out_clz == (2**self.size) - n.bit_length(), f"{n:x}" - - async def process(self, sim: TestbenchContext): - for i in range(self.test_number): - n = random.randrange(2**self.size) - self.check(sim, n) - sim.delay(1e-6) - self.check(sim, 2**self.size - 1) - - def test_count_leading_zeros(self): - with self.run_simulation(self.m) as sim: - sim.add_testbench(self.process) - - -class CTZTestCircuit(Elaboratable): - def __init__(self, xlen_log: int): - self.sig_in = Signal(1 << xlen_log) - self.sig_out = Signal(xlen_log + 1) - self.xlen_log = xlen_log - - def elaborate(self, platform): - m = Module() - - m.d.comb += self.sig_out.eq(count_trailing_zeros(self.sig_in)) - # dummy signal - s = Signal() - m.d.sync += s.eq(1) - - return m - - -@parameterized_class( - ("name", "size"), - [("size" + str(s), s) for s in range(1, 7)], -) -class TestCountTrailingZeros(TestCaseWithSimulator): - size: int - - def setup_method(self): - random.seed(14) - self.test_number = 40 - self.m = CTZTestCircuit(self.size) - - def check(self, sim: TestbenchContext, n): - sim.set(self.m.sig_in, n) - out_ctz = sim.get(self.m.sig_out) - - expected = 0 - if n == 0: - expected = 2**self.size - else: - while (n & 1) == 0: - expected += 1 - n >>= 1 - - assert out_ctz == expected, f"{n:x}" - - async def process(self, sim: TestbenchContext): - for i in range(self.test_number): - n = random.randrange(2**self.size) - self.check(sim, n) - await sim.delay(1e-6) - self.check(sim, 2**self.size - 1) - - def test_count_trailing_zeros(self): - with self.run_simulation(self.m) as sim: - sim.add_testbench(self.process) diff --git a/transactron/__init__.py b/transactron/__init__.py deleted file mode 100644 index c162fe991..000000000 --- a/transactron/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .core import * # noqa: F401 diff --git a/transactron/core/__init__.py b/transactron/core/__init__.py deleted file mode 100644 index 6ead593f8..000000000 --- a/transactron/core/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .tmodule import * # noqa: F401 -from .transaction_base import * # noqa: F401 -from .method import * # noqa: F401 -from .transaction import * # noqa: F401 -from .manager import * # noqa: F401 -from .sugar import * # noqa: F401 diff --git a/transactron/core/keys.py b/transactron/core/keys.py deleted file mode 100644 index 9444dce34..000000000 --- a/transactron/core/keys.py +++ /dev/null @@ -1,13 +0,0 @@ -from transactron.utils import * -from typing import TYPE_CHECKING -from dataclasses import dataclass - -if TYPE_CHECKING: - from .manager import TransactionManager # noqa: F401 because of https://github.com/PyCQA/pyflakes/issues/571 - -__all__ = ["TransactionManagerKey"] - - -@dataclass(frozen=True) -class TransactionManagerKey(SimpleKey["TransactionManager"]): - pass diff --git a/transactron/core/manager.py b/transactron/core/manager.py deleted file mode 100644 index cfbc6b17d..000000000 --- a/transactron/core/manager.py +++ /dev/null @@ -1,537 +0,0 @@ -from collections import defaultdict, deque -from typing import Callable, Iterable, Sequence, TypeAlias, Tuple -from os import environ -from graphlib import TopologicalSorter -from amaranth import * -from amaranth.lib.wiring import Component, connect, flipped -from itertools import chain, filterfalse, product - -from amaranth_types import AbstractComponent - -from transactron.utils import * -from transactron.utils.transactron_helpers import _graph_ccs -from transactron.graph import OwnershipGraph, Direction - -from .transaction_base import TransactionBase, TransactionOrMethod, Priority, Relation -from .method import Method -from .transaction import Transaction, TransactionManagerKey -from .tmodule import TModule -from .schedulers import eager_deterministic_cc_scheduler - -__all__ = ["TransactionManager", "TransactionModule", "TransactionComponent"] - -TransactionGraph: TypeAlias = Graph["Transaction"] -TransactionGraphCC: TypeAlias = GraphCC["Transaction"] -PriorityOrder: TypeAlias = dict["Transaction", int] -TransactionScheduler: TypeAlias = Callable[["MethodMap", TransactionGraph, TransactionGraphCC, PriorityOrder], Module] - - -class MethodMap: - def __init__(self, transactions: Iterable["Transaction"]): - self.methods_by_transaction = dict[Transaction, list[Method]]() - self.transactions_by_method = defaultdict[Method, list[Transaction]](list) - self.readiness_by_call = dict[tuple[Transaction, Method], ValueLike]() - self.ancestors_by_call = dict[tuple[Transaction, Method], tuple[Method, ...]]() - self.method_parents = defaultdict[Method, list[TransactionBase]](list) - - def rec(transaction: Transaction, source: TransactionBase, ancestors: tuple[Method, ...]): - for method, (arg_rec, _) in source.method_uses.items(): - if not method.defined: - raise RuntimeError(f"Trying to use method '{method.name}' which is not defined yet") - if method in self.methods_by_transaction[transaction]: - raise RuntimeError(f"Method '{method.name}' can't be called twice from the same transaction") - self.methods_by_transaction[transaction].append(method) - self.transactions_by_method[method].append(transaction) - self.readiness_by_call[(transaction, method)] = method._validate_arguments(arg_rec) - self.ancestors_by_call[(transaction, method)] = new_ancestors = (method, *ancestors) - rec(transaction, method, new_ancestors) - - for transaction in transactions: - self.methods_by_transaction[transaction] = [] - rec(transaction, transaction, ()) - - for transaction_or_method in self.methods_and_transactions: - for method in transaction_or_method.method_uses.keys(): - self.method_parents[method].append(transaction_or_method) - - def transactions_for(self, elem: TransactionOrMethod) -> Collection["Transaction"]: - if isinstance(elem, Transaction): - return [elem] - else: - return self.transactions_by_method[elem] - - @property - def methods(self) -> Collection["Method"]: - return self.transactions_by_method.keys() - - @property - def transactions(self) -> Collection["Transaction"]: - return self.methods_by_transaction.keys() - - @property - def methods_and_transactions(self) -> Iterable[TransactionOrMethod]: - return chain(self.methods, self.transactions) - - -class TransactionManager(Elaboratable): - """Transaction manager - - This module is responsible for granting `Transaction`\\s and running - `Method`\\s. It takes care that two conflicting `Transaction`\\s - are never granted in the same clock cycle. - """ - - def __init__(self, cc_scheduler: TransactionScheduler = eager_deterministic_cc_scheduler): - self.transactions: list[Transaction] = [] - self.cc_scheduler = cc_scheduler - - def add_transaction(self, transaction: "Transaction"): - self.transactions.append(transaction) - - @staticmethod - def _conflict_graph(method_map: MethodMap) -> Tuple[TransactionGraph, PriorityOrder]: - """_conflict_graph - - This function generates the graph of transaction conflicts. Conflicts - between transactions can be explicit or implicit. Two transactions - conflict explicitly, if a conflict was added between the transactions - or the methods used by them via `add_conflict`. Two transactions - conflict implicitly if they are both using the same method. - - Created graph is undirected. Transactions are nodes in that graph - and conflict between two transactions is marked as an edge. In such - representation connected components are sets of transactions which can - potentially conflict so there is a need to arbitrate between them. - On the other hand when two transactions are in different connected - components, then they can be scheduled independently, because they - will have no conflicts. - - This function also computes a linear ordering of transactions - which is consistent with conflict priorities of methods and - transactions. When priority constraints cannot be satisfied, - an exception is thrown. - - Returns - ------- - cgr : TransactionGraph - Graph of conflicts between transactions, where vertices are transactions and edges are conflicts. - porder : PriorityOrder - Linear ordering of transactions which is consistent with priority constraints. - """ - - def transactions_exclusive(trans1: Transaction, trans2: Transaction): - tms1 = [trans1] + method_map.methods_by_transaction[trans1] - tms2 = [trans2] + method_map.methods_by_transaction[trans2] - - # if first transaction is exclusive with the second transaction, or this is true for - # any called methods, the transactions will never run at the same time - for tm1, tm2 in product(tms1, tms2): - if tm1.ctrl_path.exclusive_with(tm2.ctrl_path): - return True - - return False - - def calls_nonexclusive(trans1: Transaction, trans2: Transaction, method: Method): - ancestors1 = method_map.ancestors_by_call[(trans1, method)] - ancestors2 = method_map.ancestors_by_call[(trans2, method)] - common_ancestors = longest_common_prefix(ancestors1, ancestors2) - return common_ancestors[-1].nonexclusive - - cgr: TransactionGraph = {} # Conflict graph - pgr: TransactionGraph = {} # Priority graph - - def add_edge(begin: Transaction, end: Transaction, priority: Priority, conflict: bool): - if conflict: - cgr[begin].add(end) - cgr[end].add(begin) - match priority: - case Priority.LEFT: - pgr[end].add(begin) - case Priority.RIGHT: - pgr[begin].add(end) - - for transaction in method_map.transactions: - cgr[transaction] = set() - pgr[transaction] = set() - - for method in method_map.methods: - for transaction1 in method_map.transactions_for(method): - for transaction2 in method_map.transactions_for(method): - if ( - transaction1 is not transaction2 - and not transactions_exclusive(transaction1, transaction2) - and not calls_nonexclusive(transaction1, transaction2, method) - ): - add_edge(transaction1, transaction2, Priority.UNDEFINED, True) - - relations = [ - Relation(**relation, start=elem) - for elem in method_map.methods_and_transactions - for relation in elem.relations - ] - - for relation in relations: - start = relation["start"] - end = relation["end"] - if not relation["conflict"]: # relation added with schedule_before - if end.def_order < start.def_order and not relation["silence_warning"]: - raise RuntimeError(f"{start.name!r} scheduled before {end.name!r}, but defined afterwards") - - for trans_start in method_map.transactions_for(start): - for trans_end in method_map.transactions_for(end): - conflict = relation["conflict"] and not transactions_exclusive(trans_start, trans_end) - add_edge(trans_start, trans_end, relation["priority"], conflict) - - porder: PriorityOrder = {} - - for k, transaction in enumerate(TopologicalSorter(pgr).static_order()): - porder[transaction] = k - - return cgr, porder - - @staticmethod - def _method_enables(method_map: MethodMap) -> Mapping["Transaction", Mapping["Method", ValueLike]]: - method_enables = defaultdict[Transaction, dict[Method, ValueLike]](dict) - enables: list[ValueLike] = [] - - def rec(transaction: Transaction, source: TransactionOrMethod): - for method, (_, enable) in source.method_uses.items(): - enables.append(enable) - rec(transaction, method) - method_enables[transaction][method] = Cat(*enables).all() - enables.pop() - - for transaction in method_map.transactions: - rec(transaction, transaction) - - return method_enables - - @staticmethod - def _method_calls( - m: Module, method_map: MethodMap - ) -> tuple[Mapping["Method", Sequence[MethodStruct]], Mapping["Method", Sequence[Value]]]: - args = defaultdict[Method, list[MethodStruct]](list) - runs = defaultdict[Method, list[Value]](list) - - for source in method_map.methods_and_transactions: - if isinstance(source, Method): - run_val = Cat(transaction.grant for transaction in method_map.transactions_by_method[source]).any() - run = Signal() - m.d.comb += run.eq(run_val) - else: - run = source.grant - for method, (arg, _) in source.method_uses.items(): - args[method].append(arg) - runs[method].append(run) - - return (args, runs) - - def _simultaneous(self): - method_map = MethodMap(self.transactions) - - # remove orderings between simultaneous methods/transactions - # TODO: can it be done after transitivity, possibly catching more cases? - for elem in method_map.methods_and_transactions: - all_sims = frozenset(elem.simultaneous_list) - elem.relations = list( - filterfalse( - lambda relation: not relation["conflict"] - and relation["priority"] != Priority.UNDEFINED - and relation["end"] in all_sims, - elem.relations, - ) - ) - - # step 1: simultaneous and independent sets generation - independents = defaultdict[Transaction, set[Transaction]](set) - - for elem in method_map.methods_and_transactions: - indeps = frozenset[Transaction]().union( - *(frozenset(method_map.transactions_for(ind)) for ind in chain([elem], elem.independent_list)) - ) - for transaction1, transaction2 in product(indeps, indeps): - independents[transaction1].add(transaction2) - - simultaneous = set[frozenset[Transaction]]() - - for elem in method_map.methods_and_transactions: - for sim_elem in elem.simultaneous_list: - for tr1, tr2 in product(method_map.transactions_for(elem), method_map.transactions_for(sim_elem)): - if tr1 in independents[tr2]: - raise RuntimeError( - f"Unsatisfiable simultaneity constraints for '{elem.name}' and '{sim_elem.name}'" - ) - simultaneous.add(frozenset({tr1, tr2})) - - # step 2: transitivity computation - tr_simultaneous = set[frozenset[Transaction]]() - - def conflicting(group: frozenset[Transaction]): - return any(tr1 != tr2 and tr1 in independents[tr2] for tr1 in group for tr2 in group) - - q = deque[frozenset[Transaction]](simultaneous) - - while q: - new_group = q.popleft() - if new_group in tr_simultaneous or conflicting(new_group): - continue - q.extend(new_group | other_group for other_group in simultaneous if new_group & other_group) - tr_simultaneous.add(new_group) - - # step 3: maximal group selection - def maximal(group: frozenset[Transaction]): - return not any(group.issubset(group2) and group != group2 for group2 in tr_simultaneous) - - final_simultaneous = set(filter(maximal, tr_simultaneous)) - - # step 4: convert transactions to methods - joined_transactions = set[Transaction]().union(*final_simultaneous) - - self.transactions = list(filter(lambda t: t not in joined_transactions, self.transactions)) - methods = dict[Transaction, Method]() - - for transaction in joined_transactions: - # TODO: some simpler way? - method = Method(name=transaction.name) - method.owner = transaction.owner - method.src_loc = transaction.src_loc - method.ready = transaction.request - method.run = transaction.grant - method.defined = transaction.defined - method.method_calls = transaction.method_calls - method.method_uses = transaction.method_uses - method.relations = transaction.relations - method.def_order = transaction.def_order - method.ctrl_path = transaction.ctrl_path - methods[transaction] = method - - for elem in method_map.methods_and_transactions: - # I guess method/transaction unification is really needed - for relation in elem.relations: - if relation["end"] in methods: - relation["end"] = methods[relation["end"]] - - # step 5: construct merged transactions - m = TModule() - m._MustUse__silence = True # type: ignore - - for group in final_simultaneous: - name = "_".join([t.name for t in group]) - with Transaction(manager=self, name=name).body(m): - for transaction in group: - methods[transaction](m) - - return m - - def elaborate(self, platform): - # In the following, various problems in the transaction set-up are detected. - # The exception triggers an unused Elaboratable warning. - with silence_mustuse(self): - merge_manager = self._simultaneous() - - method_map = MethodMap(self.transactions) - cgr, porder = TransactionManager._conflict_graph(method_map) - - m = Module() - m.submodules.merge_manager = merge_manager - - for elem in method_map.methods_and_transactions: - elem._set_method_uses(m) - - for transaction in self.transactions: - ready = [ - method_map.readiness_by_call[transaction, method] - for method in method_map.methods_by_transaction[transaction] - ] - m.d.comb += transaction.runnable.eq(Cat(ready).all()) - - ccs = _graph_ccs(cgr) - m.submodules._transactron_schedulers = ModuleConnector( - *[self.cc_scheduler(method_map, cgr, cc, porder) for cc in ccs] - ) - - method_enables = self._method_enables(method_map) - - for method, transactions in method_map.transactions_by_method.items(): - granted = Cat(transaction.grant & method_enables[transaction][method] for transaction in transactions) - m.d.comb += method.run.eq(granted.any()) - - (method_args, method_runs) = self._method_calls(m, method_map) - - for method in method_map.methods: - if len(method_args[method]) == 1: - m.d.comb += method.data_in.eq(method_args[method][0]) - else: - if method.single_caller: - raise RuntimeError(f"Single-caller method '{method.name}' called more than once") - - runs = Cat(method_runs[method]) - m.d.comb += assign(method.data_in, method.combiner(m, method_args[method], runs), fields=AssignType.ALL) - - if "TRANSACTRON_VERBOSE" in environ: - self.print_info(cgr, porder, ccs, method_map) - - return m - - def print_info( - self, cgr: TransactionGraph, porder: PriorityOrder, ccs: list[GraphCC["Transaction"]], method_map: MethodMap - ): - print("Transactron statistics") - print(f"\tMethods: {len(method_map.methods)}") - print(f"\tTransactions: {len(method_map.transactions)}") - print(f"\tIndependent subgraphs: {len(ccs)}") - print(f"\tAvg callers per method: {average_dict_of_lists(method_map.transactions_by_method):.2f}") - print(f"\tAvg conflicts per transaction: {average_dict_of_lists(cgr):.2f}") - print("") - print("Transaction subgraphs") - for cc in ccs: - ccl = list(cc) - ccl.sort(key=lambda t: porder[t]) - for t in ccl: - print(f"\t{t.name}") - print("") - print("Calling transactions per method") - for m, ts in method_map.transactions_by_method.items(): - print(f"\t{m.owned_name}: {m.src_loc[0]}:{m.src_loc[1]}") - for t in ts: - print(f"\t\t{t.name}: {t.src_loc[0]}:{t.src_loc[1]}") - print("") - print("Called methods per transaction") - for t, ms in method_map.methods_by_transaction.items(): - print(f"\t{t.name}: {t.src_loc[0]}:{t.src_loc[1]}") - for m in ms: - print(f"\t\t{m.owned_name}: {m.src_loc[0]}:{m.src_loc[1]}") - print("") - - def visual_graph(self, fragment): - graph = OwnershipGraph(fragment) - method_map = MethodMap(self.transactions) - for method, transactions in method_map.transactions_by_method.items(): - if len(method.data_in.as_value()) > len(method.data_out.as_value()): - direction = Direction.IN - elif method.data_in.shape().size < method.data_out.shape().size: - direction = Direction.OUT - else: - direction = Direction.INOUT - graph.insert_node(method) - for transaction in transactions: - graph.insert_node(transaction) - graph.insert_edge(transaction, method, direction) - - return graph - - def debug_signals(self) -> SignalBundle: - method_map = MethodMap(self.transactions) - cgr, _ = TransactionManager._conflict_graph(method_map) - - def transaction_debug(t: Transaction): - return ( - [t.request, t.grant] - + [m.ready for m in method_map.methods_by_transaction[t]] - + [t2.grant for t2 in cgr[t]] - ) - - def method_debug(m: Method): - return [m.ready, m.run, {t.name: transaction_debug(t) for t in method_map.transactions_by_method[m]}] - - return { - "transactions": {t.name: transaction_debug(t) for t in method_map.transactions}, - "methods": {m.owned_name: method_debug(m) for m in method_map.methods}, - } - - -class TransactionModule(Elaboratable): - """ - `TransactionModule` is used as wrapper on `Elaboratable` classes, - which adds support for transactions. It creates a - `TransactionManager` which will handle transaction scheduling - and can be used in definition of `Method`\\s and `Transaction`\\s. - The `TransactionManager` is stored in a `DependencyManager`. - """ - - def __init__( - self, - elaboratable: HasElaborate, - dependency_manager: Optional[DependencyManager] = None, - transaction_manager: Optional[TransactionManager] = None, - ): - """ - Parameters - ---------- - elaboratable: HasElaborate - The `Elaboratable` which should be wrapped to add support for - transactions and methods. - dependency_manager: DependencyManager, optional - The `DependencyManager` to use inside the transaction module. - If omitted, a new one is created. - transaction_manager: TransactionManager, optional - The `TransactionManager` to use inside the transaction module. - If omitted, a new one is created. - """ - if transaction_manager is None: - transaction_manager = TransactionManager() - if dependency_manager is None: - dependency_manager = DependencyManager() - self.manager = dependency_manager - self.manager.add_dependency(TransactionManagerKey(), transaction_manager) - self.elaboratable = elaboratable - - def context(self) -> DependencyContext: - return DependencyContext(self.manager) - - def elaborate(self, platform): - with silence_mustuse(self.manager.get_dependency(TransactionManagerKey())): - with self.context(): - elaboratable = Fragment.get(self.elaboratable, platform) - - m = Module() - - m.submodules.main_module = elaboratable - m.submodules.transactionManager = self.transaction_manager = self.manager.get_dependency( - TransactionManagerKey() - ) - - return m - - -class TransactionComponent(TransactionModule, Component): - """Top-level component for Transactron projects. - - The `TransactronComponent` is a wrapper on `Component` classes, - which adds Transactron support for the wrapped class. The use - case is to wrap a top-level module of the project, and pass the - wrapped module for simulation, HDL generation or synthesis. - The ports of the wrapped component are forwarded to the wrapper. - - It extends the functionality of `TransactionModule`. - """ - - def __init__( - self, - component: AbstractComponent, - dependency_manager: Optional[DependencyManager] = None, - transaction_manager: Optional[TransactionManager] = None, - ): - """ - Parameters - ---------- - component: Component - The `Component` which should be wrapped to add support for - transactions and methods. - dependency_manager: DependencyManager, optional - The `DependencyManager` to use inside the transaction component. - If omitted, a new one is created. - transaction_manager: TransactionManager, optional - The `TransactionManager` to use inside the transaction component. - If omitted, a new one is created. - """ - TransactionModule.__init__(self, component, dependency_manager, transaction_manager) - Component.__init__(self, component.signature) - - def elaborate(self, platform): - m = super().elaborate(platform) - - assert isinstance(self.elaboratable, Component) # for typing - connect(m, flipped(self), self.elaboratable) - - return m diff --git a/transactron/core/method.py b/transactron/core/method.py deleted file mode 100644 index b5d573fcd..000000000 --- a/transactron/core/method.py +++ /dev/null @@ -1,315 +0,0 @@ -from collections.abc import Sequence -from transactron.utils import * -from amaranth import * -from amaranth import tracer -from typing import Optional, Callable, Iterator, TYPE_CHECKING -from .transaction_base import * -from .sugar import def_method -from contextlib import contextmanager -from transactron.utils.assign import AssignArg - -if TYPE_CHECKING: - from .tmodule import TModule - -__all__ = ["Method"] - - -class Method(TransactionBase): - """Transactional method. - - A `Method` serves to interface a module with external `Transaction`\\s - or `Method`\\s. It can be called by at most once in a given clock cycle. - When a given `Method` is required by multiple `Transaction`\\s - (either directly, or indirectly via another `Method`) simultenaously, - at most one of them is granted by the `TransactionManager`, and the rest - of them must wait. (Non-exclusive methods are an exception to this - behavior.) Calling a `Method` always takes a single clock cycle. - - Data is combinationally transferred between to and from `Method`\\s - using Amaranth structures (`View` with a `StructLayout`). The transfer - can take place in both directions at the same time: from the called - `Method` to the caller (`data_out`) and from the caller to the called - `Method` (`data_in`). - - A module which defines a `Method` should use `body` or `def_method` - to describe the method's effect on the module state. - - Attributes - ---------- - name: str - Name of this `Method`. - ready: Signal, in - Signals that the method is ready to run in the current cycle. - Typically defined by calling `body`. - run: Signal, out - Signals that the method is called in the current cycle by some - `Transaction`. Defined by the `TransactionManager`. - data_in: MethodStruct, out - Contains the data passed to the `Method` by the caller - (a `Transaction` or another `Method`). - data_out: MethodStruct, in - Contains the data passed from the `Method` to the caller - (a `Transaction` or another `Method`). Typically defined by - calling `body`. - """ - - def __init__( - self, - *, - name: Optional[str] = None, - i: MethodLayout = (), - o: MethodLayout = (), - nonexclusive: bool = False, - combiner: Optional[Callable[[Module, Sequence[MethodStruct], Value], AssignArg]] = None, - single_caller: bool = False, - src_loc: int | SrcLoc = 0, - ): - """ - Parameters - ---------- - name: str or None - Name hint for this `Method`. If `None` (default) the name is - inferred from the variable name this `Method` is assigned to. - i: method layout - The format of `data_in`. - o: method layout - The format of `data_out`. - nonexclusive: bool - If true, the method is non-exclusive: it can be called by multiple - transactions in the same clock cycle. If such a situation happens, - the method still is executed only once, and each of the callers - receive its output. Nonexclusive methods cannot have inputs. - combiner: (Module, Sequence[MethodStruct], Value) -> AssignArg - If `nonexclusive` is true, the combiner function combines the - arguments from multiple calls to this method into a single - argument, which is passed to the method body. The third argument - is a bit vector, whose n-th bit is 1 if the n-th call is active - in a given cycle. - single_caller: bool - If true, this method is intended to be called from a single - transaction. An error will be thrown if called from multiple - transactions. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - super().__init__(src_loc=get_src_loc(src_loc)) - - def default_combiner(m: Module, args: Sequence[MethodStruct], runs: Value) -> AssignArg: - ret = Signal(from_method_layout(i)) - for k in OneHotSwitchDynamic(m, runs): - m.d.comb += ret.eq(args[k]) - return ret - - self.owner, owner_name = get_caller_class_name(default="$method") - self.name = name or tracer.get_var_name(depth=2, default=owner_name) - self.ready = Signal(name=self.owned_name + "_ready") - self.run = Signal(name=self.owned_name + "_run") - self.data_in: MethodStruct = Signal(from_method_layout(i)) - self.data_out: MethodStruct = Signal(from_method_layout(o)) - self.nonexclusive = nonexclusive - self.combiner: Callable[[Module, Sequence[MethodStruct], Value], AssignArg] = combiner or default_combiner - self.single_caller = single_caller - self.validate_arguments: Optional[Callable[..., ValueLike]] = None - if nonexclusive: - assert len(self.data_in.as_value()) == 0 or combiner is not None - - @property - def layout_in(self): - return self.data_in.shape() - - @property - def layout_out(self): - return self.data_out.shape() - - @staticmethod - def like(other: "Method", *, name: Optional[str] = None, src_loc: int | SrcLoc = 0) -> "Method": - """Constructs a new `Method` based on another. - - The returned `Method` has the same input/output data layouts as the - `other` `Method`. - - Parameters - ---------- - other : Method - The `Method` which serves as a blueprint for the new `Method`. - name : str, optional - Name of the new `Method`. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - - Returns - ------- - Method - The freshly constructed `Method`. - """ - return Method(name=name, i=other.layout_in, o=other.layout_out, src_loc=get_src_loc(src_loc)) - - def proxy(self, m: "TModule", method: "Method"): - """Define as a proxy for another method. - - The calls to this method will be forwarded to `method`. - - Parameters - ---------- - m : TModule - Module in which operations on signals should be executed, - `proxy` uses the combinational domain only. - method : Method - Method for which this method is a proxy for. - """ - - @def_method(m, self, ready=method.ready) - def _(arg): - return method(m, arg) - - @contextmanager - def body( - self, - m: "TModule", - *, - ready: ValueLike = C(1), - out: ValueLike = C(0, 0), - validate_arguments: Optional[Callable[..., ValueLike]] = None, - ) -> Iterator[MethodStruct]: - """Define method body - - The `body` context manager can be used to define the actions - performed by a `Method` when it's run. Each assignment added to - a domain under `body` is guarded by the `run` signal. - Combinational assignments which do not need to be guarded by `run` - can be added to `m.d.av_comb` or `m.d.top_comb` instead of `m.d.comb`. - `Method` calls can be performed under `body`. - - Parameters - ---------- - m : TModule - Module in which operations on signals should be executed, - `body` uses the combinational domain only. - ready : Signal, in - Signal to indicate if the method is ready to be run. By - default it is `Const(1)`, so the method is always ready. - Assigned combinationially to the `ready` attribute. - out : Value, in - Data generated by the `Method`, which will be passed to - the caller (a `Transaction` or another `Method`). Assigned - combinationally to the `data_out` attribute. - validate_arguments: Optional[Callable[..., ValueLike]] - Function that takes input arguments used to call the method - and checks whether the method can be called with those arguments. - It instantiates a combinational circuit for each - method caller. By default, there is no function, so all arguments - are accepted. - - Returns - ------- - data_in : Record, out - Data passed from the caller (a `Transaction` or another - `Method`) to this `Method`. - - Examples - -------- - .. highlight:: python - .. code-block:: python - - m = Module() - my_sum_method = Method(i = Layout([("arg1",8),("arg2",8)])) - sum = Signal(16) - with my_sum_method.body(m, out = sum) as data_in: - m.d.comb += sum.eq(data_in.arg1 + data_in.arg2) - """ - if self.defined: - raise RuntimeError(f"Method '{self.name}' already defined") - self.def_order = next(TransactionBase.def_counter) - self.validate_arguments = validate_arguments - - m.d.av_comb += self.ready.eq(ready) - m.d.top_comb += self.data_out.eq(out) - with self.context(m): - with m.AvoidedIf(self.run): - yield self.data_in - - def _validate_arguments(self, arg_rec: MethodStruct) -> ValueLike: - if self.validate_arguments is not None: - return self.ready & method_def_helper(self, self.validate_arguments, arg_rec) - return self.ready - - def __call__( - self, m: "TModule", arg: Optional[AssignArg] = None, enable: ValueLike = C(1), /, **kwargs: AssignArg - ) -> MethodStruct: - """Call a method. - - Methods can only be called from transaction and method bodies. - Calling a `Method` marks, for the purpose of transaction scheduling, - the dependency between the calling context and the called `Method`. - It also connects the method's inputs to the parameters and the - method's outputs to the return value. - - Parameters - ---------- - m : TModule - Module in which operations on signals should be executed, - arg : Value or dict of Values - Call argument. Can be passed as a `View` of the method's - input layout or as a dictionary. Alternative syntax uses - keyword arguments. - enable : Value - Configures the call as enabled in the current clock cycle. - Disabled calls still lock the called method in transaction - scheduling. Calls are by default enabled. - **kwargs : Value or dict of Values - Allows to pass method arguments using keyword argument - syntax. Equivalent to passing a dict as the argument. - - Returns - ------- - data_out : MethodStruct - The result of the method call. - - Examples - -------- - .. highlight:: python - .. code-block:: python - - m = Module() - with Transaction().body(m): - ret = my_sum_method(m, arg1=2, arg2=3) - - Alternative syntax: - - .. highlight:: python - .. code-block:: python - - with Transaction().body(m): - ret = my_sum_method(m, {"arg1": 2, "arg2": 3}) - """ - arg_rec = Signal.like(self.data_in) - - if arg is not None and kwargs: - raise ValueError(f"Method '{self.name}' call with both keyword arguments and legacy record argument") - - if arg is None: - arg = kwargs - - enable_sig = Signal(name=self.owned_name + "_enable") - m.d.av_comb += enable_sig.eq(enable) - m.d.top_comb += assign(arg_rec, arg, fields=AssignType.ALL) - - caller = TransactionBase.get() - if not all(ctrl_path.exclusive_with(m.ctrl_path) for ctrl_path, _, _ in caller.method_calls[self]): - raise RuntimeError(f"Method '{self.name}' can't be called twice from the same caller '{caller.name}'") - caller.method_calls[self].append((m.ctrl_path, arg_rec, enable_sig)) - - if self not in caller.method_uses: - arg_rec_use = Signal(self.layout_in) - arg_rec_enable_sig = Signal() - caller.method_uses[self] = (arg_rec_use, arg_rec_enable_sig) - - return self.data_out - - def __repr__(self) -> str: - return "(method {})".format(self.name) - - def debug_signals(self) -> SignalBundle: - return [self.ready, self.run, self.data_in, self.data_out] diff --git a/transactron/core/schedulers.py b/transactron/core/schedulers.py deleted file mode 100644 index 856d4450b..000000000 --- a/transactron/core/schedulers.py +++ /dev/null @@ -1,77 +0,0 @@ -from amaranth import * -from typing import TYPE_CHECKING -from transactron.utils import * - -if TYPE_CHECKING: - from .manager import MethodMap, TransactionGraph, TransactionGraphCC, PriorityOrder - -__all__ = ["eager_deterministic_cc_scheduler", "trivial_roundrobin_cc_scheduler"] - - -def eager_deterministic_cc_scheduler( - method_map: "MethodMap", gr: "TransactionGraph", cc: "TransactionGraphCC", porder: "PriorityOrder" -) -> Module: - """eager_deterministic_cc_scheduler - - This function generates an eager scheduler for the transaction - subsystem. It isn't fair, because it starts transactions using - transaction index in `cc` as a priority. Transaction with the lowest - index has the highest priority. - - If there are two different transactions which have no conflicts then - they will be started concurrently. - - Parameters - ---------- - manager : TransactionManager - TransactionManager which uses this instance of scheduler for - arbitrating which agent should get a grant signal. - gr : TransactionGraph - Graph of conflicts between transactions, where vertices are transactions and edges are conflicts. - cc : Set[Transaction] - Connected components of the graph `gr` for which scheduler - should be generated. - porder : PriorityOrder - Linear ordering of transactions which is consistent with priority constraints. - """ - m = Module() - ccl = list(cc) - ccl.sort(key=lambda transaction: porder[transaction]) - for k, transaction in enumerate(ccl): - conflicts = [ccl[j].grant for j in range(k) if ccl[j] in gr[transaction]] - noconflict = ~Cat(conflicts).any() - m.d.comb += transaction.grant.eq(transaction.request & transaction.runnable & noconflict) - return m - - -def trivial_roundrobin_cc_scheduler( - method_map: "MethodMap", gr: "TransactionGraph", cc: "TransactionGraphCC", porder: "PriorityOrder" -) -> Module: - """trivial_roundrobin_cc_scheduler - - This function generates a simple round-robin scheduler for the transaction - subsystem. In a one cycle there will be at most one transaction granted - (in a given connected component of the conflict graph), even if there is - another ready, non-conflicting, transaction. It is mainly for testing - purposes. - - Parameters - ---------- - manager : TransactionManager - TransactionManager which uses this instance of scheduler for - arbitrating which agent should get grant signal. - gr : TransactionGraph - Graph of conflicts between transactions, where vertices are transactions and edges are conflicts. - cc : Set[Transaction] - Connected components of the graph `gr` for which scheduler - should be generated. - porder : PriorityOrder - Linear ordering of transactions which is consistent with priority constraints. - """ - m = Module() - sched = Scheduler(len(cc)) - m.submodules.scheduler = sched - for k, transaction in enumerate(cc): - m.d.comb += sched.requests[k].eq(transaction.request & transaction.runnable) - m.d.comb += transaction.grant.eq(sched.grant[k] & sched.valid) - return m diff --git a/transactron/core/sugar.py b/transactron/core/sugar.py deleted file mode 100644 index 640cddbb5..000000000 --- a/transactron/core/sugar.py +++ /dev/null @@ -1,180 +0,0 @@ -from collections.abc import Sequence, Callable -from amaranth import * -from typing import TYPE_CHECKING, Optional, Concatenate, ParamSpec -from transactron.utils import * -from transactron.utils.assign import AssignArg -from functools import partial - -if TYPE_CHECKING: - from .tmodule import TModule - from .method import Method - -__all__ = ["def_method", "def_methods"] - - -P = ParamSpec("P") - - -def def_method( - m: "TModule", - method: "Method", - ready: ValueLike = C(1), - validate_arguments: Optional[Callable[..., ValueLike]] = None, -): - """Define a method. - - This decorator allows to define transactional methods in an - elegant way using Python's `def` syntax. Internally, `def_method` - uses `Method.body`. - - The decorated function should take keyword arguments corresponding to the - fields of the method's input layout. The `**kwargs` syntax is supported. - Alternatively, it can take one argument named `arg`, which will be a - structure with input signals. - - The returned value can be either a structure with the method's output layout - or a dictionary of outputs. - - Parameters - ---------- - m: TModule - Module in which operations on signals should be executed. - method: Method - The method whose body is going to be defined. - ready: Signal - Signal to indicate if the method is ready to be run. By - default it is `Const(1)`, so the method is always ready. - Assigned combinationally to the `ready` attribute. - validate_arguments: Optional[Callable[..., ValueLike]] - Function that takes input arguments used to call the method - and checks whether the method can be called with those arguments. - It instantiates a combinational circuit for each - method caller. By default, there is no function, so all arguments - are accepted. - - Examples - -------- - .. highlight:: python - .. code-block:: python - - m = Module() - my_sum_method = Method(i=[("arg1",8),("arg2",8)], o=[("res",8)]) - @def_method(m, my_sum_method) - def _(arg1, arg2): - return arg1 + arg2 - - Alternative syntax (keyword args in dictionary): - - .. highlight:: python - .. code-block:: python - - @def_method(m, my_sum_method) - def _(**args): - return args["arg1"] + args["arg2"] - - Alternative syntax (arg structure): - - .. highlight:: python - .. code-block:: python - - @def_method(m, my_sum_method) - def _(arg): - return {"res": arg.arg1 + arg.arg2} - """ - - def decorator(func: Callable[..., Optional[AssignArg]]): - out = Signal(method.layout_out) - ret_out = None - - with method.body(m, ready=ready, out=out, validate_arguments=validate_arguments) as arg: - ret_out = method_def_helper(method, func, arg) - - if ret_out is not None: - m.d.top_comb += assign(out, ret_out, fields=AssignType.ALL) - - return decorator - - -def def_methods( - m: "TModule", - methods: Sequence["Method"], - ready: Callable[[int], ValueLike] = lambda _: C(1), - validate_arguments: Optional[Callable[..., ValueLike]] = None, -): - """Decorator for defining similar methods - - This decorator is a wrapper over `def_method`, which allows you to easily - define multiple similar methods in a loop. - - The function over which this decorator is applied, should always expect - at least one argument, as the index of the method will be passed as the - first argument to the function. - - This is a syntax sugar equivalent to: - - .. highlight:: python - .. code-block:: python - - for i in range(len(my_methods)): - @def_method(m, my_methods[i]) - def _(arg): - ... - - Parameters - ---------- - m: TModule - Module in which operations on signals should be executed. - methods: Sequence[Method] - The methods whose body is going to be defined. - ready: Callable[[int], Value] - A `Callable` that takes the index in the form of an `int` of the currently defined method - and produces a `Value` describing whether the method is ready to be run. - When omitted, each defined method is always ready. Assigned combinationally to the `ready` attribute. - validate_arguments: Optional[Callable[Concatenate[int, ...], ValueLike]] - Function that takes input arguments used to call the method - and checks whether the method can be called with those arguments. - It instantiates a combinational circuit for each - method caller. By default, there is no function, so all arguments - are accepted. - - Examples - -------- - Define three methods with the same body: - - .. highlight:: python - .. code-block:: python - - m = TModule() - my_sum_methods = [Method(i=[("arg1",8),("arg2",8)], o=[("res",8)]) for _ in range(3)] - @def_methods(m, my_sum_methods) - def _(_, arg1, arg2): - return arg1 + arg2 - - Define three methods with different bodies parametrized with the index of the method: - - .. highlight:: python - .. code-block:: python - - m = TModule() - my_sum_methods = [Method(i=[("arg1",8),("arg2",8)], o=[("res",8)]) for _ in range(3)] - @def_methods(m, my_sum_methods) - def _(index : int, arg1, arg2): - return arg1 + arg2 + index - - Define three methods with different ready signals: - - .. highlight:: python - .. code-block:: python - - @def_methods(m, my_filter_read_methods, ready_list=lambda i: fifo.head == i) - def _(_): - return fifo.read(m) - """ - - def decorator(func: Callable[Concatenate[int, P], Optional[RecordDict]]): - for i in range(len(methods)): - partial_f = partial(func, i) - partial_vargs = partial(validate_arguments, i) if validate_arguments is not None else None - def_method(m, methods[i], ready(i), partial_vargs)(partial_f) - - return decorator diff --git a/transactron/core/tmodule.py b/transactron/core/tmodule.py deleted file mode 100644 index d4276dce7..000000000 --- a/transactron/core/tmodule.py +++ /dev/null @@ -1,286 +0,0 @@ -from enum import Enum, auto -from dataclasses import dataclass, replace -from amaranth import * -from typing import Optional, Self, NoReturn -from contextlib import contextmanager -from amaranth.hdl._dsl import FSM -from transactron.utils import * - -__all__ = ["TModule"] - - -class _AvoidingModuleBuilderDomain: - """ - A wrapper over Amaranth domain to abstract away internal Amaranth implementation. - It is needed to allow for correctness check in `__setattr__` which uses `isinstance`. - """ - - def __init__(self, amaranth_module_domain): - self._domain = amaranth_module_domain - - def __iadd__(self, assigns: StatementLike) -> Self: - self._domain.__iadd__(assigns) - return self - - -class _AvoidingModuleBuilderDomains: - _m: "TModule" - - def __init__(self, m: "TModule"): - object.__setattr__(self, "_m", m) - - def __getattr__(self, name: str) -> _AvoidingModuleBuilderDomain: - if name == "av_comb": - return _AvoidingModuleBuilderDomain(self._m.avoiding_module.d["comb"]) - elif name == "top_comb": - return _AvoidingModuleBuilderDomain(self._m.top_module.d["comb"]) - else: - return _AvoidingModuleBuilderDomain(self._m.main_module.d[name]) - - def __getitem__(self, name: str) -> _AvoidingModuleBuilderDomain: - return self.__getattr__(name) - - def __setattr__(self, name: str, value): - if not isinstance(value, _AvoidingModuleBuilderDomain): - raise AttributeError(f"Cannot assign 'd.{name}' attribute; did you mean 'd.{name} +='?") - - def __setitem__(self, name: str, value): - return self.__setattr__(name, value) - - -class EnterType(Enum): - """Characterizes stack behavior of Amaranth's context managers for control structures.""" - - #: Used for `m.If`, `m.Switch` and `m.FSM`. - PUSH = auto() - #: Used for `m.Elif` and `m.Else`. - ADD = auto() - #: Used for `m.Case`, `m.Default` and `m.State`. - ENTRY = auto() - - -@dataclass(frozen=True) -class PathEdge: - """Describes an edge in Amaranth's control tree. - - Attributes - ---------- - alt : int - Which alternative (e.g. case of `m.If` or m.Switch`) is described. - par : int - Which parallel control structure (e.g. `m.If` at the same level) is described. - """ - - alt: int = 0 - par: int = 0 - - -@dataclass -class CtrlPath: - """Describes a path in Amaranth's control tree. - - Attributes - ---------- - module : int - Unique number of the module the path refers to. - path : list[PathEdge] - Path in the control tree, starting from the root. - """ - - module: int - path: list[PathEdge] - - def exclusive_with(self, other: "CtrlPath"): - """Decides if this path is mutually exclusive with some other path. - - Paths are mutually exclusive if they refer to the same module and - diverge on different alternatives of the same control structure. - - Arguments - --------- - other : CtrlPath - The other path this path is compared to. - """ - common_prefix = [] - for a, b in zip(self.path, other.path): - if a == b: - common_prefix.append(a) - elif a.par != b.par: - return False - else: - break - - return ( - self.module == other.module - and len(common_prefix) != len(self.path) - and len(common_prefix) != len(other.path) - ) - - -class CtrlPathBuilder: - """Constructs control paths. - - Used internally by `TModule`.""" - - def __init__(self, module: int): - """ - Parameters - ---------- - module: int - Unique module identifier. - """ - self.module = module - self.ctrl_path: list[PathEdge] = [] - self.previous: Optional[PathEdge] = None - - @contextmanager - def enter(self, enter_type=EnterType.PUSH): - et = EnterType - - match enter_type: - case et.ADD: - assert self.previous is not None - self.ctrl_path.append(replace(self.previous, alt=self.previous.alt + 1)) - case et.ENTRY: - self.ctrl_path[-1] = replace(self.ctrl_path[-1], alt=self.ctrl_path[-1].alt + 1) - case et.PUSH: - if self.previous is not None: - self.ctrl_path.append(PathEdge(par=self.previous.par + 1)) - else: - self.ctrl_path.append(PathEdge()) - self.previous = None - try: - yield - finally: - if enter_type in [et.PUSH, et.ADD]: - self.previous = self.ctrl_path.pop() - - def build_ctrl_path(self): - """Returns the current control path.""" - return CtrlPath(self.module, self.ctrl_path[:]) - - -class TModule(ModuleLike, Elaboratable): - """Extended Amaranth module for use with transactions. - - It includes three different combinational domains: - - * `comb` domain, works like the `comb` domain in plain Amaranth modules. - Statements in `comb` are guarded by every condition, including - `AvoidedIf`. This means they are guarded by transaction and method - bodies: they don't execute if the given transaction/method is not run. - * `av_comb` is guarded by all conditions except `AvoidedIf`. This means - they are not guarded by transaction and method bodies. This allows to - reduce the amount of useless multplexers due to transaction use, while - still allowing the use of conditions in transaction/method bodies. - * `top_comb` is unguarded: statements added to this domain always - execute. It can be used to reduce combinational path length due to - multplexers while keeping related combinational and synchronous - statements together. - """ - - __next_uid = 0 - - def __init__(self): - self.main_module = Module() - self.avoiding_module = Module() - self.top_module = Module() - self.d = _AvoidingModuleBuilderDomains(self) - self.submodules = self.main_module.submodules - self.domains = self.main_module.domains - self.fsm: Optional[FSM] = None - self.uid = TModule.__next_uid - self.path_builder = CtrlPathBuilder(self.uid) - TModule.__next_uid += 1 - - @contextmanager - def AvoidedIf(self, cond: ValueLike): # noqa: N802 - with self.main_module.If(cond): - with self.path_builder.enter(EnterType.PUSH): - yield - - @contextmanager - def If(self, cond: ValueLike): # noqa: N802 - with self.main_module.If(cond): - with self.avoiding_module.If(cond): - with self.path_builder.enter(EnterType.PUSH): - yield - - @contextmanager - def Elif(self, cond): # noqa: N802 - with self.main_module.Elif(cond): - with self.avoiding_module.Elif(cond): - with self.path_builder.enter(EnterType.ADD): - yield - - @contextmanager - def Else(self): # noqa: N802 - with self.main_module.Else(): - with self.avoiding_module.Else(): - with self.path_builder.enter(EnterType.ADD): - yield - - @contextmanager - def Switch(self, test: ValueLike): # noqa: N802 - with self.main_module.Switch(test): - with self.avoiding_module.Switch(test): - with self.path_builder.enter(EnterType.PUSH): - yield - - @contextmanager - def Case(self, *patterns: SwitchKey): # noqa: N802 - with self.main_module.Case(*patterns): - with self.avoiding_module.Case(*patterns): - with self.path_builder.enter(EnterType.ENTRY): - yield - - @contextmanager - def Default(self): # noqa: N802 - with self.main_module.Default(): - with self.avoiding_module.Default(): - with self.path_builder.enter(EnterType.ENTRY): - yield - - @contextmanager - def FSM(self, init: Optional[str] = None, domain: str = "sync", name: str = "fsm"): # noqa: N802 - old_fsm = self.fsm - with self.main_module.FSM(init, domain, name) as fsm: - self.fsm = fsm - with self.path_builder.enter(EnterType.PUSH): - yield fsm - self.fsm = old_fsm - - @contextmanager - def State(self, name: str): # noqa: N802 - assert self.fsm is not None - with self.main_module.State(name): - with self.avoiding_module.If(self.fsm.ongoing(name)): - with self.path_builder.enter(EnterType.ENTRY): - yield - - @property - def next(self) -> NoReturn: - raise NotImplementedError - - @next.setter - def next(self, name: str): - self.main_module.next = name - - @property - def ctrl_path(self): - return self.path_builder.build_ctrl_path() - - @property - def _MustUse__silence(self): # noqa: N802 - return self.main_module._MustUse__silence - - @_MustUse__silence.setter - def _MustUse__silence(self, value): # noqa: N802 - self.main_module._MustUse__silence = value # type: ignore - self.avoiding_module._MustUse__silence = value # type: ignore - self.top_module._MustUse__silence = value # type: ignore - - def elaborate(self, platform): - self.main_module.submodules._avoiding_module = self.avoiding_module - self.main_module.submodules._top_module = self.top_module - return self.main_module diff --git a/transactron/core/transaction.py b/transactron/core/transaction.py deleted file mode 100644 index c6f4176ab..000000000 --- a/transactron/core/transaction.py +++ /dev/null @@ -1,115 +0,0 @@ -from transactron.utils import * -from amaranth import * -from amaranth import tracer -from typing import Optional, Iterator, TYPE_CHECKING -from .transaction_base import * -from .keys import * -from contextlib import contextmanager - -if TYPE_CHECKING: - from .tmodule import TModule - from .manager import TransactionManager - -__all__ = ["Transaction"] - - -class Transaction(TransactionBase): - """Transaction. - - A `Transaction` represents a task which needs to be regularly done. - Execution of a `Transaction` always lasts a single clock cycle. - A `Transaction` signals readiness for execution by setting the - `request` signal. If the conditions for its execution are met, it - can be granted by the `TransactionManager`. - - A `Transaction` can, as part of its execution, call a number of - `Method`\\s. A `Transaction` can be granted only if every `Method` - it runs is ready. - - A `Transaction` cannot execute concurrently with another, conflicting - `Transaction`. Conflicts between `Transaction`\\s are either explicit - or implicit. An explicit conflict is added using the `add_conflict` - method. Implicit conflicts arise between pairs of `Transaction`\\s - which use the same `Method`. - - A module which defines a `Transaction` should use `body` to - describe used methods and the transaction's effect on the module state. - The used methods should be called inside the `body`'s - `with` block. - - Attributes - ---------- - name: str - Name of this `Transaction`. - request: Signal, in - Signals that the transaction wants to run. If omitted, the transaction - is always ready. Defined in the constructor. - runnable: Signal, out - Signals that all used methods are ready. - grant: Signal, out - Signals that the transaction is granted by the `TransactionManager`, - and all used methods are called. - """ - - def __init__( - self, *, name: Optional[str] = None, manager: Optional["TransactionManager"] = None, src_loc: int | SrcLoc = 0 - ): - """ - Parameters - ---------- - name: str or None - Name hint for this `Transaction`. If `None` (default) the name is - inferred from the variable name this `Transaction` is assigned to. - If the `Transaction` was not assigned, the name is inferred from - the class name where the `Transaction` was constructed. - manager: TransactionManager - The `TransactionManager` controlling this `Transaction`. - If omitted, the manager is received from `TransactionContext`. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - super().__init__(src_loc=get_src_loc(src_loc)) - self.owner, owner_name = get_caller_class_name(default="$transaction") - self.name = name or tracer.get_var_name(depth=2, default=owner_name) - if manager is None: - manager = DependencyContext.get().get_dependency(TransactionManagerKey()) - manager.add_transaction(self) - self.request = Signal(name=self.owned_name + "_request") - self.runnable = Signal(name=self.owned_name + "_runnable") - self.grant = Signal(name=self.owned_name + "_grant") - - @contextmanager - def body(self, m: "TModule", *, request: ValueLike = C(1)) -> Iterator["Transaction"]: - """Defines the `Transaction` body. - - This context manager allows to conveniently define the actions - performed by a `Transaction` when it's granted. Each assignment - added to a domain under `body` is guarded by the `grant` signal. - Combinational assignments which do not need to be guarded by - `grant` can be added to `m.d.top_comb` or `m.d.av_comb` instead of - `m.d.comb`. `Method` calls can be performed under `body`. - - Parameters - ---------- - m: TModule - The module where the `Transaction` is defined. - request: Signal - Indicates that the `Transaction` wants to be executed. By - default it is `Const(1)`, so it wants to be executed in - every clock cycle. - """ - if self.defined: - raise RuntimeError(f"Transaction '{self.name}' already defined") - self.def_order = next(TransactionBase.def_counter) - - m.d.av_comb += self.request.eq(request) - with self.context(m): - with m.AvoidedIf(self.grant): - yield self - - def __repr__(self) -> str: - return "(transaction {})".format(self.name) - - def debug_signals(self) -> SignalBundle: - return [self.request, self.runnable, self.grant] diff --git a/transactron/core/transaction_base.py b/transactron/core/transaction_base.py deleted file mode 100644 index be4fe6f93..000000000 --- a/transactron/core/transaction_base.py +++ /dev/null @@ -1,209 +0,0 @@ -from collections import defaultdict -from contextlib import contextmanager -from enum import Enum, auto -from itertools import count -from typing import ( - ClassVar, - TypeAlias, - TypedDict, - Union, - TypeVar, - Protocol, - Self, - runtime_checkable, - TYPE_CHECKING, - Iterator, -) -from amaranth import * - -from .tmodule import TModule, CtrlPath -from transactron.graph import Owned -from transactron.utils import * - -if TYPE_CHECKING: - from .method import Method - from .transaction import Transaction - -__all__ = ["TransactionBase", "Priority"] - -TransactionOrMethod: TypeAlias = Union["Transaction", "Method"] -TransactionOrMethodBound = TypeVar("TransactionOrMethodBound", "Transaction", "Method") - - -class Priority(Enum): - #: Conflicting transactions/methods don't have a priority order. - UNDEFINED = auto() - #: Left transaction/method is prioritized over the right one. - LEFT = auto() - #: Right transaction/method is prioritized over the left one. - RIGHT = auto() - - -class RelationBase(TypedDict): - end: TransactionOrMethod - priority: Priority - conflict: bool - silence_warning: bool - - -class Relation(RelationBase): - start: TransactionOrMethod - - -@runtime_checkable -class TransactionBase(Owned, Protocol): - stack: ClassVar[list[Union["Transaction", "Method"]]] = [] - def_counter: ClassVar[count] = count() - def_order: int - defined: bool = False - name: str - src_loc: SrcLoc - method_uses: dict["Method", tuple[MethodStruct, Signal]] - method_calls: defaultdict["Method", list[tuple[CtrlPath, MethodStruct, ValueLike]]] - relations: list[RelationBase] - simultaneous_list: list[TransactionOrMethod] - independent_list: list[TransactionOrMethod] - ctrl_path: CtrlPath = CtrlPath(-1, []) - - def __init__(self, *, src_loc: int | SrcLoc): - self.src_loc = get_src_loc(src_loc) - self.method_uses = {} - self.method_calls = defaultdict(list) - self.relations = [] - self.simultaneous_list = [] - self.independent_list = [] - - def add_conflict(self, end: TransactionOrMethod, priority: Priority = Priority.UNDEFINED) -> None: - """Registers a conflict. - - Record that that the given `Transaction` or `Method` cannot execute - simultaneously with this `Method` or `Transaction`. Typical reason - is using a common resource (register write or memory port). - - Parameters - ---------- - end: Transaction or Method - The conflicting `Transaction` or `Method` - priority: Priority, optional - Is one of conflicting `Transaction`\\s or `Method`\\s prioritized? - Defaults to undefined priority relation. - """ - self.relations.append( - RelationBase(end=end, priority=priority, conflict=True, silence_warning=self.owner != end.owner) - ) - - def schedule_before(self, end: TransactionOrMethod) -> None: - """Adds a priority relation. - - Record that that the given `Transaction` or `Method` needs to be - scheduled before this `Method` or `Transaction`, without adding - a conflict. Typical reason is data forwarding. - - Parameters - ---------- - end: Transaction or Method - The other `Transaction` or `Method` - """ - self.relations.append( - RelationBase(end=end, priority=Priority.LEFT, conflict=False, silence_warning=self.owner != end.owner) - ) - - def simultaneous(self, *others: TransactionOrMethod) -> None: - """Adds simultaneity relations. - - The given `Transaction`\\s or `Method``\\s will execute simultaneously - (in the same clock cycle) with this `Transaction` or `Method`. - - Parameters - ---------- - *others: Transaction or Method - The `Transaction`\\s or `Method`\\s to be executed simultaneously. - """ - self.simultaneous_list += others - - def simultaneous_alternatives(self, *others: TransactionOrMethod) -> None: - """Adds exclusive simultaneity relations. - - Each of the given `Transaction`\\s or `Method``\\s will execute - simultaneously (in the same clock cycle) with this `Transaction` or - `Method`. However, each of the given `Transaction`\\s or `Method`\\s - will be separately considered for execution. - - Parameters - ---------- - *others: Transaction or Method - The `Transaction`\\s or `Method`\\s to be executed simultaneously, - but mutually exclusive, with this `Transaction` or `Method`. - """ - self.simultaneous(*others) - others[0]._independent(*others[1:]) - - def _independent(self, *others: TransactionOrMethod) -> None: - """Adds independence relations. - - This `Transaction` or `Method`, together with all the given - `Transaction`\\s or `Method`\\s, will never be considered (pairwise) - for simultaneous execution. - - Warning: this function is an implementation detail, do not use in - user code. - - Parameters - ---------- - *others: Transaction or Method - The `Transaction`\\s or `Method`\\s which, together with this - `Transaction` or `Method`, need to be independently considered - for execution. - """ - self.independent_list += others - - @contextmanager - def context(self: TransactionOrMethodBound, m: TModule) -> Iterator[TransactionOrMethodBound]: - self.ctrl_path = m.ctrl_path - - parent = TransactionBase.peek() - if parent is not None: - parent.schedule_before(self) - - TransactionBase.stack.append(self) - - try: - yield self - finally: - TransactionBase.stack.pop() - self.defined = True - - def _set_method_uses(self, m: ModuleLike): - for method, calls in self.method_calls.items(): - arg_rec, enable_sig = self.method_uses[method] - if len(calls) == 1: - m.d.comb += arg_rec.eq(calls[0][1]) - m.d.comb += enable_sig.eq(calls[0][2]) - else: - call_ens = Cat([en for _, _, en in calls]) - - for i in OneHotSwitchDynamic(m, call_ens): - m.d.comb += arg_rec.eq(calls[i][1]) - m.d.comb += enable_sig.eq(1) - - @classmethod - def get(cls) -> Self: - ret = cls.peek() - if ret is None: - raise RuntimeError("No current body") - return ret - - @classmethod - def peek(cls) -> Optional[Self]: - if not TransactionBase.stack: - return None - if not isinstance(TransactionBase.stack[-1], cls): - raise RuntimeError(f"Current body not a {cls.__name__}") - return TransactionBase.stack[-1] - - @property - def owned_name(self): - if self.owner is not None and self.owner.__class__.__name__ != self.name: - return f"{self.owner.__class__.__name__}_{self.name}" - else: - return self.name diff --git a/transactron/graph.py b/transactron/graph.py deleted file mode 100644 index 709ba8724..000000000 --- a/transactron/graph.py +++ /dev/null @@ -1,249 +0,0 @@ -""" -Utilities for extracting dependency graphs from Amaranth designs. -""" - -from enum import IntFlag -from collections import defaultdict -from typing import Literal, Optional, Protocol - -from amaranth import Elaboratable, Fragment -from .tracing import TracingFragment - - -class Owned(Protocol): - name: str - owner: Optional[Elaboratable] - - -class Direction(IntFlag): - NONE = 0 - IN = 1 - OUT = 2 - INOUT = 3 - - -class OwnershipGraph: - mermaid_direction = ["---", "-->", "<--", "<-->"] - - def __init__(self, root): - self.class_counters: defaultdict[type, int] = defaultdict(int) - self.owned_counters: defaultdict[tuple[int, str], int] = defaultdict(int) - self.names: dict[int, str] = {} - self.owned_names: dict[int, str] = {} - self.hier: dict[int, str] = {} - self.labels: dict[int, str] = {} - self.graph: dict[int, list[int]] = {} - self.edges: list[tuple[Owned, Owned, Direction]] = [] - self.owned: defaultdict[int, set[Owned]] = defaultdict(set) - self.stray: set[int] = set() - self.remember(root) - - def remember(self, owner: Elaboratable) -> int: - while hasattr(owner, "_tracing_original"): - owner = owner._tracing_original # type: ignore - owner_id = id(owner) - if owner_id not in self.names: - tp = type(owner) - count = self.class_counters[tp] - self.class_counters[tp] = count + 1 - - name = tp.__name__ - if count: - name += str(count) - self.names[owner_id] = name - self.graph[owner_id] = [] - while True: - for field, obj in vars(owner).items(): - if isinstance(obj, Elaboratable) and not field.startswith("_"): - self.remember_field(owner_id, field, obj) - if isinstance(owner, Fragment): - assert isinstance(owner, TracingFragment) - for obj, field, _ in owner.subfragments: - self.remember_field(owner_id, field, obj) - try: - owner = owner._elaborated # type: ignore - except AttributeError: - break - return owner_id - - def remember_field(self, owner_id: int, field: str, obj: Elaboratable): - while hasattr(obj, "_tracing_original"): - obj = obj._tracing_original # type: ignore - obj_id = id(obj) - if obj_id == owner_id or obj_id in self.labels: - return - self.labels[obj_id] = f"{field} {obj.__class__.__name__}" - self.graph[owner_id].append(obj_id) - self.remember(obj) - - def insert_node(self, obj: Owned): - assert obj.owner is not None - owner_id = self.remember(obj.owner) - self.owned[owner_id].add(obj) - - def insert_edge(self, fr: Owned, to: Owned, direction: Direction = Direction.OUT): - self.edges.append((fr, to, direction)) - - def get_name(self, obj: Owned) -> str: - assert obj.owner is not None - obj_id = id(obj) - name = self.owned_names.get(obj_id) - if name is not None: - return name - owner_id = self.remember(obj.owner) - count = self.owned_counters[(owner_id, obj.name)] - self.owned_counters[(owner_id, obj.name)] = count + 1 - suffix = str(count) if count else "" - name = self.owned_names[obj_id] = f"{self.names[owner_id]}_{obj.name}{suffix}" - return name - - def get_hier_name(self, obj: Owned) -> str: - """ - Get hierarchical name. - Might raise KeyError if not yet hierarchized. - """ - name = self.get_name(obj) - owner_id = id(obj.owner) - hier = self.hier[owner_id] - return f"{hier}.{name}" - - def prune(self, owner: Optional[int] = None): - """ - Mark all empty subgraphs. - """ - if owner is None: - backup = self.graph.copy() - for owner in self.names: - if owner not in self.labels: - self.prune(owner) - self.graph = backup - return - - subowners = self.graph.pop(owner) - flag = bool(self.owned[owner]) - for subowner in subowners: - if subowner in self.graph: - flag |= self.prune(subowner) - - if not flag: - self.stray.add(owner) - - return flag - - def dump(self, fp, format: Literal["dot", "elk", "mermaid"]): - dumper = getattr(self, "dump_" + format) - dumper(fp) - - def dump_dot(self, fp, owner: Optional[int] = None, indent: str = ""): - if owner is None: - fp.write("digraph G {\n") - for owner in self.names: - if owner not in self.labels: - self.dump_dot(fp, owner, indent) - for fr, to, direction in self.edges: - if direction == Direction.OUT: - fr, to = to, fr - - caller_name = self.get_name(fr) - callee_name = self.get_name(to) - fp.write(f"{caller_name} -> {callee_name}\n") - fp.write("}\n") - return - - subowners = self.graph.pop(owner) - if owner in self.stray: - return - indent += " " - owned = self.owned[owner] - fp.write(f"{indent}subgraph cluster_{self.names[owner]} {{\n") - fp.write(f'{indent} label="{self.labels.get(owner, self.names[owner])}";\n') - for x in owned: - fp.write(f'{indent} {self.get_name(x)} [label="{x.name}"];\n') - for subowner in subowners: - if subowner in self.graph: - self.dump_dot(fp, subowner, indent) - fp.write(f"{indent}}}\n") - - def dump_elk(self, fp, owner: Optional[int] = None, indent: str = ""): - if owner is None: - fp.write(f"{indent}hierarchyHandling: INCLUDE_CHILDREN\n") - fp.write(f"{indent}elk.direction: DOWN\n") - for owner in self.names: - if owner not in self.labels: - self.dump_elk(fp, owner, indent) - return - - hier = self.hier.setdefault(owner, self.names[owner]) - - subowners = self.graph.pop(owner) - if owner in self.stray: - return - owned = self.owned[owner] - fp.write(f"{indent}node {self.names[owner]} {{\n") - fp.write(f"{indent} considerModelOrder.components: INSIDE_PORT_SIDE_GROUPS\n") - fp.write(f'{indent} nodeSize.constraints: "[PORTS, PORT_LABELS, MINIMUM_SIZE]"\n') - fp.write(f'{indent} nodeLabels.placement: "[H_LEFT, V_TOP, OUTSIDE]"\n') - fp.write(f'{indent} portLabels.placement: "[INSIDE]"\n') - fp.write(f"{indent} feedbackEdges: true\n") - fp.write(f'{indent} label "{self.labels.get(owner, self.names[owner])}"\n') - for x in owned: - if x.__class__.__name__ == "Method": - fp.write(f'{indent} port {self.get_name(x)} {{ label "{x.name}" }}\n') - else: - fp.write(f"{indent} node {self.get_name(x)} {{\n") - fp.write(f'{indent} nodeSize.constraints: "[NODE_LABELS, MINIMUM_SIZE]"\n') - fp.write(f'{indent} nodeLabels.placement: "[H_CENTER, V_CENTER, INSIDE]"\n') - fp.write(f'{indent} label "{x.name}"\n') - fp.write(f"{indent} }}\n") - for subowner in subowners: - if subowner in self.graph: - self.hier[subowner] = f"{hier}.{self.names[subowner]}" - self.dump_elk(fp, subowner, indent + " ") - - # reverse iteration so that deleting works - for i, (fr, to, direction) in reversed(list(enumerate(self.edges))): - if direction == Direction.OUT: - fr, to = to, fr - - try: - caller_name = self.get_hier_name(fr) - callee_name = self.get_hier_name(to) - except KeyError: - continue - - # only output edges belonging here - if caller_name[: len(hier)] == callee_name[: len(hier)] == hier: - caller_name = caller_name[len(hier) + 1 :] - callee_name = callee_name[len(hier) + 1 :] - del self.edges[i] - fp.write(f"{indent} edge {caller_name} -> {callee_name}\n") - - fp.write(f"{indent}}}\n") - - def dump_mermaid(self, fp, owner: Optional[int] = None, indent: str = ""): - if owner is None: - fp.write("flowchart TB\n") - for owner in self.names: - if owner not in self.labels: - self.dump_mermaid(fp, owner, indent) - for fr, to, direction in self.edges: - if direction == Direction.OUT: - fr, to, direction = to, fr, Direction.IN - - caller_name = self.get_name(fr) - callee_name = self.get_name(to) - fp.write(f"{caller_name} {self.mermaid_direction[direction]} {callee_name}\n") - return - - subowners = self.graph.pop(owner) - if owner in self.stray: - return - indent += " " - owned = self.owned[owner] - fp.write(f'{indent}subgraph {self.names[owner]}["{self.labels.get(owner, self.names[owner])}"]\n') - for x in owned: - fp.write(f'{indent} {self.get_name(x)}["{x.name}"]\n') - for subowner in subowners: - if subowner in self.graph: - self.dump_mermaid(fp, subowner, indent) - fp.write(f"{indent}end\n") diff --git a/transactron/lib/__init__.py b/transactron/lib/__init__.py deleted file mode 100644 index f6dd3ef0a..000000000 --- a/transactron/lib/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .fifo import * # noqa: F401 -from .connectors import * # noqa: F401 -from .buttons import * # noqa: F401 -from .adapters import * # noqa: F401 -from .transformers import * # noqa: F401 -from .reqres import * # noqa: F401 -from .storage import * # noqa: F401 -from .simultaneous import * # noqa: F401 -from .metrics import * # noqa: F401 diff --git a/transactron/lib/adapters.py b/transactron/lib/adapters.py deleted file mode 100644 index 81816b3c4..000000000 --- a/transactron/lib/adapters.py +++ /dev/null @@ -1,149 +0,0 @@ -from abc import abstractmethod -from typing import Optional -from amaranth import * -from amaranth.lib.wiring import Component, In, Out -from amaranth.lib.data import StructLayout, View - -from ..utils import SrcLoc, get_src_loc, MethodStruct -from ..core import * -from ..utils._typing import type_self_kwargs_as, SignalBundle - -__all__ = [ - "AdapterBase", - "AdapterTrans", - "Adapter", -] - - -class AdapterBase(Component): - data_in: MethodStruct - data_out: MethodStruct - en: Signal - done: Signal - - def __init__(self, iface: Method, layout_in: StructLayout, layout_out: StructLayout): - super().__init__({"data_in": In(layout_in), "data_out": Out(layout_out), "en": In(1), "done": Out(1)}) - self.iface = iface - - def debug_signals(self) -> SignalBundle: - return [self.en, self.done, self.data_in, self.data_out] - - @abstractmethod - def elaborate(self, platform) -> TModule: - raise NotImplementedError() - - -class AdapterTrans(AdapterBase): - """Adapter transaction. - - Creates a transaction controlled by plain Amaranth signals. Allows to - expose a method to plain Amaranth code, including testbenches. - - Attributes - ---------- - en: Signal, in - Activates the transaction (sets the `request` signal). - done: Signal, out - Signals that the transaction is performed (returns the `grant` - signal). - data_in: View, in - Data passed to the `iface` method. - data_out: View, out - Data returned from the `iface` method. - """ - - def __init__(self, iface: Method, *, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - iface: Method - The method to be called by the transaction. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - super().__init__(iface, iface.layout_in, iface.layout_out) - self.src_loc = get_src_loc(src_loc) - - def elaborate(self, platform): - m = TModule() - - # this forces data_in signal to appear in VCD dumps - data_in = Signal.like(self.data_in) - m.d.comb += data_in.eq(self.data_in) - - with Transaction(name=f"AdapterTrans_{self.iface.name}", src_loc=self.src_loc).body(m, request=self.en): - data_out = self.iface(m, data_in) - m.d.top_comb += self.data_out.eq(data_out) - m.d.comb += self.done.eq(1) - - return m - - -class Adapter(AdapterBase): - """Adapter method. - - Creates a method controlled by plain Amaranth signals. One of the - possible uses is to mock a method in a testbench. - - Attributes - ---------- - en: Signal, in - Activates the method (sets the `ready` signal). - done: Signal, out - Signals that the method is called (returns the `run` signal). - data_in: View, in - Data returned from the defined method. - data_out: View, out - Data passed as argument to the defined method. - validators: list of tuples of View, out and Signal, in - Hooks for `validate_arguments`. - """ - - @type_self_kwargs_as(Method.__init__) - def __init__(self, **kwargs): - """ - Parameters - ---------- - **kwargs - Keyword arguments for Method that will be created. - See transactron.core.Method.__init__ for parameters description. - """ - - kwargs["src_loc"] = get_src_loc(kwargs.setdefault("src_loc", 0)) - - iface = Method(**kwargs) - super().__init__(iface, iface.layout_out, iface.layout_in) - self.validators: list[tuple[View[StructLayout], Signal]] = [] - self.with_validate_arguments: bool = False - - def set(self, with_validate_arguments: Optional[bool]): - if with_validate_arguments is not None: - self.with_validate_arguments = with_validate_arguments - return self - - def elaborate(self, platform): - m = TModule() - - # this forces data_in signal to appear in VCD dumps - data_in = Signal.like(self.data_in) - m.d.comb += data_in.eq(self.data_in) - - kwargs = {} - - if self.with_validate_arguments: - - def validate_arguments(arg: "View[StructLayout]"): - ret = Signal() - self.validators.append((arg, ret)) - return ret - - kwargs["validate_arguments"] = validate_arguments - - @def_method(m, self.iface, ready=self.en, **kwargs) - def _(arg): - m.d.top_comb += self.data_out.eq(arg) - m.d.comb += self.done.eq(1) - return data_in - - return m diff --git a/transactron/lib/buttons.py b/transactron/lib/buttons.py deleted file mode 100644 index d275cd25d..000000000 --- a/transactron/lib/buttons.py +++ /dev/null @@ -1,113 +0,0 @@ -from amaranth import * - -from transactron.utils.transactron_helpers import from_method_layout -from ..core import * -from ..utils import SrcLoc, get_src_loc, MethodLayout - -__all__ = ["ClickIn", "ClickOut"] - - -class ClickIn(Elaboratable): - """Clicked input. - - Useful for interactive simulations or FPGA button/switch interfaces. - On a rising edge (tested synchronously) of `btn`, the `get` method - is enabled, which returns the data present on `dat` at the time. - Inputs are synchronized. - - Attributes - ---------- - get: Method - The method for retrieving data from the input. Accepts an empty - argument, returns a structure. - btn: Signal, in - The button input. - dat: MethodStruct, in - The data input. - """ - - def __init__(self, layout: MethodLayout, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - layout: method layout - The data format for the input. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - src_loc = get_src_loc(src_loc) - self.get = Method(o=layout, src_loc=src_loc) - self.btn = Signal() - self.dat = Signal(from_method_layout(layout)) - - def elaborate(self, platform): - m = TModule() - - btn1 = Signal() - btn2 = Signal() - dat1 = Signal.like(self.dat) - m.d.sync += btn1.eq(self.btn) - m.d.sync += btn2.eq(btn1) - m.d.sync += dat1.eq(self.dat) - get_ready = Signal() - get_data = Signal.like(self.dat) - - @def_method(m, self.get, ready=get_ready) - def _(): - m.d.sync += get_ready.eq(0) - return get_data - - with m.If(~btn2 & btn1): - m.d.sync += get_ready.eq(1) - m.d.sync += get_data.eq(dat1) - - return m - - -class ClickOut(Elaboratable): - """Clicked output. - - Useful for interactive simulations or FPGA button/LED interfaces. - On a rising edge (tested synchronously) of `btn`, the `put` method - is enabled, which, when called, changes the value of the `dat` signal. - - Attributes - ---------- - put: Method - The method for retrieving data from the input. Accepts a structure, - returns empty result. - btn: Signal, in - The button input. - dat: MethodStruct, out - The data output. - """ - - def __init__(self, layout: MethodLayout, *, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - layout: method layout - The data format for the output. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - src_loc = get_src_loc(src_loc) - self.put = Method(i=layout, src_loc=src_loc) - self.btn = Signal() - self.dat = Signal(from_method_layout(layout)) - - def elaborate(self, platform): - m = TModule() - - btn1 = Signal() - btn2 = Signal() - m.d.sync += btn1.eq(self.btn) - m.d.sync += btn2.eq(btn1) - - @def_method(m, self.put, ready=~btn2 & btn1) - def _(arg): - m.d.sync += self.dat.eq(arg) - - return m diff --git a/transactron/lib/connectors.py b/transactron/lib/connectors.py deleted file mode 100644 index 723660ff9..000000000 --- a/transactron/lib/connectors.py +++ /dev/null @@ -1,424 +0,0 @@ -from amaranth import * -from amaranth.lib.data import View -import amaranth.lib.fifo - -from transactron.utils.transactron_helpers import from_method_layout -from ..core import * -from ..utils import SrcLoc, get_src_loc, MethodLayout - -__all__ = [ - "FIFO", - "Forwarder", - "Connect", - "ConnectTrans", - "ManyToOneConnectTrans", - "StableSelectingNetwork", - "Pipe", -] - - -class FIFO(Elaboratable): - """FIFO module. - - Provides a transactional interface to Amaranth FIFOs. Exposes two methods: - `read`, and `write`. Both methods are ready only when they can - be executed -- i.e. the queue is respectively not empty / not full. - It is possible to simultaneously read and write in a single clock cycle, - but only if both readiness conditions are fulfilled. - - Attributes - ---------- - read: Method - The read method. Accepts an empty argument, returns a structure. - write: Method - The write method. Accepts a structure, returns empty result. - """ - - def __init__( - self, layout: MethodLayout, depth: int, fifo_type=amaranth.lib.fifo.SyncFIFO, *, src_loc: int | SrcLoc = 0 - ): - """ - Parameters - ---------- - layout: method layout - The format of structures stored in the FIFO. - depth: int - Size of the FIFO. - fifoType: Elaboratable - FIFO module conforming to Amaranth library FIFO interface. Defaults - to SyncFIFO. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.layout = from_method_layout(layout) - self.width = self.layout.size - self.depth = depth - self.fifoType = fifo_type - - src_loc = get_src_loc(src_loc) - self.read = Method(o=layout, src_loc=src_loc) - self.write = Method(i=layout, src_loc=src_loc) - - def elaborate(self, platform): - m = TModule() - - m.submodules.fifo = fifo = self.fifoType(width=self.width, depth=self.depth) - - @def_method(m, self.write, ready=fifo.w_rdy) - def _(arg): - m.d.comb += fifo.w_en.eq(1) - m.d.top_comb += fifo.w_data.eq(arg) - - @def_method(m, self.read, ready=fifo.r_rdy) - def _(): - m.d.comb += fifo.r_en.eq(1) - return View(self.layout, fifo.r_data) # remove View after Amaranth upgrade - - return m - - -# Forwarding with overflow buffering - - -class Forwarder(Elaboratable): - """Forwarding with overflow buffering - - Provides a means to connect two transactions with forwarding. Exposes - two methods: `read`, and `write`. When both of these methods are - executed simultaneously, data is forwarded between them. If `write` - is executed, but `read` is not, the value cannot be forwarded, - but is stored into an overflow buffer. No further `write`\\s are - possible until the overflow buffer is cleared by `read`. - - The `write` method is scheduled before `read`. - - Attributes - ---------- - read: Method - The read method. Accepts an empty argument, returns a structure. - write: Method - The write method. Accepts a structure, returns empty result. - """ - - def __init__(self, layout: MethodLayout, *, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - layout: method layout - The format of structures forwarded. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - src_loc = get_src_loc(src_loc) - self.read = Method(o=layout, src_loc=src_loc) - self.write = Method(i=layout, src_loc=src_loc) - self.clear = Method(src_loc=src_loc) - self.head = Signal.like(self.read.data_out) - - self.clear.add_conflict(self.read, Priority.LEFT) - self.clear.add_conflict(self.write, Priority.LEFT) - - def elaborate(self, platform): - m = TModule() - - reg = Signal.like(self.read.data_out) - reg_valid = Signal() - read_value = Signal.like(self.read.data_out) - m.d.comb += self.head.eq(read_value) - - self.write.schedule_before(self.read) # to avoid combinational loops - - @def_method(m, self.write, ready=~reg_valid) - def _(arg): - m.d.av_comb += read_value.eq(arg) # for forwarding - m.d.sync += reg.eq(arg) - m.d.sync += reg_valid.eq(1) - - with m.If(reg_valid): - m.d.av_comb += read_value.eq(reg) # write method is not ready - - @def_method(m, self.read, ready=reg_valid | self.write.run) - def _(): - m.d.sync += reg_valid.eq(0) - return read_value - - @def_method(m, self.clear) - def _(): - m.d.sync += reg_valid.eq(0) - - return m - - -class Pipe(Elaboratable): - """ - This module implements a `Pipe`. It is a halfway between - `Forwarder` and `2-FIFO`. In the `Pipe` data is always - stored localy, so the critical path of the data is cut, but there is a - combinational path between the control signals of the `read` and - the `write` methods. For comparison: - - in `Forwarder` there is both a data and a control combinational path - - in `2-FIFO` there are no combinational paths - - The `read` method is scheduled before the `write`. - - Attributes - ---------- - read: Method - Reads from the pipe. Accepts an empty argument, returns a structure. - Ready only if the pipe is not empty. - write: Method - Writes to the pipe. Accepts a structure, returns empty result. - Ready only if the pipe is not full. - clean: Method - Cleans the pipe. Has priority over `read` and `write` methods. - """ - - def __init__(self, layout: MethodLayout): - """ - Parameters - ---------- - layout: record layout - The format of records forwarded. - """ - self.read = Method(o=layout) - self.write = Method(i=layout) - self.clean = Method() - self.head = Signal.like(self.read.data_out) - - self.clean.add_conflict(self.read, Priority.LEFT) - self.clean.add_conflict(self.write, Priority.LEFT) - - def elaborate(self, platform): - m = TModule() - - reg = Signal.like(self.read.data_out) - reg_valid = Signal() - - self.read.schedule_before(self.write) # to avoid combinational loops - - @def_method(m, self.read, ready=reg_valid) - def _(): - m.d.sync += reg_valid.eq(0) - return reg - - @def_method(m, self.write, ready=~reg_valid | self.read.run) - def _(arg): - m.d.sync += reg.eq(arg) - m.d.sync += reg_valid.eq(1) - - @def_method(m, self.clean) - def _(): - m.d.sync += reg_valid.eq(0) - - return m - - -class Connect(Elaboratable): - """Forwarding by transaction simultaneity - - Provides a means to connect two transactions with forwarding - by means of the transaction simultaneity mechanism. It provides - two methods: `read`, and `write`, which always execute simultaneously. - Typical use case is for moving data from `write` to `read`, but - data flow in the reverse direction is also possible. - - Attributes - ---------- - read: Method - The read method. Accepts a (possibly empty) structure, returns - a structure. - write: Method - The write method. Accepts a structure, returns a (possibly empty) - structure. - """ - - def __init__(self, layout: MethodLayout = (), rev_layout: MethodLayout = (), *, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - layout: method layout - The format of structures forwarded. - rev_layout: method layout - The format of structures forwarded in the reverse direction. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - src_loc = get_src_loc(src_loc) - self.read = Method(o=layout, i=rev_layout, src_loc=src_loc) - self.write = Method(i=layout, o=rev_layout, src_loc=src_loc) - - def elaborate(self, platform): - m = TModule() - - read_value = Signal.like(self.read.data_out) - rev_read_value = Signal.like(self.write.data_out) - - self.write.simultaneous(self.read) - - @def_method(m, self.write) - def _(arg): - m.d.av_comb += read_value.eq(arg) - return rev_read_value - - @def_method(m, self.read) - def _(arg): - m.d.av_comb += rev_read_value.eq(arg) - return read_value - - return m - - -class ConnectTrans(Elaboratable): - """Simple connecting transaction. - - Takes two methods and creates a transaction which calls both of them. - Result of the first method is connected to the argument of the second, - and vice versa. Allows easily connecting methods with compatible - layouts. - """ - - def __init__(self, method1: Method, method2: Method, *, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - method1: Method - First method. - method2: Method - Second method. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.method1 = method1 - self.method2 = method2 - self.src_loc = get_src_loc(src_loc) - - def elaborate(self, platform): - m = TModule() - - with Transaction(src_loc=self.src_loc).body(m): - data1 = Signal.like(self.method1.data_out) - data2 = Signal.like(self.method2.data_out) - - m.d.top_comb += data1.eq(self.method1(m, data2)) - m.d.top_comb += data2.eq(self.method2(m, data1)) - - return m - - -class ManyToOneConnectTrans(Elaboratable): - """Many-to-one method connection. - - Connects each of a set of methods to another method using separate - transactions. Equivalent to a set of `ConnectTrans`. - """ - - def __init__(self, *, get_results: list[Method], put_result: Method, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - get_results: list[Method] - Methods to be connected to the `put_result` method. - put_result: Method - Common method for each of the connections created. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.get_results = get_results - self.m_put_result = put_result - - self.count = len(self.get_results) - self.src_loc = get_src_loc(src_loc) - - def elaborate(self, platform): - m = TModule() - - for i in range(self.count): - m.submodules[f"ManyToOneConnectTrans_input_{i}"] = ConnectTrans( - self.m_put_result, self.get_results[i], src_loc=self.src_loc - ) - - return m - - -class StableSelectingNetwork(Elaboratable): - """A network that groups inputs with a valid bit set. - - The circuit takes `n` inputs with a valid signal each and - on the output returns a grouped and consecutive sequence of the provided - input signals. The order of valid inputs is preserved. - - For example for input (0 is an invalid input): - 0, a, 0, d, 0, 0, e - - The circuit will return: - a, d, e, 0, 0, 0, 0 - - The circuit uses a divide and conquer algorithm. - The recursive call takes two bit vectors and each of them - is already properly sorted, for example: - v1 = [a, b, 0, 0]; v2 = [c, d, e, 0] - - Now by shifting left v2 and merging it with v1, we get the result: - v = [a, b, c, d, e, 0, 0, 0] - - Thus, the network has depth log_2(n). - - """ - - def __init__(self, n: int, layout: MethodLayout): - self.n = n - self.layout = from_method_layout(layout) - - self.inputs = [Signal(self.layout) for _ in range(n)] - self.valids = [Signal() for _ in range(n)] - - self.outputs = [Signal(self.layout) for _ in range(n)] - self.output_cnt = Signal(range(n + 1)) - - def elaborate(self, platform): - m = TModule() - - current_level = [] - for i in range(self.n): - current_level.append((Array([self.inputs[i]]), self.valids[i])) - - # Create the network using the bottom-up approach. - while len(current_level) >= 2: - next_level = [] - while len(current_level) >= 2: - a, cnt_a = current_level.pop(0) - b, cnt_b = current_level.pop(0) - - total_cnt = Signal(max(len(cnt_a), len(cnt_b)) + 1) - m.d.comb += total_cnt.eq(cnt_a + cnt_b) - - total_len = len(a) + len(b) - merged = Array(Signal(self.layout) for _ in range(total_len)) - - for i in range(len(a)): - m.d.comb += merged[i].eq(Mux(cnt_a <= i, b[i - cnt_a], a[i])) - for i in range(len(b)): - m.d.comb += merged[len(a) + i].eq(Mux(len(a) + i - cnt_a >= len(b), 0, b[len(a) + i - cnt_a])) - - next_level.append((merged, total_cnt)) - - # If we had an odd number of elements on the current level, - # move the item left to the next level. - if len(current_level) == 1: - next_level.append(current_level.pop(0)) - - current_level = next_level - - last_level, total_cnt = current_level.pop(0) - - for i in range(self.n): - m.d.comb += self.outputs[i].eq(last_level[i]) - - m.d.comb += self.output_cnt.eq(total_cnt) - - return m diff --git a/transactron/lib/dependencies.py b/transactron/lib/dependencies.py deleted file mode 100644 index c7b099b76..000000000 --- a/transactron/lib/dependencies.py +++ /dev/null @@ -1,34 +0,0 @@ -from collections.abc import Callable - -from .. import Method -from .transformers import Unifier -from ..utils.dependencies import * - - -__all__ = ["DependencyManager", "DependencyKey", "SimpleKey", "ListKey", "UnifierKey"] - - -class UnifierKey(DependencyKey["Method", tuple["Method", dict[str, "Unifier"]]]): - """Base class for method unifier dependency keys. - - Method unifier dependency keys are used to collect methods to be called by - some part of the core. As multiple modules may wish to be called, a method - unifier is used to present a single method interface to the caller, which - allows to customize the calling behavior. - """ - - unifier: Callable[[list["Method"]], "Unifier"] - - def __init_subclass__(cls, unifier: Callable[[list["Method"]], "Unifier"], **kwargs) -> None: - cls.unifier = unifier - return super().__init_subclass__(**kwargs) - - def combine(self, data: list["Method"]) -> tuple["Method", dict[str, "Unifier"]]: - if len(data) == 1: - return data[0], {} - else: - unifiers: dict[str, Unifier] = {} - unifier_inst = self.unifier(data) - unifiers[self.__class__.__name__ + "_unifier"] = unifier_inst - method = unifier_inst.method - return method, unifiers diff --git a/transactron/lib/fifo.py b/transactron/lib/fifo.py deleted file mode 100644 index f9d43c30f..000000000 --- a/transactron/lib/fifo.py +++ /dev/null @@ -1,165 +0,0 @@ -from amaranth import * -import amaranth.lib.memory as memory -from transactron import Method, def_method, Priority, TModule -from transactron.utils._typing import ValueLike, MethodLayout, SrcLoc, MethodStruct -from transactron.utils.amaranth_ext import mod_incr -from transactron.utils.transactron_helpers import from_method_layout, get_src_loc - - -class BasicFifo(Elaboratable): - """Transactional FIFO queue - - Attributes - ---------- - read: Method - Reads from the FIFO. Accepts an empty argument, returns a structure. - Ready only if the FIFO is not empty. - peek: Method - Returns the element at the front (but not delete). Ready only if the FIFO - is not empty. The method is nonexclusive. - write: Method - Writes to the FIFO. Accepts a structure, returns empty result. - Ready only if the FIFO is not full. - clear: Method - Clears the FIFO entries. Has priority over `read` and `write` methods. - Note that, clearing the FIFO doesn't reinitialize it to values passed in `init` parameter. - - """ - - def __init__(self, layout: MethodLayout, depth: int, *, src_loc: int | SrcLoc = 0) -> None: - """ - Parameters - ---------- - layout: method layout - Layout of data stored in the FIFO. - depth: int - Size of the FIFO. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.layout = layout - self.width = from_method_layout(self.layout).size - self.depth = depth - - src_loc = get_src_loc(src_loc) - self.read = Method(o=self.layout, src_loc=src_loc) - self.peek = Method(o=self.layout, nonexclusive=True, src_loc=src_loc) - self.write = Method(i=self.layout, src_loc=src_loc) - self.clear = Method(src_loc=src_loc) - self.head = Signal(from_method_layout(layout)) - - self.buff = memory.Memory(shape=self.width, depth=self.depth, init=[]) - - self.write_ready = Signal() - self.read_ready = Signal() - - self.read_idx = Signal((self.depth - 1).bit_length()) - self.write_idx = Signal((self.depth - 1).bit_length()) - # current fifo depth - self.level = Signal((self.depth).bit_length()) - - # for interface compatibility with MultiportFifo - self.read_methods = [self.read] - self.write_methods = [self.write] - - def elaborate(self, platform): - m = TModule() - - next_read_idx = Signal.like(self.read_idx) - m.d.comb += next_read_idx.eq(mod_incr(self.read_idx, self.depth)) - - m.submodules.buff = self.buff - self.buff_wrport = self.buff.write_port() - self.buff_rdport = self.buff.read_port(domain="sync", transparent_for=[self.buff_wrport]) - - m.d.comb += self.read_ready.eq(self.level != 0) - m.d.comb += self.write_ready.eq(self.level != self.depth) - - with m.If(self.read.run & ~self.write.run): - m.d.sync += self.level.eq(self.level - 1) - with m.If(self.write.run & ~self.read.run): - m.d.sync += self.level.eq(self.level + 1) - with m.If(self.clear.run): - m.d.sync += self.level.eq(0) - - m.d.comb += self.buff_rdport.addr.eq(Mux(self.read.run, next_read_idx, self.read_idx)) - m.d.comb += self.head.eq(self.buff_rdport.data) - - @def_method(m, self.write, ready=self.write_ready) - def _(arg: MethodStruct) -> None: - m.d.top_comb += self.buff_wrport.addr.eq(self.write_idx) - m.d.top_comb += self.buff_wrport.data.eq(arg) - m.d.comb += self.buff_wrport.en.eq(1) - - m.d.sync += self.write_idx.eq(mod_incr(self.write_idx, self.depth)) - - @def_method(m, self.read, self.read_ready) - def _() -> ValueLike: - m.d.sync += self.read_idx.eq(next_read_idx) - return self.head - - @def_method(m, self.peek, self.read_ready) - def _() -> ValueLike: - return self.head - - @def_method(m, self.clear) - def _() -> None: - m.d.sync += self.read_idx.eq(0) - m.d.sync += self.write_idx.eq(0) - - return m - - -class Semaphore(Elaboratable): - """Semaphore""" - - def __init__(self, max_count: int) -> None: - """ - Parameters - ---------- - size: int - Size of the semaphore. - - """ - self.max_count = max_count - - self.acquire = Method() - self.release = Method() - self.clear = Method() - - self.acquire_ready = Signal() - self.release_ready = Signal() - - self.count = Signal(self.max_count.bit_length()) - self.count_next = Signal(self.max_count.bit_length()) - - self.clear.add_conflict(self.acquire, Priority.LEFT) - self.clear.add_conflict(self.release, Priority.LEFT) - - def elaborate(self, platform) -> TModule: - m = TModule() - - m.d.comb += self.release_ready.eq(self.count > 0) - m.d.comb += self.acquire_ready.eq(self.count < self.max_count) - - with m.If(self.clear.run): - m.d.comb += self.count_next.eq(0) - with m.Else(): - m.d.comb += self.count_next.eq(self.count + self.acquire.run - self.release.run) - - m.d.sync += self.count.eq(self.count_next) - - @def_method(m, self.acquire, ready=self.acquire_ready) - def _() -> None: - pass - - @def_method(m, self.release, ready=self.release_ready) - def _() -> None: - pass - - @def_method(m, self.clear) - def _() -> None: - pass - - return m diff --git a/transactron/lib/logging.py b/transactron/lib/logging.py deleted file mode 100644 index 7eb06deb1..000000000 --- a/transactron/lib/logging.py +++ /dev/null @@ -1,229 +0,0 @@ -import os -import re -import operator -import logging -from functools import reduce -from dataclasses import dataclass, field -from dataclasses_json import dataclass_json -from typing import TypeAlias - -from amaranth import * -from amaranth.tracer import get_src_loc - -from transactron.utils import SrcLoc -from transactron.utils._typing import ModuleLike, ValueLike -from transactron.utils.dependencies import DependencyContext, ListKey - -LogLevel: TypeAlias = int - - -@dataclass_json -@dataclass -class LogRecordInfo: - """Simulator-backend-agnostic information about a log record that can - be serialized and used outside the Amaranth context. - - Attributes - ---------- - logger_name: str - - level: LogLevel - The severity level of the log. - format_str: str - The template of the message. Should follow PEP 3101 standard. - location: SrcLoc - Source location of the log. - """ - - logger_name: str - level: LogLevel - format_str: str - location: SrcLoc - - def format(self, *args) -> str: - """Format the log message with a set of concrete arguments.""" - - return self.format_str.format(*args) - - -@dataclass -class LogRecord(LogRecordInfo): - """A LogRecord instance represents an event being logged. - - Attributes - ---------- - trigger: Signal - Amaranth signal triggering the log. - fields: Signal - Amaranth signals that will be used to format the message. - """ - - trigger: Signal - fields: list[Signal] = field(default_factory=list) - - -@dataclass(frozen=True) -class LogKey(ListKey[LogRecord]): - pass - - -class HardwareLogger: - """A class for creating log messages in the hardware. - - Intuitively, the hardware logger works similarly to a normal software - logger. You can log a message anywhere in the circuit, but due to the - parallel nature of the hardware you must specify a special trigger signal - which will indicate if a message shall be reported in that cycle. - - Hardware logs are evaluated and printed during simulation, so both - the trigger and the format fields are Amaranth values, i.e. - signals or arbitrary Amaranth expressions. - - Instances of the HardwareLogger class represent a logger for a single - submodule of the circuit. Exactly how a "submodule" is defined is up - to the developer. Submodule are identified by a unique string and - the names can be nested. Names are organized into a namespace hierarchy - where levels are separated by periods, much like the Python package - namespace. So in the instance, submodules names might be "frontend" - for the upper level, and "frontend.icache" and "frontend.bpu" for - the sub-levels. There is no arbitrary limit to the depth of nesting. - - Attributes - ---------- - name: str - Name of this logger. - """ - - def __init__(self, name: str): - """ - Parameters - ---------- - name: str - Name of this logger. Hierarchy levels are separated by periods, - e.g. "backend.fu.jumpbranch". - """ - self.name = name - - def log(self, m: ModuleLike, level: LogLevel, trigger: ValueLike, format: str, *args, src_loc_at: int = 0): - """Registers a hardware log record with the given severity. - - Parameters - ---------- - m: ModuleLike - The module for which the log record is added. - trigger: ValueLike - If the value of this Amaranth expression is true, the log will reported. - format: str - The format of the message as defined in PEP 3101. - *args - Amaranth values that will be read during simulation and used to format - the message. - src_loc_at: int, optional - How many stack frames below to look for the source location, used to - identify the failing assertion. - """ - - def local_src_loc(src_loc: SrcLoc): - return (os.path.relpath(src_loc[0]), src_loc[1]) - - src_loc = local_src_loc(get_src_loc(src_loc_at + 1)) - - trigger_signal = Signal() - m.d.comb += trigger_signal.eq(trigger) - - record = LogRecord( - logger_name=self.name, level=level, format_str=format, location=src_loc, trigger=trigger_signal - ) - - for arg in args: - sig = Signal.like(arg) - m.d.top_comb += sig.eq(arg) - record.fields.append(sig) - - dependencies = DependencyContext.get() - dependencies.add_dependency(LogKey(), record) - - def debug(self, m: ModuleLike, trigger: ValueLike, format: str, *args, **kwargs): - """Log a message with severity 'DEBUG'. - - See `HardwareLogger.log` function for more details. - """ - self.log(m, logging.DEBUG, trigger, format, *args, **kwargs) - - def info(self, m: ModuleLike, trigger: ValueLike, format: str, *args, **kwargs): - """Log a message with severity 'INFO'. - - See `HardwareLogger.log` function for more details. - """ - self.log(m, logging.INFO, trigger, format, *args, **kwargs) - - def warning(self, m: ModuleLike, trigger: ValueLike, format: str, *args, **kwargs): - """Log a message with severity 'WARNING'. - - See `HardwareLogger.log` function for more details. - """ - self.log(m, logging.WARNING, trigger, format, *args, **kwargs) - - def error(self, m: ModuleLike, trigger: ValueLike, format: str, *args, **kwargs): - """Log a message with severity 'ERROR'. - - This severity level has special semantics. If a log with this serverity - level is triggered, the simulation will be terminated. - - See `HardwareLogger.log` function for more details. - """ - self.log(m, logging.ERROR, trigger, format, *args, **kwargs) - - def assertion(self, m: ModuleLike, value: Value, format: str = "", *args, src_loc_at: int = 0, **kwargs): - """Define an assertion. - - This function might help find some hardware bugs which might otherwise be - hard to detect. If `value` is false, it will terminate the simulation or - it can also be used to turn on a warning LED on a board. - - Internally, this is a convenience wrapper over log.error. - - See `HardwareLogger.log` function for more details. - """ - self.error(m, ~value, format, *args, **kwargs, src_loc_at=src_loc_at + 1) - - -def get_log_records(level: LogLevel, namespace_regexp: str = ".*") -> list[LogRecord]: - """Get log records in for the given severity level and in the - specified namespace. - - This function returns all log records with the severity bigger or equal - to the specified level and belonging to the specified namespace. - - Parameters - ---------- - level: LogLevel - The minimum severity level. - namespace: str, optional - The regexp of the namespace. If not specified, logs from all namespaces - will be processed. - """ - - dependencies = DependencyContext.get() - all_logs = dependencies.get_dependency(LogKey()) - return [rec for rec in all_logs if rec.level >= level and re.search(namespace_regexp, rec.logger_name)] - - -def get_trigger_bit(level: LogLevel, namespace_regexp: str = ".*") -> Value: - """Get a trigger bit for logs of the given severity level and - in the specified namespace. - - The signal returned by this function is high whenever the trigger signal - of any of the records with the severity bigger or equal to the specified - level is high. - - Parameters - ---------- - level: LogLevel - The minimum severity level. - namespace: str, optional - The regexp of the namespace. If not specified, logs from all namespaces - will be processed. - """ - - return reduce(operator.or_, [rec.trigger for rec in get_log_records(level, namespace_regexp)], C(0)) diff --git a/transactron/lib/metrics.py b/transactron/lib/metrics.py deleted file mode 100644 index 78f5c5e53..000000000 --- a/transactron/lib/metrics.py +++ /dev/null @@ -1,822 +0,0 @@ -from dataclasses import dataclass, field -from dataclasses_json import dataclass_json -from typing import Optional, Type -from abc import ABC -from enum import Enum - -from amaranth import * -from amaranth.utils import bits_for, ceil_log2, exact_log2 - -from transactron.utils import ValueLike, OneHotSwitchDynamic, SignalBundle -from transactron import Method, def_method, TModule -from transactron.lib import FIFO, AsyncMemoryBank, logging -from transactron.utils.dependencies import ListKey, DependencyContext, SimpleKey - -__all__ = [ - "MetricRegisterModel", - "MetricModel", - "HwMetric", - "HwCounter", - "TaggedCounter", - "HwExpHistogram", - "FIFOLatencyMeasurer", - "TaggedLatencyMeasurer", - "HardwareMetricsManager", - "HwMetricsEnabledKey", -] - - -@dataclass_json -@dataclass(frozen=True) -class MetricRegisterModel: - """ - Represents a single register of a metric, serving as a fundamental - building block that holds a singular value. - - Attributes - ---------- - name: str - The unique identifier for the register (among remaning - registers of a specific metric). - description: str - A brief description of the metric's purpose. - width: int - The bit-width of the register. - """ - - name: str - description: str - width: int - - -@dataclass_json -@dataclass -class MetricModel: - """ - Provides information about a metric exposed by the circuit. Each metric - comprises multiple registers, each dedicated to storing specific values. - - The configuration of registers is internally determined by a - specific metric type and is not user-configurable. - - Attributes - ---------- - fully_qualified_name: str - The fully qualified name of the metric, with name components joined by dots ('.'), - e.g., 'foo.bar.requests'. - description: str - A human-readable description of the metric's functionality. - regs: list[MetricRegisterModel] - A list of registers associated with the metric. - """ - - fully_qualified_name: str - description: str - regs: dict[str, MetricRegisterModel] = field(default_factory=dict) - - -class HwMetricRegister(MetricRegisterModel): - """ - A concrete implementation of a metric register that holds its value as Amaranth signal. - - Attributes - ---------- - value: Signal - Amaranth signal representing the value of the register. - """ - - def __init__(self, name: str, width_bits: int, description: str = "", init: int = 0): - """ - Parameters - ---------- - name: str - The unique identifier for the register (among remaning - registers of a specific metric). - width: int - The bit-width of the register. - description: str - A brief description of the metric's purpose. - init: int - The reset value of the register. - """ - super().__init__(name, description, width_bits) - - self.value = Signal(width_bits, init=init, name=name) - - -@dataclass(frozen=True) -class HwMetricsListKey(ListKey["HwMetric"]): - """DependencyManager key collecting hardware metrics globally as a list.""" - - pass - - -@dataclass(frozen=True) -class HwMetricsEnabledKey(SimpleKey[bool]): - """ - DependencyManager key for enabling hardware metrics. If metrics are disabled, - none of theirs signals will be synthesized. - """ - - lock_on_get = False - empty_valid = True - default_value = False - - -class HwMetric(ABC, MetricModel): - """ - A base for all metric implementations. It should be only used for declaring - new types of metrics. - - It takes care of registering the metric in the dependency manager. - - Attributes - ---------- - signals: dict[str, Signal] - A mapping from a register name to a Signal containing the value of that register. - """ - - def __init__(self, fully_qualified_name: str, description: str): - """ - Parameters - ---------- - fully_qualified_name: str - The fully qualified name of the metric. - description: str - A human-readable description of the metric's functionality. - """ - super().__init__(fully_qualified_name, description) - - self.signals: dict[str, Signal] = {} - - # add the metric to the global list of all metrics - DependencyContext.get().add_dependency(HwMetricsListKey(), self) - - # So Amaranth doesn't report that the module is unused when metrics are disabled - self._MustUse__silence = True # type: ignore - - def add_registers(self, regs: list[HwMetricRegister]): - """ - Adds registers to a metric. Should be only called by inheriting classes - during initialization. - - Parameters - ---------- - regs: list[HwMetricRegister] - A list of registers to be registered. - """ - for reg in regs: - if reg.name in self.regs: - raise RuntimeError(f"Register {reg.name}' is already added to the metric {self.fully_qualified_name}") - - self.regs[reg.name] = reg - self.signals[reg.name] = reg.value - - def metrics_enabled(self) -> bool: - return DependencyContext.get().get_dependency(HwMetricsEnabledKey()) - - # To restore hashability lost by dataclass subclassing - def __hash__(self): - return object.__hash__(self) - - -class HwCounter(Elaboratable, HwMetric): - """Hardware Counter - - The most basic hardware metric that can just increase its value. - """ - - def __init__(self, fully_qualified_name: str, description: str = "", *, width_bits: int = 32): - """ - Parameters - ---------- - fully_qualified_name: str - The fully qualified name of the metric. - description: str - A human-readable description of the metric's functionality. - width_bits: int - The bit-width of the register. Defaults to 32 bits. - """ - - super().__init__(fully_qualified_name, description) - - self.count = HwMetricRegister("count", width_bits, "the value of the counter") - - self.add_registers([self.count]) - - self._incr = Method() - - def elaborate(self, platform): - if not self.metrics_enabled(): - return TModule() - - m = TModule() - - @def_method(m, self._incr) - def _(): - m.d.sync += self.count.value.eq(self.count.value + 1) - - return m - - def incr(self, m: TModule, *, cond: ValueLike = C(1)): - """ - Increases the value of the counter by 1. - - Should be called in the body of either a transaction or a method. - - Parameters - ---------- - m: TModule - Transactron module - """ - if not self.metrics_enabled(): - return - - with m.If(cond): - self._incr(m) - - -class TaggedCounter(Elaboratable, HwMetric): - """Hardware Tagged Counter - - Like HwCounter, but contains multiple counters, each with its own tag. - At a time a single counter can be increased and the value of the tag - can be provided dynamically. The type of the tag can be either an int - enum, a range or a list of integers (negative numbers are ok). - - Internally, it detects if tag values can be one-hot encoded and if so, - it generates more optimized circuit. - - Attributes - ---------- - tag_width: int - The length of the signal holding a tag value. - one_hot: bool - Whether tag values can be one-hot encoded. - counters: dict[int, HwMetricRegisters] - Mapping from a tag value to a register holding a counter for that tag. - """ - - def __init__( - self, - fully_qualified_name: str, - description: str = "", - *, - tags: range | Type[Enum] | list[int], - registers_width: int = 32, - ): - """ - Parameters - ---------- - fully_qualified_name: str - The fully qualified name of the metric. - description: str - A human-readable description of the metric's functionality. - tags: range | Type[Enum] | list[int] - Tag values. - registers_width: int - Width of the underlying registers. Defaults to 32 bits. - """ - - super().__init__(fully_qualified_name, description) - - if isinstance(tags, range) or isinstance(tags, list): - counters_meta = [(i, f"{i}") for i in tags] - else: - counters_meta = [(i.value, i.name) for i in tags] - - values = [value for value, _ in counters_meta] - self.tag_width = max(bits_for(max(values)), bits_for(min(values))) - - self.one_hot = True - negative_values = False - for value in values: - if value < 0: - self.one_hot = False - negative_values = True - break - - log = ceil_log2(value) - if 2**log != value: - self.one_hot = False - - self._incr = Method(i=[("tag", Shape(self.tag_width, signed=negative_values))]) - - self.counters: dict[int, HwMetricRegister] = {} - for tag_value, name in counters_meta: - value_str = ("1<<" + str(exact_log2(tag_value))) if self.one_hot else str(tag_value) - description = f"the counter for tag {name} (value={value_str})" - - self.counters[tag_value] = HwMetricRegister( - name, - registers_width, - description, - ) - - self.add_registers(list(self.counters.values())) - - def elaborate(self, platform): - if not self.metrics_enabled(): - return TModule() - - m = TModule() - - @def_method(m, self._incr) - def _(tag): - if self.one_hot: - sorted_tags = sorted(list(self.counters.keys())) - for i in OneHotSwitchDynamic(m, tag): - counter = self.counters[sorted_tags[i]] - m.d.sync += counter.value.eq(counter.value + 1) - else: - for tag_value, counter in self.counters.items(): - with m.If(tag == tag_value): - m.d.sync += counter.value.eq(counter.value + 1) - - return m - - def incr(self, m: TModule, tag: ValueLike, *, cond: ValueLike = C(1)): - """ - Increases the counter of a given tag by 1. - - Should be called in the body of either a transaction or a method. - - Parameters - ---------- - m: TModule - Transactron module - tag: ValueLike - The tag of the counter. - cond: ValueLike - When set to high, the counter will be increased. By default set to high. - """ - if not self.metrics_enabled(): - return - - with m.If(cond): - self._incr(m, tag) - - -class HwExpHistogram(Elaboratable, HwMetric): - """Hardware Exponential Histogram - - Represents the distribution of sampled data through a histogram. A histogram - samples observations (usually things like request durations or queue sizes) and counts - them in a configurable number of buckets. The buckets are of exponential size. For example, - a histogram with 5 buckets would have the following value ranges: - [0, 1); [1, 2); [2, 4); [4, 8); [8, +inf). - - Additionally, the histogram tracks the number of observations, the sum - of observed values, and the minimum and maximum values. - """ - - def __init__( - self, - fully_qualified_name: str, - description: str = "", - *, - bucket_count: int, - sample_width: int = 32, - registers_width: int = 32, - ): - """ - Parameters - ---------- - fully_qualified_name: str - The fully qualified name of the metric. - description: str - A human-readable description of the metric's functionality. - max_value: int - The maximum value that the histogram would be able to count. This - value is used to calculate the number of buckets. - """ - - super().__init__(fully_qualified_name, description) - self.bucket_count = bucket_count - self.sample_width = sample_width - - self._add = Method(i=[("sample", self.sample_width)]) - - self.count = HwMetricRegister("count", registers_width, "the count of events that have been observed") - self.sum = HwMetricRegister("sum", registers_width, "the total sum of all observed values") - self.min = HwMetricRegister( - "min", - self.sample_width, - "the minimum of all observed values", - init=(1 << self.sample_width) - 1, - ) - self.max = HwMetricRegister("max", self.sample_width, "the maximum of all observed values") - - self.buckets = [] - for i in range(self.bucket_count): - bucket_start = 0 if i == 0 else 2 ** (i - 1) - bucket_end = "inf" if i == self.bucket_count - 1 else 2**i - - self.buckets.append( - HwMetricRegister( - f"bucket-{bucket_end}", - registers_width, - f"the cumulative counter for the observation bucket [{bucket_start}, {bucket_end})", - ) - ) - - self.add_registers([self.count, self.sum, self.max, self.min] + self.buckets) - - def elaborate(self, platform): - if not self.metrics_enabled(): - return TModule() - - m = TModule() - - @def_method(m, self._add) - def _(sample): - m.d.sync += self.count.value.eq(self.count.value + 1) - m.d.sync += self.sum.value.eq(self.sum.value + sample) - - with m.If(sample > self.max.value): - m.d.sync += self.max.value.eq(sample) - - with m.If(sample < self.min.value): - m.d.sync += self.min.value.eq(sample) - - # todo: perhaps replace with a recursive implementation of the priority encoder - bucket_idx = Signal(range(self.sample_width)) - for i in range(self.sample_width): - with m.If(sample[i]): - m.d.av_comb += bucket_idx.eq(i) - - for i, bucket in enumerate(self.buckets): - should_incr = C(0) - if i == 0: - # The first bucket has a range [0, 1). - should_incr = sample == 0 - elif i == self.bucket_count - 1: - # The last bucket should count values bigger or equal to 2**(self.bucket_count-1) - should_incr = (bucket_idx >= i - 1) & (sample != 0) - else: - should_incr = (bucket_idx == i - 1) & (sample != 0) - - with m.If(should_incr): - m.d.sync += bucket.value.eq(bucket.value + 1) - - return m - - def add(self, m: TModule, sample: Value): - """ - Adds a new sample to the histogram. - - Should be called in the body of either a transaction or a method. - - Parameters - ---------- - m: TModule - Transactron module - sample: ValueLike - The value that will be added to the histogram - """ - - if not self.metrics_enabled(): - return - - self._add(m, sample) - - -class FIFOLatencyMeasurer(Elaboratable): - """ - Measures duration between two events, e.g. request processing latency. - It can track multiple events at the same time, i.e. the second event can - be registered as started, before the first finishes. However, they must be - processed in the FIFO order. - - The module exposes an exponential histogram of the measured latencies. - """ - - def __init__( - self, - fully_qualified_name: str, - description: str = "", - *, - slots_number: int, - max_latency: int, - ): - """ - Parameters - ---------- - fully_qualified_name: str - The fully qualified name of the metric. - description: str - A human-readable description of the metric's functionality. - slots_number: int - A number of events that the module can track simultaneously. - max_latency: int - The maximum latency of an event. Used to set signal widths and - number of buckets in the histogram. If a latency turns to be - bigger than the maximum, it will overflow and result in a false - measurement. - """ - self.fully_qualified_name = fully_qualified_name - self.description = description - self.slots_number = slots_number - self.max_latency = max_latency - - self._start = Method() - self._stop = Method() - - # This bucket count gives us the best possible granularity. - bucket_count = bits_for(self.max_latency) + 1 - self.histogram = HwExpHistogram( - self.fully_qualified_name, - self.description, - bucket_count=bucket_count, - sample_width=bits_for(self.max_latency), - ) - - def elaborate(self, platform): - if not self.metrics_enabled(): - return TModule() - - m = TModule() - - epoch_width = bits_for(self.max_latency) - - m.submodules.fifo = self.fifo = FIFO([("epoch", epoch_width)], self.slots_number) - m.submodules.histogram = self.histogram - - epoch = Signal(epoch_width) - - m.d.sync += epoch.eq(epoch + 1) - - @def_method(m, self._start) - def _(): - self.fifo.write(m, epoch) - - @def_method(m, self._stop) - def _(): - ret = self.fifo.read(m) - # The result of substracting two unsigned n-bit is a signed (n+1)-bit value, - # so we need to cast the result and discard the most significant bit. - duration = (epoch - ret.epoch).as_unsigned()[:-1] - self.histogram.add(m, duration) - - return m - - def start(self, m: TModule): - """ - Registers the start of an event. Can be called before the previous events - finish. If there are no slots available, the method will be blocked. - - Should be called in the body of either a transaction or a method. - - Parameters - ---------- - m: TModule - Transactron module - """ - - if not self.metrics_enabled(): - return - - self._start(m) - - def stop(self, m: TModule): - """ - Registers the end of the oldest event (the FIFO order). If there are no - started events in the queue, the method will block. - - Should be called in the body of either a transaction or a method. - - Parameters - ---------- - m: TModule - Transactron module - """ - - if not self.metrics_enabled(): - return - - self._stop(m) - - def metrics_enabled(self) -> bool: - return DependencyContext.get().get_dependency(HwMetricsEnabledKey()) - - -class TaggedLatencyMeasurer(Elaboratable): - """ - Measures duration between two events, e.g. request processing latency. - It can track multiple events at the same time, i.e. the second event can - be registered as started, before the first finishes. However, each event - needs to have an unique slot tag. - - The module exposes an exponential histogram of the measured latencies. - """ - - def __init__( - self, - fully_qualified_name: str, - description: str = "", - *, - slots_number: int, - max_latency: int, - ): - """ - Parameters - ---------- - fully_qualified_name: str - The fully qualified name of the metric. - description: str - A human-readable description of the metric's functionality. - slots_number: int - A number of events that the module can track simultaneously. - max_latency: int - The maximum latency of an event. Used to set signal widths and - number of buckets in the histogram. If a latency turns to be - bigger than the maximum, it will overflow and result in a false - measurement. - """ - self.fully_qualified_name = fully_qualified_name - self.description = description - self.slots_number = slots_number - self.max_latency = max_latency - - self._start = Method(i=[("slot", range(0, slots_number))]) - self._stop = Method(i=[("slot", range(0, slots_number))]) - - # This bucket count gives us the best possible granularity. - bucket_count = bits_for(self.max_latency) + 1 - self.histogram = HwExpHistogram( - self.fully_qualified_name, - self.description, - bucket_count=bucket_count, - sample_width=bits_for(self.max_latency), - ) - - self.log = logging.HardwareLogger(fully_qualified_name) - - def elaborate(self, platform): - if not self.metrics_enabled(): - return TModule() - - m = TModule() - - epoch_width = bits_for(self.max_latency) - - m.submodules.slots = self.slots = AsyncMemoryBank( - data_layout=[("epoch", epoch_width)], elem_count=self.slots_number - ) - m.submodules.histogram = self.histogram - - slots_taken = Signal(self.slots_number) - slots_taken_start = Signal.like(slots_taken) - slots_taken_stop = Signal.like(slots_taken) - - m.d.comb += slots_taken_start.eq(slots_taken) - m.d.comb += slots_taken_stop.eq(slots_taken_start) - m.d.sync += slots_taken.eq(slots_taken_stop) - - epoch = Signal(epoch_width) - - m.d.sync += epoch.eq(epoch + 1) - - @def_method(m, self._start) - def _(slot: Value): - m.d.comb += slots_taken_start.eq(slots_taken | (1 << slot)) - self.log.error(m, (slots_taken & (1 << slot)).any(), "taken slot {} taken again", slot) - self.slots.write(m, addr=slot, data=epoch) - - @def_method(m, self._stop) - def _(slot: Value): - m.d.comb += slots_taken_stop.eq(slots_taken_start & ~(C(1, self.slots_number) << slot)) - self.log.error(m, ~(slots_taken & (1 << slot)).any(), "free slot {} freed again", slot) - ret = self.slots.read(m, addr=slot) - # The result of substracting two unsigned n-bit is a signed (n+1)-bit value, - # so we need to cast the result and discard the most significant bit. - duration = (epoch - ret.epoch).as_unsigned()[:-1] - self.histogram.add(m, duration) - - return m - - def start(self, m: TModule, *, slot: ValueLike): - """ - Registers the start of an event for a given slot tag. - - Should be called in the body of either a transaction or a method. - - Parameters - ---------- - m: TModule - Transactron module - slot: ValueLike - The slot tag of the event. - """ - - if not self.metrics_enabled(): - return - - self._start(m, slot) - - def stop(self, m: TModule, *, slot: ValueLike): - """ - Registers the end of the event for a given slot tag. - - Should be called in the body of either a transaction or a method. - - Parameters - ---------- - m: TModule - Transactron module - slot: ValueLike - The slot tag of the event. - """ - - if not self.metrics_enabled(): - return - - self._stop(m, slot) - - def metrics_enabled(self) -> bool: - return DependencyContext.get().get_dependency(HwMetricsEnabledKey()) - - -class HardwareMetricsManager: - """ - Collects all metrics registered in the circuit and provides an easy - access to them. - """ - - def __init__(self): - self._metrics: Optional[dict[str, HwMetric]] = None - - def _collect_metrics(self) -> dict[str, HwMetric]: - # We lazily collect all metrics so that the metrics manager can be - # constructed at any time. Otherwise, if a metric object was created - # after the manager object had been created, that metric wouldn't end up - # being registered. - metrics: dict[str, HwMetric] = {} - for metric in DependencyContext.get().get_dependency(HwMetricsListKey()): - if metric.fully_qualified_name in metrics: - raise RuntimeError(f"Metric '{metric.fully_qualified_name}' is already registered") - - metrics[metric.fully_qualified_name] = metric - - return metrics - - def get_metrics(self) -> dict[str, HwMetric]: - """ - Returns all metrics registered in the circuit. - """ - if self._metrics is None: - self._metrics = self._collect_metrics() - return self._metrics - - def get_register_value(self, metric_name: str, reg_name: str) -> Signal: - """ - Returns the signal holding the register value of the given metric. - - Parameters - ---------- - metric_name: str - The fully qualified name of the metric, for example 'frontend.icache.loads'. - reg_name: str - The name of the register from that metric, for example if - the metric is a histogram, the 'reg_name' could be 'min' - or 'bucket-32'. - """ - - metrics = self.get_metrics() - if metric_name not in metrics: - raise RuntimeError(f"Couldn't find metric '{metric_name}'") - return metrics[metric_name].signals[reg_name] - - def debug_signals(self) -> SignalBundle: - """ - Returns tree-like SignalBundle composed of all metric registers. - """ - metrics = self.get_metrics() - - def rec(metric_names: list[str], depth: int = 1): - bundle: list[SignalBundle] = [] - components: dict[str, list[str]] = {} - - for metric in metric_names: - parts = metric.split(".") - - if len(parts) == depth: - signals = metrics[metric].signals - reg_values = [signals[reg_name] for reg_name in signals] - - bundle.append({metric: reg_values}) - - continue - - component_prefix = ".".join(parts[:depth]) - - if component_prefix not in components: - components[component_prefix] = [] - components[component_prefix].append(metric) - - for component_name, elements in components.items(): - bundle.append({component_name: rec(elements, depth + 1)}) - - return bundle - - return {"metrics": rec(list(self.get_metrics().keys()))} diff --git a/transactron/lib/reqres.py b/transactron/lib/reqres.py deleted file mode 100644 index a3f6e2908..000000000 --- a/transactron/lib/reqres.py +++ /dev/null @@ -1,185 +0,0 @@ -from amaranth import * -from ..core import * -from ..utils import SrcLoc, get_src_loc, MethodLayout -from .connectors import Forwarder -from transactron.lib import BasicFifo -from amaranth.utils import * - -__all__ = [ - "ArgumentsToResultsZipper", - "Serializer", -] - - -class ArgumentsToResultsZipper(Elaboratable): - """Zips arguments used to call method with results, cutting critical path. - - This module provides possibility to pass arguments from caller and connect it with results - from callee. Arguments are stored in 2-FIFO and results in Forwarder. Because of this asymmetry, - the callee should provide results as long as they aren't correctly received. - - FIFO is used as rate-limiter, so when FIFO reaches full capacity there should be no new requests issued. - - Example topology: - - .. mermaid:: - - graph LR - Caller -- write_arguments --> 2-FIFO; - Caller -- invoke --> Callee["Callee \\n (1+ cycle delay)"]; - Callee -- write_results --> Forwarder; - Forwarder -- read --> Zip; - 2-FIFO -- read --> Zip; - Zip -- read --> User; - subgraph ArgumentsToResultsZipper - Forwarder; - 2-FIFO; - Zip; - end - - Attributes - ---------- - peek_arg: Method - A nonexclusive method to read (but not delete) the head of the arg queue. - write_args: Method - Method to write arguments with `args_layout` format to 2-FIFO. - write_results: Method - Method to save results with `results_layout` in the Forwarder. - read: Method - Reads latest entries from the fifo and the forwarder and return them as - a structure with two fields: 'args' and 'results'. - """ - - def __init__(self, args_layout: MethodLayout, results_layout: MethodLayout, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - args_layout: method layout - The format of arguments. - results_layout: method layout - The format of results. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.src_loc = get_src_loc(src_loc) - self.results_layout = results_layout - self.args_layout = args_layout - self.output_layout = [("args", self.args_layout), ("results", results_layout)] - - self.peek_arg = Method(o=self.args_layout, nonexclusive=True, src_loc=self.src_loc) - self.write_args = Method(i=self.args_layout, src_loc=self.src_loc) - self.write_results = Method(i=self.results_layout, src_loc=self.src_loc) - self.read = Method(o=self.output_layout, src_loc=self.src_loc) - - def elaborate(self, platform): - m = TModule() - - fifo = BasicFifo(self.args_layout, depth=2, src_loc=self.src_loc) - forwarder = Forwarder(self.results_layout, src_loc=self.src_loc) - - m.submodules.fifo = fifo - m.submodules.forwarder = forwarder - - @def_method(m, self.write_args) - def _(arg): - fifo.write(m, arg) - - @def_method(m, self.write_results) - def _(arg): - forwarder.write(m, arg) - - @def_method(m, self.read) - def _(): - args = fifo.read(m) - results = forwarder.read(m) - return {"args": args, "results": results} - - self.peek_arg.proxy(m, fifo.peek) - - return m - - -class Serializer(Elaboratable): - """Module to serialize request-response methods. - - Provides a transactional interface to connect many client `Module`\\s (which request somethig using method call) - with a server `Module` which provides method to request operation and method to get response. - - Requests are being serialized from many clients and forwarded to a server which can process only one request - at the time. Responses from server are deserialized and passed to proper client. `Serializer` assumes, that - responses from the server are in-order, so the order of responses is the same as order of requests. - - - Attributes - ---------- - serialize_in: list[Method] - List of request methods. Data layouts are the same as for `serialized_req_method`. - serialize_out: list[Method] - List of response methods. Data layouts are the same as for `serialized_resp_method`. - `i`-th response method provides responses for requests from `i`-th `serialize_in` method. - """ - - def __init__( - self, - *, - port_count: int, - serialized_req_method: Method, - serialized_resp_method: Method, - depth: int = 4, - src_loc: int | SrcLoc = 0 - ): - """ - Parameters - ---------- - port_count: int - Number of ports, which should be generated. `len(serialize_in)=len(serialize_out)=port_count` - serialized_req_method: Method - Request method provided by server's `Module`. - serialized_resp_method: Method - Response method provided by server's `Module`. - depth: int - Number of requests which can be forwarded to server, before server provides first response. Describe - the resistance of `Serializer` to latency of server in case when server is fully pipelined. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.src_loc = get_src_loc(src_loc) - self.port_count = port_count - self.serialized_req_method = serialized_req_method - self.serialized_resp_method = serialized_resp_method - - self.depth = depth - - self.id_layout = [("id", exact_log2(self.port_count))] - - self.clear = Method() - self.serialize_in = [ - Method.like(self.serialized_req_method, src_loc=self.src_loc) for _ in range(self.port_count) - ] - self.serialize_out = [ - Method.like(self.serialized_resp_method, src_loc=self.src_loc) for _ in range(self.port_count) - ] - - def elaborate(self, platform) -> TModule: - m = TModule() - - pending_requests = BasicFifo(self.id_layout, self.depth, src_loc=self.src_loc) - m.submodules.pending_requests = pending_requests - - for i in range(self.port_count): - - @def_method(m, self.serialize_in[i]) - def _(arg): - pending_requests.write(m, {"id": i}) - self.serialized_req_method(m, arg) - - @def_method(m, self.serialize_out[i], ready=(pending_requests.head.id == i)) - def _(): - pending_requests.read(m) - return self.serialized_resp_method(m) - - self.clear.proxy(m, pending_requests.clear) - - return m diff --git a/transactron/lib/simultaneous.py b/transactron/lib/simultaneous.py deleted file mode 100644 index 7b00f93ff..000000000 --- a/transactron/lib/simultaneous.py +++ /dev/null @@ -1,87 +0,0 @@ -from amaranth import * - -from ..utils import SrcLoc -from ..core import * -from ..core import TransactionBase -from contextlib import contextmanager -from typing import Optional -from transactron.utils import ValueLike - -__all__ = [ - "condition", -] - - -@contextmanager -def condition(m: TModule, *, nonblocking: bool = False, priority: bool = False): - """Conditions using simultaneous transactions. - - This context manager allows to easily define conditions utilizing - nested transactions and the simultaneous transactions mechanism. - It is similar to Amaranth's `If`, but allows to call different and - possibly overlapping method sets in each branch. Each of the branches is - defined using a separate nested transaction. - - Inside the condition body, branches can be added, which are guarded - by Boolean conditions. A branch is considered for execution if its - condition is true and the called methods can be run. A catch-all, - default branch can be added, which can be executed only if none of - the other branches execute. The condition of the default branch is - the negated alternative of all the other conditions. - - Parameters - ---------- - m : TModule - A module where the condition is defined. - nonblocking : bool - States that the condition should not block the containing method - or transaction from running, even when none of the branch - conditions is true. In case of a blocking method call, the - containing method or transaction is still blocked. - priority : bool - States that when conditions are not mutually exclusive and multiple - branches could be executed, the first one will be selected. This - influences the scheduling order of generated transactions. - - Examples - -------- - .. highlight:: python - .. code-block:: python - - with condition(m) as branch: - with branch(cond1): - ... - with branch(cond2): - ... - with branch(): # default, optional - ... - """ - this = TransactionBase.get() - transactions = list[Transaction]() - last = False - conds = list[Signal]() - - @contextmanager - def branch(cond: Optional[ValueLike] = None, *, src_loc: int | SrcLoc = 2): - nonlocal last - if last: - raise RuntimeError("Condition clause added after catch-all") - req = Signal() - m.d.top_comb += req.eq(cond if cond is not None else ~Cat(*conds).any()) - conds.append(req) - name = f"{this.name}_cond{len(transactions)}" - with (transaction := Transaction(name=name, src_loc=src_loc)).body(m, request=req): - yield - if transactions and priority: - transactions[-1].schedule_before(transaction) - if cond is None: - last = True - transactions.append(transaction) - - yield branch - - if nonblocking and not last: - with branch(): - pass - - this.simultaneous_alternatives(*transactions) diff --git a/transactron/lib/storage.py b/transactron/lib/storage.py deleted file mode 100644 index 0b25dd6c2..000000000 --- a/transactron/lib/storage.py +++ /dev/null @@ -1,350 +0,0 @@ -from amaranth import * -from amaranth.utils import * -import amaranth.lib.memory as memory -import amaranth_types.memory as amemory - -from transactron.utils.transactron_helpers import from_method_layout, make_layout -from ..core import * -from ..utils import SrcLoc, get_src_loc, MultiPriorityEncoder -from typing import Optional -from transactron.utils import LayoutList, MethodLayout - -__all__ = ["MemoryBank", "ContentAddressableMemory", "AsyncMemoryBank"] - - -class MemoryBank(Elaboratable): - """MemoryBank module. - - Provides a transactional interface to synchronous Amaranth Memory with arbitrary - number of read and write ports. It supports optionally writing with given granularity. - - Attributes - ---------- - read_reqs: list[Method] - The read request methods, one for each read port. Accepts an `addr` from which data should be read. - Only ready if there is there is a place to buffer response. After calling `read_reqs[i]`, the result - will be available via the method `read_resps[i]`. - read_resps: list[Method] - The read response methods, one for each read port. Return `data_layout` View which was saved on `addr` given - by last corresponding `read_reqs` method call. Only ready after corresponding `read_reqs` call. - writes: list[Method] - The write methods, one for each write port. Accepts write address `addr`, `data` in form of `data_layout` - and optionally `mask` if `granularity` is not None. `1` in mask means that appropriate part should be written. - read_req: Method - The only method from `read_reqs`, if the memory has a single read port. If it has more ports, this method - is unavailable and `read_reqs` should be used instead. - read_resp: Method - The only method from `read_resps`, if the memory has a single read port. If it has more ports, this method - is unavailable and `read_resps` should be used instead. - write: Method - The only method from `writes`, if the memory has a single write port. If it has more ports, this method - is unavailable and `writes` should be used instead. - """ - - def __init__( - self, - *, - data_layout: LayoutList, - elem_count: int, - granularity: Optional[int] = None, - transparent: bool = False, - read_ports: int = 1, - write_ports: int = 1, - memory_type: amemory.AbstractMemoryConstructor[int, Value] = memory.Memory, - src_loc: int | SrcLoc = 0, - ): - """ - Parameters - ---------- - data_layout: method layout - The format of structures stored in the Memory. - elem_count: int - Number of elements stored in Memory. - granularity: Optional[int] - Granularity of write, forwarded to Amaranth. If `None` the whole structure is always saved at once. - If not, the width of `data_layout` is split into `granularity` parts, which can be saved independently. - transparent: bool - Read port transparency, false by default. When a read port is transparent, if a given memory address - is read and written in the same clock cycle, the read returns the written value instead of the value - which was in the memory in that cycle. - read_ports: int - Number of read ports. - write_ports: int - Number of write ports. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.src_loc = get_src_loc(src_loc) - self.data_layout = make_layout(*data_layout) - self.elem_count = elem_count - self.granularity = granularity - self.width = from_method_layout(self.data_layout).size - self.addr_width = bits_for(self.elem_count - 1) - self.transparent = transparent - self.reads_ports = read_ports - self.writes_ports = write_ports - self.memory_type = memory_type - - self.read_reqs_layout: LayoutList = [("addr", self.addr_width)] - write_layout = [("addr", self.addr_width), ("data", self.data_layout)] - if self.granularity is not None: - write_layout.append(("mask", self.width // self.granularity)) - self.writes_layout = make_layout(*write_layout) - - self.read_reqs = [Method(i=self.read_reqs_layout, src_loc=self.src_loc) for _ in range(read_ports)] - self.read_resps = [Method(o=self.data_layout, src_loc=self.src_loc) for _ in range(read_ports)] - self.writes = [Method(i=self.writes_layout, src_loc=self.src_loc) for _ in range(write_ports)] - - if read_ports == 1: - self.read_req = self.read_reqs[0] - self.read_resp = self.read_resps[0] - if write_ports == 1: - self.write = self.writes[0] - - def elaborate(self, platform) -> TModule: - m = TModule() - - m.submodules.mem = mem = self.memory_type(shape=self.width, depth=self.elem_count, init=[]) - write_port = [mem.write_port() for _ in range(self.writes_ports)] - read_port = [ - mem.read_port(transparent_for=write_port if self.transparent else []) for _ in range(self.reads_ports) - ] - read_output_valid = [Signal() for _ in range(self.reads_ports)] - overflow_valid = [Signal() for _ in range(self.reads_ports)] - overflow_data = [Signal(self.width) for _ in range(self.reads_ports)] - - # The read request method can be called at most twice when not reading the response. - # The first result is stored in the overflow buffer, the second - in the read value buffer of the memory. - # If the responses are always read as they arrive, overflow is never written and no stalls occur. - - for i in range(self.reads_ports): - with m.If(read_output_valid[i] & ~overflow_valid[i] & self.read_reqs[i].run & ~self.read_resps[i].run): - m.d.sync += overflow_valid[i].eq(1) - m.d.sync += overflow_data[i].eq(read_port[i].data) - - @def_methods(m, self.read_resps, lambda i: read_output_valid[i] | overflow_valid[i]) - def _(i: int): - with m.If(overflow_valid[i]): - m.d.sync += overflow_valid[i].eq(0) - with m.Else(): - m.d.sync += read_output_valid[i].eq(0) - return Mux(overflow_valid[i], overflow_data[i], read_port[i].data) - - for i in range(self.reads_ports): - m.d.comb += read_port[i].en.eq(0) # because the init value is 1 - - @def_methods(m, self.read_reqs, lambda i: ~overflow_valid[i]) - def _(i: int, addr): - m.d.sync += read_output_valid[i].eq(1) - m.d.comb += read_port[i].en.eq(1) - m.d.comb += read_port[i].addr.eq(addr) - - @def_methods(m, self.writes) - def _(i: int, arg): - m.d.comb += write_port[i].addr.eq(arg.addr) - m.d.comb += write_port[i].data.eq(arg.data) - if self.granularity is None: - m.d.comb += write_port[i].en.eq(1) - else: - m.d.comb += write_port[i].en.eq(arg.mask) - - return m - - -class ContentAddressableMemory(Elaboratable): - """Content addresable memory - - This module implements a content-addressable memory (in short CAM) with Transactron interface. - CAM is a type of memory where instead of predefined indexes there are used values fed in runtime - as keys (similar as in python dictionary). To insert new entry a pair `(key, value)` has to be - provided. Such pair takes an free slot which depends on internal implementation. To read value - a `key` has to be provided. It is compared with every valid key stored in CAM. If there is a hit, - a value is read. There can be many instances of the same key in CAM. In such case it is undefined - which value will be read. - - - .. warning:: - Pushing the value with index already present in CAM is an undefined behaviour. - - Attributes - ---------- - read : Method - Nondestructive read - write : Method - If index present - do update - remove : Method - Remove - push : Method - Inserts new data. - """ - - def __init__(self, address_layout: MethodLayout, data_layout: MethodLayout, entries_number: int): - """ - Parameters - ---------- - address_layout : LayoutLike - The layout of the address records. - data_layout : LayoutLike - The layout of the data. - entries_number : int - The number of slots to create in memory. - """ - self.address_layout = from_method_layout(address_layout) - self.data_layout = from_method_layout(data_layout) - self.entries_number = entries_number - - self.read = Method(i=[("addr", self.address_layout)], o=[("data", self.data_layout), ("not_found", 1)]) - self.remove = Method(i=[("addr", self.address_layout)]) - self.push = Method(i=[("addr", self.address_layout), ("data", self.data_layout)]) - self.write = Method(i=[("addr", self.address_layout), ("data", self.data_layout)], o=[("not_found", 1)]) - - def elaborate(self, platform) -> TModule: - m = TModule() - - address_array = Array( - [Signal(self.address_layout, name=f"address_array_{i}") for i in range(self.entries_number)] - ) - data_array = Array([Signal(self.data_layout, name=f"data_array_{i}") for i in range(self.entries_number)]) - valids = Signal(self.entries_number, name="valids") - - m.submodules.encoder_read = encoder_read = MultiPriorityEncoder(self.entries_number, 1) - m.submodules.encoder_write = encoder_write = MultiPriorityEncoder(self.entries_number, 1) - m.submodules.encoder_push = encoder_push = MultiPriorityEncoder(self.entries_number, 1) - m.submodules.encoder_remove = encoder_remove = MultiPriorityEncoder(self.entries_number, 1) - m.d.top_comb += encoder_push.input.eq(~valids) - - @def_method(m, self.push, ready=~valids.all()) - def _(addr, data): - id = Signal(range(self.entries_number), name="id_push") - m.d.top_comb += id.eq(encoder_push.outputs[0]) - m.d.sync += address_array[id].eq(addr) - m.d.sync += data_array[id].eq(data) - m.d.sync += valids.bit_select(id, 1).eq(1) - - @def_method(m, self.write) - def _(addr, data): - write_mask = Signal(self.entries_number, name="write_mask") - m.d.top_comb += write_mask.eq(Cat([addr == stored_addr for stored_addr in address_array]) & valids) - m.d.top_comb += encoder_write.input.eq(write_mask) - with m.If(write_mask.any()): - m.d.sync += data_array[encoder_write.outputs[0]].eq(data) - return {"not_found": ~write_mask.any()} - - @def_method(m, self.read) - def _(addr): - read_mask = Signal(self.entries_number, name="read_mask") - m.d.top_comb += read_mask.eq(Cat([addr == stored_addr for stored_addr in address_array]) & valids) - m.d.top_comb += encoder_read.input.eq(read_mask) - return {"data": data_array[encoder_read.outputs[0]], "not_found": ~read_mask.any()} - - @def_method(m, self.remove) - def _(addr): - rm_mask = Signal(self.entries_number, name="rm_mask") - m.d.top_comb += rm_mask.eq(Cat([addr == stored_addr for stored_addr in address_array]) & valids) - m.d.top_comb += encoder_remove.input.eq(rm_mask) - with m.If(rm_mask.any()): - m.d.sync += valids.bit_select(encoder_remove.outputs[0], 1).eq(0) - - return m - - -class AsyncMemoryBank(Elaboratable): - """AsyncMemoryBank module. - - Provides a transactional interface to asynchronous Amaranth Memory with arbitrary number of - read and write ports. It supports optionally writing with given granularity. - - Attributes - ---------- - reads: list[Method] - The read methods, one for each read port. Accepts an `addr` from which data should be read. - The read response method. Return `data_layout` View which was saved on `addr` given by last - `write` method call. - writes: list[Method] - The write methods, one for each write port. Accepts write address `addr`, `data` in form of `data_layout` - and optionally `mask` if `granularity` is not None. `1` in mask means that appropriate part should be written. - read: Method - The only method from `reads`, if the memory has a single read port. - write: Method - The only method from `writes`, if the memory has a single write port. - """ - - def __init__( - self, - *, - data_layout: LayoutList, - elem_count: int, - granularity: Optional[int] = None, - read_ports: int = 1, - write_ports: int = 1, - memory_type: amemory.AbstractMemoryConstructor[int, Value] = memory.Memory, - src_loc: int | SrcLoc = 0, - ): - """ - Parameters - ---------- - data_layout: method layout - The format of structures stored in the Memory. - elem_count: int - Number of elements stored in Memory. - granularity: Optional[int] - Granularity of write, forwarded to Amaranth. If `None` the whole structure is always saved at once. - If not, the width of `data_layout` is split into `granularity` parts, which can be saved independently. - read_ports: int - Number of read ports. - write_ports: int - Number of write ports. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.src_loc = get_src_loc(src_loc) - self.data_layout = make_layout(*data_layout) - self.elem_count = elem_count - self.granularity = granularity - self.width = from_method_layout(self.data_layout).size - self.addr_width = bits_for(self.elem_count - 1) - self.reads_ports = read_ports - self.writes_ports = write_ports - self.memory_type = memory_type - - self.read_reqs_layout: LayoutList = [("addr", self.addr_width)] - write_layout = [("addr", self.addr_width), ("data", self.data_layout)] - if self.granularity is not None: - write_layout.append(("mask", self.width // self.granularity)) - self.writes_layout = make_layout(*write_layout) - - self.reads = [ - Method(i=self.read_reqs_layout, o=self.data_layout, src_loc=self.src_loc) for _ in range(read_ports) - ] - self.writes = [Method(i=self.writes_layout, src_loc=self.src_loc) for _ in range(write_ports)] - - if read_ports == 1: - self.read = self.reads[0] - if write_ports == 1: - self.write = self.writes[0] - - def elaborate(self, platform) -> TModule: - m = TModule() - - mem = self.memory_type(shape=self.width, depth=self.elem_count, init=[]) - m.submodules.mem = mem - write_port = [mem.write_port() for _ in range(self.writes_ports)] - read_port = [mem.read_port(domain="comb") for _ in range(self.reads_ports)] - - @def_methods(m, self.reads) - def _(i: int, addr): - m.d.comb += read_port[i].addr.eq(addr) - return read_port[i].data - - @def_methods(m, self.writes) - def _(i: int, arg): - m.d.comb += write_port[i].addr.eq(arg.addr) - m.d.comb += write_port[i].data.eq(arg.data) - if self.granularity is None: - m.d.comb += write_port[i].en.eq(1) - else: - m.d.comb += write_port[i].en.eq(arg.mask) - - return m diff --git a/transactron/lib/transformers.py b/transactron/lib/transformers.py deleted file mode 100644 index 91b94f9c6..000000000 --- a/transactron/lib/transformers.py +++ /dev/null @@ -1,455 +0,0 @@ -from amaranth import * - -from transactron.utils.transactron_helpers import get_src_loc -from ..core import * -from ..utils import SrcLoc -from typing import Optional, Protocol -from collections.abc import Callable -from transactron.utils import ( - ValueLike, - assign, - AssignType, - ModuleLike, - MethodStruct, - HasElaborate, - MethodLayout, - RecordDict, -) -from .connectors import Forwarder, ManyToOneConnectTrans, ConnectTrans -from .simultaneous import condition - -__all__ = [ - "Transformer", - "Unifier", - "MethodMap", - "MethodFilter", - "MethodProduct", - "MethodTryProduct", - "Collector", - "CatTrans", - "ConnectAndMapTrans", -] - - -class Transformer(HasElaborate, Protocol): - """Method transformer abstract class. - - Method transformers construct a new method which utilizes other methods. - - Attributes - ---------- - method: Method - The method. - """ - - method: Method - - def use(self, m: ModuleLike): - """ - Returns the method and adds the transformer to a module. - - Parameters - ---------- - m: Module or TModule - The module to which this transformer is added as a submodule. - """ - m.submodules += self - return self.method - - -class Unifier(Transformer, Protocol): - method: Method - - def __init__(self, targets: list[Method]): ... - - -class MethodMap(Elaboratable, Transformer): - """Bidirectional map for methods. - - Takes a target method and creates a transformed method which calls the - original target method, mapping the input and output values with - functions. The mapping functions take two parameters, a `Module` and the - structure being transformed. Alternatively, a `Method` can be - passed. - - Attributes - ---------- - method: Method - The transformed method. - """ - - def __init__( - self, - target: Method, - *, - i_transform: Optional[tuple[MethodLayout, Callable[[TModule, MethodStruct], RecordDict]]] = None, - o_transform: Optional[tuple[MethodLayout, Callable[[TModule, MethodStruct], RecordDict]]] = None, - src_loc: int | SrcLoc = 0 - ): - """ - Parameters - ---------- - target: Method - The target method. - i_transform: (method layout, function or Method), optional - Input mapping function. If specified, it should be a pair of a - function and a input layout for the transformed method. - If not present, input is passed unmodified. - o_transform: (method layout, function or Method), optional - Output mapping function. If specified, it should be a pair of a - function and a output layout for the transformed method. - If not present, output is passed unmodified. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - if i_transform is None: - i_transform = (target.layout_in, lambda _, x: x) - if o_transform is None: - o_transform = (target.layout_out, lambda _, x: x) - - self.target = target - src_loc = get_src_loc(src_loc) - self.method = Method(i=i_transform[0], o=o_transform[0], src_loc=src_loc) - self.i_fun = i_transform[1] - self.o_fun = o_transform[1] - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.method) - def _(arg): - return self.o_fun(m, self.target(m, self.i_fun(m, arg))) - - return m - - -class MethodFilter(Elaboratable, Transformer): - """Method filter. - - Takes a target method and creates a method which calls the target method - only when some condition is true. The condition function takes two - parameters, a module and the input structure of the method. Non-zero - return value is interpreted as true. Alternatively to using a function, - a `Method` can be passed as a condition. - By default, the target method is locked for use even if it is not called. - If this is not the desired effect, set `use_condition` to True, but this will - cause that the provided method will be `single_caller` and all other `condition` - drawbacks will be in place (e.g. risk of exponential complexity). - - Attributes - ---------- - method: Method - The transformed method. - """ - - def __init__( - self, - target: Method, - condition: Callable[[TModule, MethodStruct], ValueLike], - default: Optional[RecordDict] = None, - *, - use_condition: bool = False, - src_loc: int | SrcLoc = 0 - ): - """ - Parameters - ---------- - target: Method - The target method. - condition: function or Method - The condition which, when true, allows the call to `target`. When - false, `default` is returned. - default: Value or dict, optional - The default value returned from the filtered method when the condition - is false. If omitted, zero is returned. - use_condition : bool - Instead of `m.If` use simultaneus `condition` which allow to execute - this filter if the condition is False and target is not ready. - When `use_condition` is true, `condition` must not be a `Method`. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - if default is None: - default = Signal.like(target.data_out) - - self.target = target - self.use_condition = use_condition - src_loc = get_src_loc(src_loc) - self.method = Method(i=target.layout_in, o=target.layout_out, single_caller=self.use_condition, src_loc=src_loc) - self.condition = condition - self.default = default - - assert not (use_condition and isinstance(condition, Method)) - - def elaborate(self, platform): - m = TModule() - - ret = Signal.like(self.target.data_out) - m.d.comb += assign(ret, self.default, fields=AssignType.ALL) - - @def_method(m, self.method) - def _(arg): - if self.use_condition: - cond = Signal() - m.d.top_comb += cond.eq(self.condition(m, arg)) - with condition(m, nonblocking=True) as branch: - with branch(cond): - m.d.comb += ret.eq(self.target(m, arg)) - else: - with m.If(self.condition(m, arg)): - m.d.comb += ret.eq(self.target(m, arg)) - return ret - - return m - - -class MethodProduct(Elaboratable, Unifier): - def __init__( - self, - targets: list[Method], - combiner: Optional[tuple[MethodLayout, Callable[[TModule, list[MethodStruct]], RecordDict]]] = None, - *, - src_loc: int | SrcLoc = 0 - ): - """Method product. - - Takes arbitrary, non-zero number of target methods, and constructs - a method which calls all of the target methods using the same - argument. The return value of the resulting method is, by default, - the return value of the first of the target methods. A combiner - function can be passed, which can compute the return value from - the results of every target method. - - Parameters - ---------- - targets: list[Method] - A list of methods to be called. - combiner: (int or method layout, function), optional - A pair of the output layout and the combiner function. The - combiner function takes two parameters: a `Module` and - a list of outputs of the target methods. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - - Attributes - ---------- - method: Method - The product method. - """ - if combiner is None: - combiner = (targets[0].layout_out, lambda _, x: x[0]) - self.targets = targets - self.combiner = combiner - src_loc = get_src_loc(src_loc) - self.method = Method(i=targets[0].layout_in, o=combiner[0], src_loc=src_loc) - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.method) - def _(arg): - results = [] - for target in self.targets: - results.append(target(m, arg)) - return self.combiner[1](m, results) - - return m - - -class MethodTryProduct(Elaboratable, Unifier): - def __init__( - self, - targets: list[Method], - combiner: Optional[ - tuple[MethodLayout, Callable[[TModule, list[tuple[Value, MethodStruct]]], RecordDict]] - ] = None, - *, - src_loc: int | SrcLoc = 0 - ): - """Method product with optional calling. - - Takes arbitrary, non-zero number of target methods, and constructs - a method which tries to call all of the target methods using the same - argument. The methods which are not ready are not called. The return - value of the resulting method is, by default, empty. A combiner - function can be passed, which can compute the return value from the - results of every target method. - - Parameters - ---------- - targets: list[Method] - A list of methods to be called. - combiner: (int or method layout, function), optional - A pair of the output layout and the combiner function. The - combiner function takes two parameters: a `Module` and - a list of pairs. Each pair contains a bit which signals - that a given call succeeded, and the result of the call. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - - Attributes - ---------- - method: Method - The product method. - """ - if combiner is None: - combiner = ([], lambda _, __: {}) - self.targets = targets - self.combiner = combiner - self.src_loc = get_src_loc(src_loc) - self.method = Method(i=targets[0].layout_in, o=combiner[0], src_loc=self.src_loc) - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.method) - def _(arg): - results: list[tuple[Value, MethodStruct]] = [] - for target in self.targets: - success = Signal() - with Transaction(src_loc=self.src_loc).body(m): - m.d.comb += success.eq(1) - results.append((success, target(m, arg))) - return self.combiner[1](m, results) - - return m - - -class Collector(Elaboratable, Unifier): - """Single result collector. - - Creates method that collects results of many methods with identical - layouts. Each call of this method will return a single result of one - of the provided methods. - - Attributes - ---------- - method: Method - Method which returns single result of provided methods. - """ - - def __init__(self, targets: list[Method], *, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - method_list: list[Method] - List of methods from which results will be collected. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.method_list = targets - layout = targets[0].layout_out - self.src_loc = get_src_loc(src_loc) - self.method = Method(o=layout, src_loc=self.src_loc) - - for method in targets: - if layout != method.layout_out: - raise Exception("Not all methods have this same layout") - - def elaborate(self, platform): - m = TModule() - - m.submodules.forwarder = forwarder = Forwarder(self.method.layout_out, src_loc=self.src_loc) - - m.submodules.connect = ManyToOneConnectTrans( - get_results=[get for get in self.method_list], put_result=forwarder.write, src_loc=self.src_loc - ) - - self.method.proxy(m, forwarder.read) - - return m - - -class CatTrans(Elaboratable): - """Concatenating transaction. - - Concatenates the results of two methods and passes the result to the - third method. - """ - - def __init__(self, src1: Method, src2: Method, dst: Method): - """ - Parameters - ---------- - src1: Method - First input method. - src2: Method - Second input method. - dst: Method - The method which receives the concatenation of the results of input - methods. - """ - self.src1 = src1 - self.src2 = src2 - self.dst = dst - - def elaborate(self, platform): - m = TModule() - - with Transaction().body(m): - sdata1 = self.src1(m) - sdata2 = self.src2(m) - ddata = Signal.like(self.dst.data_in) - self.dst(m, ddata) - - m.d.comb += ddata.eq(Cat(sdata1, sdata2)) - - return m - - -class ConnectAndMapTrans(Elaboratable): - """Connecting transaction with mapping functions. - - Behaves like `ConnectTrans`, but modifies the transferred data using - functions or `Method`s. Equivalent to a combination of `ConnectTrans` - and `MethodMap`. The mapping functions take two parameters, a `Module` - and the structure being transformed. - """ - - def __init__( - self, - method1: Method, - method2: Method, - *, - i_fun: Optional[Callable[[TModule, MethodStruct], RecordDict]] = None, - o_fun: Optional[Callable[[TModule, MethodStruct], RecordDict]] = None, - src_loc: int | SrcLoc = 0 - ): - """ - Parameters - ---------- - method1: Method - First method. - method2: Method - Second method, and the method being transformed. - i_fun: function or Method, optional - Input transformation (`method1` to `method2`). - o_fun: function or Method, optional - Output transformation (`method2` to `method1`). - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.method1 = method1 - self.method2 = method2 - self.i_fun = i_fun or (lambda _, x: x) - self.o_fun = o_fun or (lambda _, x: x) - self.src_loc = get_src_loc(src_loc) - - def elaborate(self, platform): - m = TModule() - - m.submodules.transformer = transformer = MethodMap( - self.method2, - i_transform=(self.method1.layout_out, self.i_fun), - o_transform=(self.method1.layout_in, self.o_fun), - src_loc=self.src_loc, - ) - m.submodules.connect = ConnectTrans(self.method1, transformer.method) - - return m diff --git a/transactron/profiler.py b/transactron/profiler.py deleted file mode 100644 index fcea59387..000000000 --- a/transactron/profiler.py +++ /dev/null @@ -1,356 +0,0 @@ -import os -from collections import defaultdict -from typing import Optional -from dataclasses import dataclass, field -from dataclasses_json import dataclass_json -from transactron.utils import SrcLoc, IdGenerator -from transactron.core import TransactionManager -from transactron.core.manager import MethodMap - - -__all__ = [ - "ProfileInfo", - "ProfileData", - "RunStat", - "RunStatNode", - "Profile", - "TransactionSamples", - "MethodSamples", - "ProfileSamples", -] - - -@dataclass_json -@dataclass -class ProfileInfo: - """Information about transactions and methods. - - In `Profile`, transactions and methods are referred to by their unique ID - numbers. - - Attributes - ---------- - name : str - The name. - src_loc : SrcLoc - Source location. - is_transaction : bool - If true, this object describes a transaction; if false, a method. - """ - - name: str - src_loc: SrcLoc - is_transaction: bool - - -@dataclass -class ProfileData: - """Information about transactions and methods from the transaction manager. - - This data is required for transaction profile generation in simulators. - Transactions and methods are referred to by their unique ID numbers. - - Attributes - ---------- - transactions_and_methods: dict[int, ProfileInfo] - Information about individual transactions and methods. - method_parents: dict[int, list[int]] - Lists the callers (transactions and methods) for each method. Key is - method ID. - transactions_by_method: dict[int, list[int]] - Lists which transactions are calling each method. Key is method ID. - transaction_conflicts: dict[int, list[int]] - List which other transactions conflict with each transaction. - """ - - transactions_and_methods: dict[int, ProfileInfo] - method_parents: dict[int, list[int]] - transactions_by_method: dict[int, list[int]] - transaction_conflicts: dict[int, list[int]] - - @staticmethod - def make(transaction_manager: TransactionManager): - transactions_and_methods = dict[int, ProfileInfo]() - method_parents = dict[int, list[int]]() - transactions_by_method = dict[int, list[int]]() - transaction_conflicts = dict[int, list[int]]() - - method_map = MethodMap(transaction_manager.transactions) - cgr, _ = TransactionManager._conflict_graph(method_map) - get_id = IdGenerator() - - def local_src_loc(src_loc: SrcLoc): - return (os.path.relpath(src_loc[0]), src_loc[1]) - - for transaction in method_map.transactions: - transactions_and_methods[get_id(transaction)] = ProfileInfo( - transaction.owned_name, local_src_loc(transaction.src_loc), True - ) - - for method in method_map.methods: - transactions_and_methods[get_id(method)] = ProfileInfo( - method.owned_name, local_src_loc(method.src_loc), False - ) - method_parents[get_id(method)] = [get_id(t_or_m) for t_or_m in method_map.method_parents[method]] - transactions_by_method[get_id(method)] = [ - get_id(t_or_m) for t_or_m in method_map.transactions_by_method[method] - ] - - for transaction, transactions in cgr.items(): - transaction_conflicts[get_id(transaction)] = [get_id(transaction2) for transaction2 in transactions] - - return ( - ProfileData(transactions_and_methods, method_parents, transactions_by_method, transaction_conflicts), - get_id, - ) - - -@dataclass -class RunStat: - """Collected statistics about a transaction or method. - - Attributes - ---------- - name : str - The name. - src_loc : SrcLoc - Source location. - locked : int - For methods: the number of cycles this method was locked because of - a disabled call (a call under a false condition). For transactions: - the number of cycles this transaction was ready to run, but did not - run because a conflicting transaction has run instead. - """ - - name: str - src_loc: str - locked: int = 0 - run: int = 0 - - @staticmethod - def make(info: ProfileInfo): - return RunStat(info.name, f"{info.src_loc[0]}:{info.src_loc[1]}") - - -@dataclass -class RunStatNode: - """A statistics tree. Summarizes call graph information. - - Attributes - ---------- - stat : RunStat - Statistics. - callers : dict[int, RunStatNode] - Statistics for the method callers. For transactions, this is empty. - """ - - stat: RunStat - callers: dict[int, "RunStatNode"] = field(default_factory=dict) - - @staticmethod - def make(info: ProfileInfo): - return RunStatNode(RunStat.make(info)) - - -@dataclass -class TransactionSamples: - """Runtime value of transaction control signals in a given clock cycle. - - Attributes - ---------- - request: bool - The value of the transaction's ``request`` signal. - runnable: bool - The value of the transaction's ``runnable`` signal. - grant: bool - The value of the transaction's ``grant`` signal. - """ - - request: bool - runnable: bool - grant: bool - - -@dataclass -class MethodSamples: - """Runtime value of method control signals in a given clock cycle. - - Attributes - ---------- - run: bool - The value of the method's ``run`` signal. - """ - - run: bool - - -@dataclass -class ProfileSamples: - """Runtime values of all transaction and method control signals. - - Attributes - ---------- - transactions: dict[int, TransactionSamples] - Runtime values of transaction control signals for each transaction. - methods: dict[int, MethodSamples] - Runtime values of method control signals for each method. - """ - - transactions: dict[int, TransactionSamples] = field(default_factory=dict) - methods: dict[int, MethodSamples] = field(default_factory=dict) - - -@dataclass_json -@dataclass -class CycleProfile: - """Profile information for a single clock cycle. - - Transactions and methods are referred to by unique IDs. - - Attributes - ---------- - locked : dict[int, int] - For each transaction which didn't run because of a conflict, the - transaction which has run instead. For each method which was used - but didn't run because of a disabled call, the caller which - used it. - running : dict[int, Optional[int]] - For each running method, its caller. Running transactions don't - have a caller (the value is `None`). - """ - - locked: dict[int, int] = field(default_factory=dict) - running: dict[int, Optional[int]] = field(default_factory=dict) - - @staticmethod - def make(samples: ProfileSamples, data: ProfileData): - cprof = CycleProfile() - - for transaction_id, transaction_samples in samples.transactions.items(): - if transaction_samples.grant: - cprof.running[transaction_id] = None - elif transaction_samples.request and transaction_samples.runnable: - for transaction2_id in data.transaction_conflicts[transaction_id]: - if samples.transactions[transaction2_id].grant: - cprof.locked[transaction_id] = transaction2_id - - running = set(cprof.running) - for method_id, method_samples in samples.methods.items(): - if method_samples.run: - running.add(method_id) - - locked_methods = set[int]() - for method_id in samples.methods.keys(): - if method_id not in running: - if any(transaction_id in running for transaction_id in data.transactions_by_method[method_id]): - locked_methods.add(method_id) - - for method_id in samples.methods.keys(): - if method_id in running: - for t_or_m_id in data.method_parents[method_id]: - if t_or_m_id in running: - cprof.running[method_id] = t_or_m_id - elif method_id in locked_methods: - caller = next( - t_or_m_id - for t_or_m_id in data.method_parents[method_id] - if t_or_m_id in running or t_or_m_id in locked_methods - ) - cprof.locked[method_id] = caller - - return cprof - - -@dataclass_json -@dataclass -class Profile: - """Transactron execution profile. - - Can be saved by the simulator, and then restored by an analysis tool. - In the profile data structure, methods and transactions are referred to - by their unique ID numbers. - - Attributes - ---------- - transactions_and_methods : dict[int, ProfileInfo] - Information about transactions and methods indexed by ID numbers. - cycles : list[CycleProfile] - Profile information for each cycle of the simulation. - """ - - transactions_and_methods: dict[int, ProfileInfo] = field(default_factory=dict) - cycles: list[CycleProfile] = field(default_factory=list) - - def encode(self, file_name: str): - with open(file_name, "w") as fp: - fp.write(self.to_json()) # type: ignore - - @staticmethod - def decode(file_name: str) -> "Profile": - with open(file_name, "r") as fp: - return Profile.from_json(fp.read()) # type: ignore - - def analyze_transactions(self, recursive=False) -> list[RunStatNode]: - stats = {i: RunStatNode.make(info) for i, info in self.transactions_and_methods.items() if info.is_transaction} - - def rec(c: CycleProfile, node: RunStatNode, i: int): - if i in c.running: - node.stat.run += 1 - elif i in c.locked: - node.stat.locked += 1 - if recursive: - for j in called[i]: - if j not in node.callers: - node.callers[j] = RunStatNode.make(self.transactions_and_methods[j]) - rec(c, node.callers[j], j) - - for c in self.cycles: - called = defaultdict[int, set[int]](set) - - for i, j in c.running.items(): - if j is not None: - called[j].add(i) - - for i, j in c.locked.items(): - called[j].add(i) - - for i in c.running: - if i in stats: - rec(c, stats[i], i) - - for i in c.locked: - if i in stats: - stats[i].stat.locked += 1 - - return list(stats.values()) - - def analyze_methods(self, recursive=False) -> list[RunStatNode]: - stats = { - i: RunStatNode.make(info) for i, info in self.transactions_and_methods.items() if not info.is_transaction - } - - def rec(c: CycleProfile, node: RunStatNode, i: int, locking_call=False): - if i in c.running: - if not locking_call: - node.stat.run += 1 - else: - node.stat.locked += 1 - caller = c.running[i] - else: - node.stat.locked += 1 - caller = c.locked[i] - if recursive and caller is not None: - if caller not in node.callers: - node.callers[caller] = RunStatNode.make(self.transactions_and_methods[caller]) - rec(c, node.callers[caller], caller, locking_call) - - for c in self.cycles: - for i in c.running: - if i in stats: - rec(c, stats[i], i) - - for i in c.locked: - if i in stats: - rec(c, stats[i], i, locking_call=True) - - return list(stats.values()) diff --git a/transactron/testing/__init__.py b/transactron/testing/__init__.py deleted file mode 100644 index 8c4940038..000000000 --- a/transactron/testing/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from amaranth.sim._async import TestbenchContext, ProcessContext, SimulatorContext # noqa: F401 -from .input_generation import * # noqa: F401 -from .functions import * # noqa: F401 -from .infrastructure import * # noqa: F401 -from .method_mock import * # noqa: F401 -from .testbenchio import * # noqa: F401 -from .profiler import * # noqa: F401 -from .logging import * # noqa: F401 -from .tick_count import * # noqa: F401 -from transactron.utils import data_layout # noqa: F401 diff --git a/transactron/testing/functions.py b/transactron/testing/functions.py deleted file mode 100644 index ee1225154..000000000 --- a/transactron/testing/functions.py +++ /dev/null @@ -1,15 +0,0 @@ -import amaranth.lib.data as data -from typing import TypeAlias - - -MethodData: TypeAlias = "data.Const[data.StructLayout]" - - -def data_const_to_dict(c: "data.Const[data.Layout]"): - ret = {} - for k, _ in c.shape(): - v = c[k] - if isinstance(v, data.Const): - v = data_const_to_dict(v) - ret[k] = v - return ret diff --git a/transactron/testing/infrastructure.py b/transactron/testing/infrastructure.py deleted file mode 100644 index c41829c80..000000000 --- a/transactron/testing/infrastructure.py +++ /dev/null @@ -1,337 +0,0 @@ -import sys -import pytest -import logging -import os -import random -import functools -from contextlib import contextmanager, nullcontext -from collections.abc import Callable -from typing import TypeVar, Generic, Type, TypeGuard, Any, cast, TypeAlias, Optional -from amaranth import * -from amaranth.sim import * -from amaranth.sim._async import SimulatorContext - -from transactron.utils.dependencies import DependencyContext, DependencyManager -from .testbenchio import TestbenchIO -from .profiler import profiler_process, Profile -from .logging import make_logging_process, parse_logging_level, _LogFormatter -from .tick_count import make_tick_count_process -from .method_mock import MethodMock -from transactron import Method -from transactron.lib import AdapterTrans -from transactron.core.keys import TransactionManagerKey -from transactron.core import TransactionModule -from transactron.utils import ModuleConnector, HasElaborate, auto_debug_signals, HasDebugSignals - - -__all__ = ["SimpleTestCircuit", "PysimSimulator", "TestCaseWithSimulator"] - - -T = TypeVar("T") -_T_nested_collection: TypeAlias = T | list["_T_nested_collection[T]"] | dict[str, "_T_nested_collection[T]"] - - -def guard_nested_collection(cont: Any, t: Type[T]) -> TypeGuard[_T_nested_collection[T]]: - if isinstance(cont, (list, dict)): - if isinstance(cont, dict): - cont = cont.values() - return all([guard_nested_collection(elem, t) for elem in cont]) - elif isinstance(cont, t): - return True - else: - return False - - -_T_HasElaborate = TypeVar("_T_HasElaborate", bound=HasElaborate) - - -class SimpleTestCircuit(Elaboratable, Generic[_T_HasElaborate]): - def __init__(self, dut: _T_HasElaborate): - self._dut = dut - self._io: dict[str, _T_nested_collection[TestbenchIO]] = {} - - def __getattr__(self, name: str) -> Any: - try: - return self._io[name] - except KeyError: - raise AttributeError(f"No mock for '{name}'") - - def elaborate(self, platform): - def transform_methods_to_testbenchios( - container: _T_nested_collection[Method], - ) -> tuple[ - _T_nested_collection["TestbenchIO"], - "ModuleConnector | TestbenchIO", - ]: - if isinstance(container, list): - tb_list = [] - mc_list = [] - for elem in container: - tb, mc = transform_methods_to_testbenchios(elem) - tb_list.append(tb) - mc_list.append(mc) - return tb_list, ModuleConnector(*mc_list) - elif isinstance(container, dict): - tb_dict = {} - mc_dict = {} - for name, elem in container.items(): - tb, mc = transform_methods_to_testbenchios(elem) - tb_dict[name] = tb - mc_dict[name] = mc - return tb_dict, ModuleConnector(*mc_dict) - else: - tb = TestbenchIO(AdapterTrans(container)) - return tb, tb - - m = Module() - - m.submodules.dut = self._dut - - for name, attr in vars(self._dut).items(): - if guard_nested_collection(attr, Method) and attr: - tb_cont, mc = transform_methods_to_testbenchios(attr) - self._io[name] = tb_cont - m.submodules[name] = mc - - return m - - def debug_signals(self): - sigs = {"_dut": auto_debug_signals(self._dut)} - for name, io in self._io.items(): - sigs[name] = auto_debug_signals(io) - return sigs - - -class _TestModule(Elaboratable): - def __init__(self, tested_module: HasElaborate, add_transaction_module: bool): - self.tested_module = ( - TransactionModule(tested_module, dependency_manager=DependencyContext.get()) - if add_transaction_module - else tested_module - ) - self.add_transaction_module = add_transaction_module - - def elaborate(self, platform) -> HasElaborate: - m = Module() - - # so that Amaranth allows us to use add_clock - _dummy = Signal() - m.d.sync += _dummy.eq(1) - - m.submodules.tested_module = self.tested_module - - m.domains.sync_neg = ClockDomain(clk_edge="neg", local=True) - - return m - - -class PysimSimulator(Simulator): - def __init__( - self, - module: HasElaborate, - max_cycles: float = 10e4, - add_transaction_module=True, - traces_file=None, - clk_period=1e-6, - ): - test_module = _TestModule(module, add_transaction_module) - self.tested_module = tested_module = test_module.tested_module - super().__init__(test_module) - - self.add_clock(clk_period) - self.add_clock(clk_period, domain="sync_neg") - - if isinstance(tested_module, HasDebugSignals): - extra_signals = tested_module.debug_signals - else: - extra_signals = functools.partial(auto_debug_signals, tested_module) - - if traces_file: - traces_dir = "test/__traces__" - os.makedirs(traces_dir, exist_ok=True) - # Signal handling is hacky and accesses Simulator internals. - # TODO: try to merge with Amaranth. - if isinstance(extra_signals, Callable): - extra_signals = extra_signals() - clocks = [d.clk for d in cast(Any, self)._design.fragment.domains.values()] - - self.ctx = self.write_vcd( - f"{traces_dir}/{traces_file}.vcd", - f"{traces_dir}/{traces_file}.gtkw", - traces=[clocks, extra_signals], - ) - else: - self.ctx = nullcontext() - - self.timeouted = False - - async def timeout_testbench(sim: SimulatorContext): - await sim.delay(clk_period * max_cycles) - self.timeouted = True - - self.add_testbench(timeout_testbench, background=True) - - def run(self) -> bool: - with self.ctx: - super().run() - - return not self.timeouted - - -class TestCaseWithSimulator: - dependency_manager: DependencyManager - - @contextmanager - def _configure_dependency_context(self): - self.dependency_manager = DependencyManager() - with DependencyContext(self.dependency_manager): - yield Tick() - - def add_mock(self, sim: PysimSimulator, val: MethodMock): - sim.add_process(val.output_process) - if val.validate_arguments is not None: - sim.add_process(val.validate_arguments_process) - sim.add_testbench(val.effect_process, background=True) - - def _add_class_mocks(self, sim: PysimSimulator) -> None: - for key in dir(self): - val = getattr(self, key) - if hasattr(val, "_transactron_testing_process"): - sim.add_process(val) - elif hasattr(val, "_transactron_method_mock"): - self.add_mock(sim, val()) - - def _add_local_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None: - for key, val in frame_locals.items(): - if hasattr(val, "_transactron_testing_process"): - sim.add_process(val) - elif hasattr(val, "_transactron_method_mock"): - self.add_mock(sim, val()) - - def _add_all_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None: - self._add_class_mocks(sim) - self._add_local_mocks(sim, frame_locals) - - def _configure_traces(self): - traces_file = None - if "__TRANSACTRON_DUMP_TRACES" in os.environ: - traces_file = self._transactron_current_output_file_name - self._transactron_infrastructure_traces_file = traces_file - - @contextmanager - def _configure_profiles(self): - profile = None - if "__TRANSACTRON_PROFILE" in os.environ: - - def f(): - nonlocal profile - try: - transaction_manager = DependencyContext.get().get_dependency(TransactionManagerKey()) - profile = Profile() - return profiler_process(transaction_manager, profile) - except KeyError: - pass - return None - - self._transactron_sim_processes_to_add.append(f) - - yield - - if profile is not None: - profile_dir = "test/__profiles__" - profile_file = self._transactron_current_output_file_name - os.makedirs(profile_dir, exist_ok=True) - profile.encode(f"{profile_dir}/{profile_file}.json") - - @contextmanager - def _configure_logging(self): - def on_error(): - assert False, "Simulation finished due to an error" - - log_level = parse_logging_level(os.environ["__TRANSACTRON_LOG_LEVEL"]) - log_filter = os.environ["__TRANSACTRON_LOG_FILTER"] - self._transactron_sim_processes_to_add.append(lambda: make_logging_process(log_level, log_filter, on_error)) - - ch = logging.StreamHandler() - formatter = _LogFormatter() - ch.setFormatter(formatter) - - root_logger = logging.getLogger() - handlers_before = root_logger.handlers.copy() - root_logger.handlers.append(ch) - yield - root_logger.handlers = handlers_before - - @contextmanager - def reinitialize_fixtures(self): - # File name to be used in the current test run (either standard or hypothesis iteration) - # for standard tests it will always have the suffix "_0". For hypothesis tests, it will be suffixed - # with the current hypothesis iteration number, so that each hypothesis run is saved to a - # the different file. - self._transactron_current_output_file_name = ( - self._transactron_base_output_file_name + "_" + str(self._transactron_hypothesis_iter_counter) - ) - self._transactron_sim_processes_to_add: list[Callable[[], Optional[Callable]]] = [] - with self._configure_dependency_context(): - self._configure_traces() - with self._configure_profiles(): - with self._configure_logging(): - self._transactron_sim_processes_to_add.append(make_tick_count_process) - yield - self._transactron_hypothesis_iter_counter += 1 - - @pytest.fixture(autouse=True) - def fixture_initialize_testing_env(self, request): - # Hypothesis creates a single instance of a test class, which is later reused multiple times. - # This means that pytest fixtures are only run once. We can take advantage of this behaviour and - # initialise hypothesis related variables. - - # The counter for distinguishing between successive hypothesis iterations, it is incremented - # by `reinitialize_fixtures` which should be started at the beginning of each hypothesis run - self._transactron_hypothesis_iter_counter = 0 - # Base name which will be used later to create file names for particular outputs - self._transactron_base_output_file_name = ".".join(request.node.nodeid.split("/")) - with self.reinitialize_fixtures(): - yield - - @contextmanager - def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True): - clk_period = 1e-6 - sim = PysimSimulator( - module, - max_cycles=max_cycles, - add_transaction_module=add_transaction_module, - traces_file=self._transactron_infrastructure_traces_file, - clk_period=clk_period, - ) - self._add_all_mocks(sim, sys._getframe(2).f_locals) - - yield sim - - for f in self._transactron_sim_processes_to_add: - ret = f() - if ret is not None: - sim.add_process(ret) - - res = sim.run() - assert res, "Simulation time limit exceeded" - - async def tick(self, sim: SimulatorContext, cycle_cnt: int = 1): - """ - Waits for the given number of cycles. - """ - for _ in range(cycle_cnt): - await sim.tick() - - async def random_wait(self, sim: SimulatorContext, max_cycle_cnt: int, *, min_cycle_cnt: int = 0): - """ - Wait for a random amount of cycles in range [min_cycle_cnt, max_cycle_cnt] - """ - await self.tick(sim, random.randrange(min_cycle_cnt, max_cycle_cnt + 1)) - - async def random_wait_geom(self, sim: SimulatorContext, prob: float = 0.5): - """ - Wait till the first success, where there is `prob` probability for success in each cycle. - """ - while random.random() > prob: - await sim.tick() diff --git a/transactron/testing/input_generation.py b/transactron/testing/input_generation.py deleted file mode 100644 index 909da7a43..000000000 --- a/transactron/testing/input_generation.py +++ /dev/null @@ -1,97 +0,0 @@ -from amaranth import * -from amaranth.lib.data import StructLayout -from typing import TypeVar -import hypothesis.strategies as st -from hypothesis.strategies import composite, DrawFn, integers, SearchStrategy -from transactron.utils import MethodLayout, RecordIntDict - - -class OpNOP: - def __repr__(self): - return "OpNOP()" - - -T = TypeVar("T") - - -@composite -def generate_shrinkable_list(draw: DrawFn, length: int, generator: SearchStrategy[T]) -> list[T]: - """ - Trick based on https://github.com/HypothesisWorks/hypothesis/blob/ - 6867da71beae0e4ed004b54b92ef7c74d0722815/hypothesis-python/src/hypothesis/stateful.py#L143 - """ - hp_data = draw(st.data()) - lst = [] - if length == 0: - return lst - i = 0 - force_val = None - while True: - b = hp_data.conjecture_data.draw_boolean(p=2**-16, forced=force_val) - if b: - break - lst.append(draw(generator)) - i += 1 - if i == length: - force_val = True - return lst - - -@composite -def generate_based_on_layout(draw: DrawFn, layout: MethodLayout) -> RecordIntDict: - if isinstance(layout, StructLayout): - raise NotImplementedError("StructLayout is not supported in automatic value generation.") - d = {} - for name, sublayout in layout: - if isinstance(sublayout, list): - elem = draw(generate_based_on_layout(sublayout)) - elif isinstance(sublayout, int): - elem = draw(integers(min_value=0, max_value=sublayout)) - elif isinstance(sublayout, range): - elem = draw(integers(min_value=sublayout.start, max_value=sublayout.stop - 1)) - elif isinstance(sublayout, Shape): - if sublayout.signed: - min_value = -(2 ** (sublayout.width - 1)) - max_value = 2 ** (sublayout.width - 1) - 1 - else: - min_value = 0 - max_value = 2**sublayout.width - elem = draw(integers(min_value=min_value, max_value=max_value)) - else: - # Currently type[Enum] and ShapeCastable - raise NotImplementedError("Passed LayoutList with syntax yet unsuported in automatic value generation.") - d[name] = elem - return d - - -def insert_nops(draw: DrawFn, max_nops: int, lst: list): - nops_nr = draw(integers(min_value=0, max_value=max_nops)) - for i in range(nops_nr): - lst.append(OpNOP()) - return lst - - -@composite -def generate_nops_in_list(draw: DrawFn, max_nops: int, generate_list: SearchStrategy[list[T]]) -> list[T | OpNOP]: - lst = draw(generate_list) - out_lst = [] - out_lst = insert_nops(draw, max_nops, out_lst) - for i in lst: - out_lst.append(i) - out_lst = insert_nops(draw, max_nops, out_lst) - return out_lst - - -@composite -def generate_method_input(draw: DrawFn, args: list[tuple[str, MethodLayout]]) -> dict[str, RecordIntDict]: - out = [] - for name, layout in args: - out.append((name, draw(generate_based_on_layout(layout)))) - return dict(out) - - -@composite -def generate_process_input( - draw: DrawFn, elem_count: int, max_nops: int, layouts: list[tuple[str, MethodLayout]] -) -> list[dict[str, RecordIntDict] | OpNOP]: - return draw(generate_nops_in_list(max_nops, generate_shrinkable_list(elem_count, generate_method_input(layouts)))) diff --git a/transactron/testing/logging.py b/transactron/testing/logging.py deleted file mode 100644 index 449a43ced..000000000 --- a/transactron/testing/logging.py +++ /dev/null @@ -1,109 +0,0 @@ -from collections.abc import Callable, Iterable -from typing import Any -import logging -import itertools - -from amaranth.sim._async import ProcessContext -from transactron.lib import logging as tlog -from transactron.utils.dependencies import DependencyContext -from .tick_count import TicksKey - - -__all__ = ["make_logging_process", "parse_logging_level"] - - -def parse_logging_level(str: str) -> tlog.LogLevel: - """Parse the log level from a string. - - The level can be either a non-negative integer or a string representation - of one of the predefined levels. - - Raises an exception if the level cannot be parsed. - """ - str = str.upper() - names_mapping = logging.getLevelNamesMapping() - if str in names_mapping: - return names_mapping[str] - - # try convert to int - try: - return int(str) - except ValueError: - pass - - raise ValueError("Log level must be either {error, warn, info, debug} or a non-negative integer.") - - -_sim_cycle: int = 0 - - -class _LogFormatter(logging.Formatter): - """ - Log formatter to provide colors and to inject simulator times into - the log messages. Adapted from https://stackoverflow.com/a/56944256/3638629 - """ - - magenta = "\033[0;35m" - grey = "\033[0;34m" - blue = "\033[0;34m" - yellow = "\033[0;33m" - red = "\033[0;31m" - reset = "\033[0m" - - loglevel2colour = { - logging.DEBUG: grey + "{}" + reset, - logging.INFO: magenta + "{}" + reset, - logging.WARNING: yellow + "{}" + reset, - logging.ERROR: red + "{}" + reset, - } - - def format(self, record: logging.LogRecord): - level_name = self.loglevel2colour[record.levelno].format(record.levelname) - return f"{_sim_cycle} {level_name} {record.name} {record.getMessage()}" - - -def make_logging_process(level: tlog.LogLevel, namespace_regexp: str, on_error: Callable[[], Any]): - combined_trigger = tlog.get_trigger_bit(level, namespace_regexp) - records = tlog.get_log_records(level, namespace_regexp) - - root_logger = logging.getLogger() - - def handle_logs(record_vals: Iterable[int]) -> None: - it = iter(record_vals) - - for record in records: - trigger = next(it) - values = [next(it) for _ in record.fields] - - if not trigger: - continue - - formatted_msg = record.format(*values) - - logger = root_logger.getChild(record.logger_name) - logger.log( - record.level, - "[%s:%d] %s", - record.location[0], - record.location[1], - formatted_msg, - ) - - if record.level >= logging.ERROR: - on_error() - - async def log_process(sim: ProcessContext) -> None: - global _sim_cycle - ticks = DependencyContext.get().get_dependency(TicksKey()) - - async for _, _, ticks_val, combined_trigger_val, *record_vals in ( - sim.tick() - .sample(ticks, combined_trigger) - .sample(*itertools.chain(*([record.trigger] + record.fields for record in records))) - ): - if not combined_trigger_val: - continue - _sim_cycle = ticks_val - handle_logs(record_vals) - - return log_process diff --git a/transactron/testing/method_mock.py b/transactron/testing/method_mock.py deleted file mode 100644 index 9587ae19f..000000000 --- a/transactron/testing/method_mock.py +++ /dev/null @@ -1,175 +0,0 @@ -from contextlib import contextmanager -import functools -from typing import Callable, Any, Optional - -from amaranth.sim._async import SimulatorContext -from transactron.lib.adapters import Adapter -from transactron.utils.transactron_helpers import async_mock_def_helper -from .testbenchio import TestbenchIO -from transactron.utils._typing import RecordIntDict - - -__all__ = ["MethodMock", "def_method_mock"] - - -class MethodMock: - def __init__( - self, - adapter: Adapter, - function: Callable[..., Optional[RecordIntDict]], - *, - validate_arguments: Optional[Callable[..., bool]] = None, - enable: Callable[[], bool] = lambda: True, - delay: float = 0, - ): - self.adapter = adapter - self.function = function - self.validate_arguments = validate_arguments - self.enable = enable - self.delay = delay - self._effects: list[Callable[[], None]] = [] - self._freeze = False - - _current_mock: Optional["MethodMock"] = None - - @staticmethod - def effect(effect: Callable[[], None]): - assert MethodMock._current_mock is not None - MethodMock._current_mock._effects.append(effect) - - @contextmanager - def _context(self): - assert MethodMock._current_mock is None - MethodMock._current_mock = self - try: - yield - finally: - MethodMock._current_mock = None - - async def output_process( - self, - sim: SimulatorContext, - ) -> None: - sync = sim._design.lookup_domain("sync", None) # type: ignore - async for *_, done, arg, clk in sim.changed(self.adapter.done, self.adapter.data_out).edge(sync.clk, 1): - if clk: - self._freeze = True - if not done or self._freeze: - continue - self._effects = [] - with self._context(): - ret = async_mock_def_helper(self, self.function, arg) - sim.set(self.adapter.data_in, ret) - - async def validate_arguments_process(self, sim: SimulatorContext) -> None: - assert self.validate_arguments is not None - sync = sim._design.lookup_domain("sync", None) # type: ignore - async for *args, clk, _ in ( - sim.changed(*(a for a, _ in self.adapter.validators)).edge(sync.clk, 1).edge(self.adapter.en, 1) - ): - assert len(args) == len(self.adapter.validators) # TODO: remove later - if clk: - self._freeze = True - if self._freeze: - continue - for arg, r in zip(args, (r for _, r in self.adapter.validators)): - sim.set(r, async_mock_def_helper(self, self.validate_arguments, arg)) - - async def effect_process(self, sim: SimulatorContext) -> None: - sim.set(self.adapter.en, self.enable()) - async for *_, done in sim.tick().sample(self.adapter.done): - # Disabling the method on each cycle forces an edge when it is reenabled again. - # The method body won't be executed until the effects are done. - sim.set(self.adapter.en, False) - - # First, perform pending effects, updating internal state. - with sim.critical(): - if done: - for eff in self._effects: - eff() - - # Ensure that the effects of all mocks are applied. Delay 0 also does this! - await sim.delay(self.delay) - - # Next, enable the method. The output will be updated by a combinational process. - self._effects = [] - self._freeze = False - sim.set(self.adapter.en, self.enable()) - - -def def_method_mock( - tb_getter: Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO], **kwargs -) -> Callable[[Callable[..., Optional[RecordIntDict]]], Callable[[], MethodMock]]: - """ - Decorator function to create method mock handlers. It should be applied on - a function which describes functionality which we want to invoke on method call. - This function will be called on every clock cycle when the method is active, - and also on combinational changes to inputs. - - The decorated function can have a single argument `arg`, which receives - the arguments passed to a method as a `data.Const`, or multiple named arguments, - which correspond to named arguments of the method. - - This decorator can be applied to function definitions or method definitions. - When applied to a method definition, lambdas passed to `def_method_mock` - need to take a `self` argument, which should be the first. - - Mocks defined at class level or at test level are automatically discovered and - don't need to be manually added to the simulation. - - Any side effects (state modification, assertions, etc.) need to be guarded - using the `MethodMock.effect` decorator. - - Make sure to defer accessing state, since decorators are evaluated eagerly - during function declaration. - - Parameters - ---------- - tb_getter : Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO] - Function to get the TestbenchIO of the mocked method. - enable : Callable[[], bool] | Callable[[Any], bool] - Function which decides if the method is enabled in a given clock cycle. - validate_arguments : Callable[..., bool] - Function which validates call arguments. This applies only to Adapters - with `with_validate_arguments` set to True. - delay : float - Simulation time delay for method mock calling. Used for synchronization - between different mocks and testbench processes. - - Example - ------- - ``` - @def_method_mock(lambda: m.target[k]) - def process(arg): - return {"data": arg["data"] + k} - ``` - or for class methods - ``` - @def_method_mock(lambda self: self.target[k]) - def process(self, data): - return {"data": data + k} - ``` - """ - - def decorator(func: Callable[..., Optional[RecordIntDict]]) -> Callable[[], MethodMock]: - @functools.wraps(func) - def mock(func_self=None, /) -> MethodMock: - f = func - getter: Any = tb_getter - kw = kwargs - if func_self is not None: - getter = getter.__get__(func_self) - f = f.__get__(func_self) - kw = {} - for k, v in kwargs.items(): - bind = getattr(v, "__get__", None) - kw[k] = bind(func_self) if bind else v - tb = getter() - assert isinstance(tb, TestbenchIO) - assert isinstance(tb.adapter, Adapter) - return MethodMock(tb.adapter, f, **kw) - - mock._transactron_method_mock = 1 # type: ignore - return mock - - return decorator diff --git a/transactron/testing/profiler.py b/transactron/testing/profiler.py deleted file mode 100644 index ace2b6327..000000000 --- a/transactron/testing/profiler.py +++ /dev/null @@ -1,46 +0,0 @@ -from amaranth import Cat -from amaranth.lib.data import StructLayout, View -from amaranth.sim._async import ProcessContext -from transactron.core import TransactionManager -from transactron.core.manager import MethodMap -from transactron.profiler import CycleProfile, MethodSamples, Profile, ProfileData, ProfileSamples, TransactionSamples - -__all__ = ["profiler_process"] - - -def profiler_process(transaction_manager: TransactionManager, profile: Profile): - async def process(sim: ProcessContext) -> None: - profile_data, get_id = ProfileData.make(transaction_manager) - method_map = MethodMap(transaction_manager.transactions) - profile.transactions_and_methods = profile_data.transactions_and_methods - - transaction_sample_layout = StructLayout({"request": 1, "runnable": 1, "grant": 1}) - - async for _, _, *data in ( - sim.tick() - .sample( - *( - View(transaction_sample_layout, Cat(transaction.request, transaction.runnable, transaction.grant)) - for transaction in method_map.transactions - ) - ) - .sample(*(method.run for method in method_map.methods)) - ): - transaction_data = data[: len(method_map.transactions)] - method_data = data[len(method_map.transactions) :] - samples = ProfileSamples() - - for transaction, tsample in zip(method_map.transactions, transaction_data): - samples.transactions[get_id(transaction)] = TransactionSamples( - bool(tsample.request), - bool(tsample.runnable), - bool(tsample.grant), - ) - - for method, run in zip(method_map.methods, method_data): - samples.methods[get_id(method)] = MethodSamples(bool(run)) - - cprof = CycleProfile.make(samples, profile_data) - profile.cycles.append(cprof) - - return process diff --git a/transactron/testing/testbenchio.py b/transactron/testing/testbenchio.py deleted file mode 100644 index 05531842c..000000000 --- a/transactron/testing/testbenchio.py +++ /dev/null @@ -1,206 +0,0 @@ -from collections.abc import Generator, Iterable -from amaranth import * -from amaranth.lib.data import View, StructLayout -from amaranth.sim._async import SimulatorContext, TestbenchContext -from typing import Any, Optional - -from transactron.lib import AdapterBase -from transactron.utils import ValueLike -from .functions import MethodData - - -__all__ = ["CallTrigger", "TestbenchIO"] - - -class CallTrigger: - """A trigger which allows to call multiple methods and sample signals. - - The `call()` and `call_try()` methods on a `TestbenchIO` always wait at least one clock cycle. It follows - that these methods can't be used to perform calls to multiple methods in a single clock cycle. Usually - this is not a problem, as different methods can be called from different simulation processes. But in cases - when more control over the time when different calls happen is needed, this trigger class allows to call - many methods in a single clock cycle. - """ - - def __init__( - self, - sim: SimulatorContext, - _calls: Iterable[ValueLike | tuple["TestbenchIO", Optional[dict[str, Any]]]] = (), - ): - """ - Parameters - ---------- - sim: SimulatorContext - Amaranth simulator context. - """ - self.sim = sim - self.calls_and_values: list[ValueLike | tuple[TestbenchIO, Optional[dict[str, Any]]]] = list(_calls) - - def sample(self, *values: "ValueLike | TestbenchIO"): - """Sample a signal or a method result on a clock edge. - - Values are sampled like in standard Amaranth `TickTrigger`. Sampling a method result works like `call()`, - but the method is not called - another process can do that instead. If the method was not called, the - sampled value is `None`. - - Parameters - ---------- - *values: ValueLike | TestbenchIO - Value or method to sample. - """ - new_calls_and_values: list[ValueLike | tuple["TestbenchIO", None]] = [] - for value in values: - if isinstance(value, TestbenchIO): - new_calls_and_values.append((value, None)) - else: - new_calls_and_values.append(value) - return CallTrigger(self.sim, (*self.calls_and_values, *new_calls_and_values)) - - def call(self, tbio: "TestbenchIO", data: dict[str, Any] = {}, /, **kwdata): - """Call a method and sample its result. - - Adds a method call to the trigger. The method result is sampled on a clock edge. If the call did not - succeed, the sampled value is `None`. - - Parameters - ---------- - tbio: TestbenchIO - The method to call. - data: dict[str, Any] - Method call arguments stored in a dict. - **kwdata: Any - Method call arguments passed as keyword arguments. If keyword arguments are used, - the `data` argument should not be provided. - """ - if data and kwdata: - raise TypeError("call() takes either a single dict or keyword arguments") - return CallTrigger(self.sim, (*self.calls_and_values, (tbio, data or kwdata))) - - async def until_done(self) -> Any: - """Wait until at least one of the calls succeeds. - - The `CallTrigger` normally acts like `TickTrigger`, e.g. awaiting on it advances the clock to the next - clock edge. It is possible that none of the calls could not be performed, for example because the called - methods were not enabled. In case we only want to focus on the cycles when one of the calls succeeded, - `until_done` can be used. This works like `until()` in `TickTrigger`. - """ - async for results in self: - if any(res is not None for res in results): - return results - - def __await__(self) -> Generator: - only_calls = [t for t in self.calls_and_values if isinstance(t, tuple)] - only_values = [t for t in self.calls_and_values if not isinstance(t, tuple)] - - for tbio, data in only_calls: - if data is not None: - tbio.call_init(self.sim, data) - - def layout_for(tbio: TestbenchIO): - return StructLayout({"outputs": tbio.adapter.data_out.shape(), "done": 1}) - - trigger = ( - self.sim.tick() - .sample(*(View(layout_for(tbio), Cat(tbio.outputs, tbio.done)) for tbio, _ in only_calls)) - .sample(*only_values) - ) - _, _, *results = yield from trigger.__await__() - - for tbio, data in only_calls: - if data is not None: - tbio.disable(self.sim) - - values_it = iter(results[len(only_calls) :]) - calls_it = (s.outputs if s.done else None for s in results[: len(only_calls)]) - - def ret(): - for v in self.calls_and_values: - if isinstance(v, tuple): - yield next(calls_it) - else: - yield next(values_it) - - return tuple(ret()) - - async def __aiter__(self): - while True: - yield await self - - -class TestbenchIO(Elaboratable): - def __init__(self, adapter: AdapterBase): - self.adapter = adapter - - def elaborate(self, platform): - m = Module() - m.submodules += self.adapter - return m - - # Low-level operations - - def set_enable(self, sim: SimulatorContext, en): - sim.set(self.adapter.en, 1 if en else 0) - - def enable(self, sim: SimulatorContext): - self.set_enable(sim, True) - - def disable(self, sim: SimulatorContext): - self.set_enable(sim, False) - - @property - def done(self): - return self.adapter.done - - @property - def outputs(self): - return self.adapter.data_out - - def set_inputs(self, sim: SimulatorContext, data): - sim.set(self.adapter.data_in, data) - - def get_done(self, sim: TestbenchContext): - return sim.get(self.adapter.done) - - def get_outputs(self, sim: TestbenchContext) -> MethodData: - return sim.get(self.adapter.data_out) - - def sample_outputs(self, sim: SimulatorContext): - return sim.tick().sample(self.adapter.data_out) - - def sample_outputs_until_done(self, sim: SimulatorContext): - return self.sample_outputs(sim).until(self.adapter.done) - - def sample_outputs_done(self, sim: SimulatorContext): - return sim.tick().sample(self.adapter.data_out, self.adapter.done) - - # Operations for AdapterTrans - - def call_init(self, sim: SimulatorContext, data={}, /, **kwdata): - if data and kwdata: - raise TypeError("call_init() takes either a single dict or keyword arguments") - if not data: - data = kwdata - self.enable(sim) - self.set_inputs(sim, data) - - def get_call_result(self, sim: TestbenchContext) -> Optional[MethodData]: - if self.get_done(sim): - return self.get_outputs(sim) - return None - - async def call_result(self, sim: SimulatorContext) -> Optional[MethodData]: - *_, data, done = await self.sample_outputs_done(sim) - if done: - return data - return None - - async def call_do(self, sim: SimulatorContext) -> MethodData: - *_, outputs = await self.sample_outputs_until_done(sim) - self.disable(sim) - return outputs - - async def call_try(self, sim: SimulatorContext, data={}, /, **kwdata) -> Optional[MethodData]: - return (await CallTrigger(sim).call(self, data, **kwdata))[0] - - async def call(self, sim: SimulatorContext, data={}, /, **kwdata) -> MethodData: - return (await CallTrigger(sim).call(self, data, **kwdata).until_done())[0] diff --git a/transactron/testing/tick_count.py b/transactron/testing/tick_count.py deleted file mode 100644 index a2d3828d6..000000000 --- a/transactron/testing/tick_count.py +++ /dev/null @@ -1,25 +0,0 @@ -from dataclasses import dataclass - -from amaranth import Signal -from amaranth.sim._async import ProcessContext - -from transactron.utils.dependencies import DependencyContext, SimpleKey - - -__all__ = ["TicksKey", "make_tick_count_process"] - - -@dataclass(frozen=True) -class TicksKey(SimpleKey[Signal]): - pass - - -def make_tick_count_process(): - ticks = Signal(64) - DependencyContext.get().add_dependency(TicksKey(), ticks) - - async def process(sim: ProcessContext): - async for _, _, ticks_val in sim.tick().sample(ticks): - sim.set(ticks, ticks_val + 1) - - return process diff --git a/transactron/tracing.py b/transactron/tracing.py deleted file mode 100644 index 6f9c709f1..000000000 --- a/transactron/tracing.py +++ /dev/null @@ -1,159 +0,0 @@ -""" -Utilities for extracting dependencies from Amaranth. -""" - -import warnings - -from amaranth.hdl import Elaboratable, Fragment, Instance -from amaranth.hdl._xfrm import FragmentTransformer -from amaranth.hdl import _dsl, _ir, _mem, _xfrm -from amaranth.lib import memory # type: ignore -from amaranth_types import SrcLoc -from transactron.utils import HasElaborate -from . import core - - -# generic tuple because of aggressive monkey-patching -modules_with_fragment: tuple = core, _ir, _dsl, _mem, _xfrm -# List of Fragment subclasses which should be patched to inherit from TracingFragment. -# The first element of the tuple is a subclass name to patch, and the second element -# of the tuple is tuple with modules in which the patched subclass should be installed. -fragment_subclasses_to_patch = [("MemoryInstance", (memory, _mem, _xfrm))] - -DIAGNOSTICS = False -orig_on_fragment = FragmentTransformer.on_fragment - - -class TracingEnabler: - def __enter__(self): - self.orig_fragment_get = Fragment.get - self.orig_on_fragment = FragmentTransformer.on_fragment - self.orig_fragment_class = _ir.Fragment - self.orig_instance_class = _ir.Instance - self.orig_patched_fragment_subclasses = [] - Fragment.get = TracingFragment.get - FragmentTransformer.on_fragment = TracingFragmentTransformer.on_fragment - for mod in modules_with_fragment: - mod.Fragment = TracingFragment - mod.Instance = TracingInstance - for class_name, modules in fragment_subclasses_to_patch: - orig_fragment_subclass = getattr(modules[0], class_name) - # `type` is used to declare new class dynamicaly. There is passed `orig_fragment_subclass` as a first - # base class to allow `super()` to work. Calls to `super` without arguments are syntax sugar and are - # extended on compile/interpretation (not execution!) phase to the `super(OriginalClass, self)`, - # so they are hardcoded on execution time to look for the original class - # (see: https://docs.python.org/3/library/functions.html#super). - # This cause that OriginalClass has to be in `__mro__` of the newly created class, because else an - # TypeError will be raised (see: https://stackoverflow.com/a/40819403). Adding OriginalClass to the - # bases of patched class allows us to fix the TypeError. Everything works correctly because `super` - # starts search of `__mro__` from the class right after the first argument. In our case the first - # checked class will be `TracingFragment` as we want. - newclass = type( - class_name, - ( - orig_fragment_subclass, - TracingFragment, - ), - dict(orig_fragment_subclass.__dict__), - ) - for mod in modules: - setattr(mod, class_name, newclass) - self.orig_patched_fragment_subclasses.append((class_name, orig_fragment_subclass, modules)) - - def __exit__(self, tp, val, tb): - Fragment.get = self.orig_fragment_get - FragmentTransformer.on_fragment = self.orig_on_fragment - for mod in modules_with_fragment: - mod.Fragment = self.orig_fragment_class - mod.Instance = self.orig_instance_class - for class_name, orig_fragment_subclass, modules in self.orig_patched_fragment_subclasses: - for mod in modules: - setattr(mod, class_name, orig_fragment_subclass) - - -class TracingFragmentTransformer(FragmentTransformer): - def on_fragment(self: FragmentTransformer, fragment): - ret = orig_on_fragment(self, fragment) - ret._tracing_original = fragment - fragment._elaborated = ret - return ret - - -class TracingFragment(Fragment): - _tracing_original: Elaboratable - subfragments: list[tuple[Elaboratable, str, SrcLoc]] - - if DIAGNOSTICS: - - def __init__(self, *args, **kwargs): - import sys - import traceback - - self.created = traceback.format_stack(sys._getframe(1)) - super().__init__(*args, **kwargs) - - def __del__(self): - if not hasattr(self, "_tracing_original"): - print("Missing tracing hook:") - for line in self.created: - print(line, end="") - - @staticmethod - def get(obj: HasElaborate, platform) -> "TracingFragment": - """ - This function code is based on Amaranth, which originally loses all information. - It was too difficult to hook into, so this has to be a near-exact copy. - - Relevant copyrights apply. - """ - with TracingEnabler(): - code = None - old_obj = None - while True: - if isinstance(obj, TracingFragment): - return obj - elif isinstance(obj, Fragment): - raise NotImplementedError(f"Monkey-patching missed some Fragment in {old_obj}.elaborate()?") - # This is literally taken from Amaranth {{ - elif isinstance(obj, Elaboratable): - code = obj.elaborate.__code__ - obj._MustUse__used = True # type: ignore - new_obj = obj.elaborate(platform) - elif hasattr(obj, "elaborate"): - warnings.warn( - message="Class {!r} is an elaboratable that does not explicitly inherit from " - "Elaboratable; doing so would improve diagnostics".format(type(obj)), - category=RuntimeWarning, - stacklevel=2, - ) - code = obj.elaborate.__code__ - new_obj = obj.elaborate(platform) - else: - raise AttributeError("Object {!r} cannot be elaborated".format(obj)) - if new_obj is obj: - raise RecursionError("Object {!r} elaborates to itself".format(obj)) - if new_obj is None and code is not None: - warnings.warn_explicit( - message=".elaborate() returned None; missing return statement?", - category=UserWarning, - filename=code.co_filename, - lineno=code.co_firstlineno, - ) - # }} (taken from Amaranth) - new_obj._tracing_original = obj # type: ignore - obj._elaborated = new_obj # type: ignore - - old_obj = obj - obj = new_obj - - def prepare(self, *args, **kwargs) -> "TracingFragment": - with TracingEnabler(): - ret = super().prepare(*args, **kwargs) - ret._tracing_original = self - self._elaborated = ret - return ret - - -class TracingInstance(Instance, TracingFragment): - _tracing_original: Elaboratable - get = TracingFragment.get diff --git a/transactron/utils/__init__.py b/transactron/utils/__init__.py deleted file mode 100644 index ebf845b7d..000000000 --- a/transactron/utils/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .data_repr import * # noqa: F401 -from ._typing import * # noqa: F401 -from .debug_signals import * # noqa: F401 -from .assign import * # noqa: F401 -from .amaranth_ext import * # noqa: F401 -from .transactron_helpers import * # noqa: F401 -from .dependencies import * # noqa: F401 -from .depcache import * # noqa: F401 -from .idgen import * # noqa: F401 diff --git a/transactron/utils/_typing.py b/transactron/utils/_typing.py deleted file mode 100644 index 1a264527b..000000000 --- a/transactron/utils/_typing.py +++ /dev/null @@ -1,82 +0,0 @@ -from typing import ( - Callable, - Concatenate, - ParamSpec, - Protocol, - TypeAlias, - TypeVar, - cast, - runtime_checkable, - Union, - Any, -) -from collections.abc import Iterable, Mapping -from amaranth import * -from amaranth.lib.data import StructLayout, View -from amaranth_types import * -from amaranth_types import _ModuleBuilderDomainsLike - -__all__ = [ - "FragmentLike", - "ValueLike", - "ShapeLike", - "StatementLike", - "SwitchKey", - "SrcLoc", - "MethodLayout", - "MethodStruct", - "SignalBundle", - "LayoutListField", - "LayoutList", - "LayoutIterable", - "RecordIntDict", - "RecordIntDictRet", - "RecordValueDict", - "RecordDict", - "ROGraph", - "Graph", - "GraphCC", - "_ModuleBuilderDomainsLike", - "ModuleLike", - "HasElaborate", - "HasDebugSignals", -] - -# Internal Coreblocks types -SignalBundle: TypeAlias = Signal | Record | View | Iterable["SignalBundle"] | Mapping[str, "SignalBundle"] -LayoutListField: TypeAlias = tuple[str, "ShapeLike | LayoutList"] -LayoutList: TypeAlias = list[LayoutListField] -LayoutIterable: TypeAlias = Iterable[LayoutListField] -MethodLayout: TypeAlias = StructLayout | LayoutIterable -MethodStruct: TypeAlias = "View[StructLayout]" - -RecordIntDict: TypeAlias = Mapping[str, Union[int, "RecordIntDict"]] -RecordIntDictRet: TypeAlias = Mapping[str, Any] # full typing hard to work with -RecordValueDict: TypeAlias = Mapping[str, Union[ValueLike, "RecordValueDict"]] -RecordDict: TypeAlias = ValueLike | Mapping[str, "RecordDict"] - -T = TypeVar("T") -U = TypeVar("U") -P = ParamSpec("P") - -ROGraph: TypeAlias = Mapping[T, Iterable[T]] -Graph: TypeAlias = dict[T, set[T]] -GraphCC: TypeAlias = set[T] - - -@runtime_checkable -class HasDebugSignals(Protocol): - def debug_signals(self) -> SignalBundle: ... - - -def type_self_kwargs_as(as_func: Callable[Concatenate[Any, P], Any]): - """ - Decorator used to annotate `**kwargs` type to be the same as named arguments from `as_func` method. - - Works only with methods with (self, **kwargs) signature. `self` parameter is also required in `as_func`. - """ - - def return_func(func: Callable[Concatenate[Any, ...], T]) -> Callable[Concatenate[Any, P], T]: - return cast(Callable[Concatenate[Any, P], T], func) - - return return_func diff --git a/transactron/utils/amaranth_ext/__init__.py b/transactron/utils/amaranth_ext/__init__.py deleted file mode 100644 index 05df9e85f..000000000 --- a/transactron/utils/amaranth_ext/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .functions import * # noqa: F401 -from .elaboratables import * # noqa: F401 -from .coding import * # noqa: F401 diff --git a/transactron/utils/amaranth_ext/coding.py b/transactron/utils/amaranth_ext/coding.py deleted file mode 100644 index 5360579e8..000000000 --- a/transactron/utils/amaranth_ext/coding.py +++ /dev/null @@ -1,196 +0,0 @@ -# This module was copied from Amaranth because it is deprecated there. -# Copyright (C) 2019-2024 Amaranth HDL contributors - -from amaranth import * - - -__all__ = [ - "Encoder", - "Decoder", - "PriorityEncoder", - "PriorityDecoder", - "GrayEncoder", - "GrayDecoder", -] - - -class Encoder(Elaboratable): - """Encode one-hot to binary. - - If one bit in ``i`` is asserted, ``n`` is low and ``o`` indicates the asserted bit. - Otherwise, ``n`` is high and ``o`` is ``0``. - - Parameters - ---------- - width : int - Bit width of the input - - Attributes - ---------- - i : Signal(width), in - One-hot input. - o : Signal(range(width)), out - Encoded natural binary. - n : Signal, out - Invalid: either none or multiple input bits are asserted. - """ - - def __init__(self, width: int): - self.width = width - - self.i = Signal(width) - self.o = Signal(range(width)) - self.n = Signal() - - def elaborate(self, platform): - m = Module() - with m.Switch(self.i): - for j in range(self.width): - with m.Case(1 << j): - m.d.comb += self.o.eq(j) - with m.Default(): - m.d.comb += self.n.eq(1) - return m - - -class PriorityEncoder(Elaboratable): - """Priority encode requests to binary. - - If any bit in ``i`` is asserted, ``n`` is low and ``o`` indicates the least significant - asserted bit. - Otherwise, ``n`` is high and ``o`` is ``0``. - - Parameters - ---------- - width : int - Bit width of the input. - - Attributes - ---------- - i : Signal(width), in - Input requests. - o : Signal(range(width)), out - Encoded natural binary. - n : Signal, out - Invalid: no input bits are asserted. - """ - - def __init__(self, width: int): - self.width = width - - self.i = Signal(width) - self.o = Signal(range(width)) - self.n = Signal() - - def elaborate(self, platform): - m = Module() - for j in reversed(range(self.width)): - with m.If(self.i[j]): - m.d.comb += self.o.eq(j) - m.d.comb += self.n.eq(self.i == 0) - return m - - -class Decoder(Elaboratable): - """Decode binary to one-hot. - - If ``n`` is low, only the ``i``-th bit in ``o`` is asserted. - If ``n`` is high, ``o`` is ``0``. - - Parameters - ---------- - width : int - Bit width of the output. - - Attributes - ---------- - i : Signal(range(width)), in - Input binary. - o : Signal(width), out - Decoded one-hot. - n : Signal, in - Invalid, no output bits are to be asserted. - """ - - def __init__(self, width: int): - self.width = width - - self.i = Signal(range(width)) - self.n = Signal() - self.o = Signal(width) - - def elaborate(self, platform): - m = Module() - with m.Switch(self.i): - for j in range(len(self.o)): - with m.Case(j): - m.d.comb += self.o.eq(1 << j) - with m.If(self.n): - m.d.comb += self.o.eq(0) - return m - - -class PriorityDecoder(Decoder): - """Decode binary to priority request. - - Identical to :class:`Decoder`. - """ - - -class GrayEncoder(Elaboratable): - """Encode binary to Gray code. - - Parameters - ---------- - width : int - Bit width. - - Attributes - ---------- - i : Signal(width), in - Natural binary input. - o : Signal(width), out - Encoded Gray code. - """ - - def __init__(self, width: int): - self.width = width - - self.i = Signal(width) - self.o = Signal(width) - - def elaborate(self, platform): - m = Module() - m.d.comb += self.o.eq(self.i ^ self.i[1:]) - return m - - -class GrayDecoder(Elaboratable): - """Decode Gray code to binary. - - Parameters - ---------- - width : int - Bit width. - - Attributes - ---------- - i : Signal(width), in - Gray code input. - o : Signal(width), out - Decoded natural binary. - """ - - def __init__(self, width: int): - self.width = width - - self.i = Signal(width) - self.o = Signal(width) - - def elaborate(self, platform): - m = Module() - rhs = Const(0) - for i in reversed(range(self.width)): - rhs = rhs ^ self.i[i] - m.d.comb += self.o[i].eq(rhs) - return m diff --git a/transactron/utils/amaranth_ext/elaboratables.py b/transactron/utils/amaranth_ext/elaboratables.py deleted file mode 100644 index ed6b57122..000000000 --- a/transactron/utils/amaranth_ext/elaboratables.py +++ /dev/null @@ -1,532 +0,0 @@ -import itertools -from contextlib import contextmanager -from typing import Literal, Optional, overload -from collections.abc import Iterable -from amaranth import * -from transactron.utils._typing import HasElaborate, ModuleLike, ValueLike - -__all__ = [ - "OneHotSwitchDynamic", - "OneHotSwitch", - "ModuleConnector", - "Scheduler", - "RoundRobin", - "MultiPriorityEncoder", - "RingMultiPriorityEncoder", -] - - -@contextmanager -def OneHotSwitch(m: ModuleLike, test: Value): - """One-hot switch. - - This function allows one-hot matching in the style similar to the standard - Amaranth `Switch`. This allows to get the performance benefit of using - the one-hot representation. - - Example:: - - with OneHotSwitch(m, sig) as OneHotCase: - with OneHotCase(0b01): - ... - with OneHotCase(0b10): - ... - # optional default case - with OneHotCase(): - ... - - Parameters - ---------- - m : Module - The module for which the matching is defined. - test : Signal - The signal being tested. - """ - - @contextmanager - def case(n: Optional[int] = None): - if n is None: - with m.Default(): - yield - else: - # find the index of the least significant bit set - i = (n & -n).bit_length() - 1 - if n - (1 << i) != 0: - raise ValueError("%d not in one-hot representation" % n) - with m.Case(n): - yield - - with m.Switch(test): - yield case - - -@overload -def OneHotSwitchDynamic(m: ModuleLike, test: Value, *, default: Literal[True]) -> Iterable[Optional[int]]: ... - - -@overload -def OneHotSwitchDynamic(m: ModuleLike, test: Value, *, default: Literal[False] = False) -> Iterable[int]: ... - - -def OneHotSwitchDynamic(m: ModuleLike, test: Value, *, default: bool = False) -> Iterable[Optional[int]]: - """Dynamic one-hot switch. - - This function allows simple one-hot matching on signals which can have - variable bit widths. - - Example:: - - for i in OneHotSwitchDynamic(m, sig): - # code dependent on the bit index i - ... - - Parameters - ---------- - m : Module - The module for which the matching is defined. - test : Signal - The signal being tested. - default : bool, optional - Whether the matching includes a default case (signified by a None). - """ - count = len(test) - with OneHotSwitch(m, test) as OneHotCase: - for i in range(count): - with OneHotCase(1 << i): - yield i - if default: - with OneHotCase(): - yield None - return - - -class ModuleConnector(Elaboratable): - """ - An Elaboratable to create a new module, which will have all arguments - added as its submodules. - """ - - def __init__(self, *args: HasElaborate, **kwargs: HasElaborate): - """ - Parameters - ---------- - *args - Modules which should be added as anonymous submodules. - **kwargs - Modules which will be added as named submodules. - """ - self.args = args - self.kwargs = kwargs - - def elaborate(self, platform): - m = Module() - - for elem in self.args: - m.submodules += elem - - for name, elem in self.kwargs.items(): - m.submodules[name] = elem - - return m - - -class Scheduler(Elaboratable): - """Scheduler - - An implementation of a round-robin scheduler, which is used in the - transaction subsystem. It is based on Amaranth's round-robin scheduler - but instead of using binary numbers, it uses one-hot encoding for the - `grant` output signal. - - Attributes - ---------- - requests: Signal(count), in - Signals that something (e.g. a transaction) wants to run. When i-th - bit is high, then the i-th agent requests the grant signal. - grant: Signal(count), out - Signals that something (e.g. transaction) is granted to run. It uses - one-hot encoding. - valid : Signal(1), out - Signal that `grant` signals are valid. - """ - - def __init__(self, count: int): - """ - Parameters - ---------- - count : int - Number of agents between which the scheduler should arbitrate. - """ - if not isinstance(count, int) or count < 0: - raise ValueError("Count must be a non-negative integer, not {!r}".format(count)) - self.count = count - - self.requests = Signal(count) - self.grant = Signal(count, init=1) - self.valid = Signal() - - def elaborate(self, platform): - m = Module() - - grant_reg = Signal.like(self.grant) - - for i in OneHotSwitchDynamic(m, grant_reg, default=True): - if i is not None: - m.d.comb += self.grant.eq(grant_reg) - for j in itertools.chain(reversed(range(i)), reversed(range(i + 1, self.count))): - with m.If(self.requests[j]): - m.d.comb += self.grant.eq(1 << j) - else: - m.d.comb += self.grant.eq(0) - - m.d.comb += self.valid.eq(self.requests.any()) - - m.d.sync += grant_reg.eq(self.grant) - - return m - - -class RoundRobin(Elaboratable): - """Round-robin scheduler. - For a given set of requests, the round-robin scheduler will - grant one request. Once it grants a request, if any other - requests are active, it grants the next active request with - a greater number, restarting from zero once it reaches the - highest one. - Use :class:`EnableInserter` to control when the scheduler - is updated. - - Implementation ported from amaranth lib. - - Parameters - ---------- - count : int - Number of requests. - Attributes - ---------- - requests : Signal(count), in - Set of requests. - grant : Signal(range(count)), out - Number of the granted request. Does not change if there are no - active requests. - valid : Signal(), out - Asserted if grant corresponds to an active request. Deasserted - otherwise, i.e. if no requests are active. - """ - - def __init__(self, *, count): - if not isinstance(count, int) or count < 0: - raise ValueError("Count must be a non-negative integer, not {!r}".format(count)) - self.count = count - - self.requests = Signal(count) - self.grant = Signal(range(count)) - self.valid = Signal() - - def elaborate(self, platform): - m = Module() - - with m.Switch(self.grant): - for i in range(self.count): - with m.Case(i): - for pred in reversed(range(i)): - with m.If(self.requests[pred]): - m.d.sync += self.grant.eq(pred) - for succ in reversed(range(i + 1, self.count)): - with m.If(self.requests[succ]): - m.d.sync += self.grant.eq(succ) - - m.d.sync += self.valid.eq(self.requests.any()) - - return m - - -class MultiPriorityEncoder(Elaboratable): - """Priority encoder with more outputs - - This is an extension of the `PriorityEncoder` from amaranth that supports - more than one output from an input signal. In other words - it decodes multi-hot encoded signal into lists of signals in binary - format, each with the index of a different high bit in the input. - - Attributes - ---------- - input_width : int - Width of the input signal - outputs_count : int - Number of outputs to generate at once. - input : Signal, in - Signal with 1 on `i`-th bit if `i` can be selected by encoder - outputs : list[Signal], out - Signals with selected indicies, sorted in ascending order, - if the number of ready signals is less than `outputs_count` - then valid signals are at the beginning of the list. - valids : list[Signal], out - One bit for each output signal, indicating whether the output is valid or not. - """ - - def __init__(self, input_width: int, outputs_count: int): - self.input_width = input_width - self.outputs_count = outputs_count - - self.input = Signal(self.input_width) - self.outputs = [Signal(range(self.input_width), name=f"output_{i}") for i in range(self.outputs_count)] - self.valids = [Signal(name=f"valid_{i}") for i in range(self.outputs_count)] - - @staticmethod - def create( - m: Module, input_width: int, input: ValueLike, outputs_count: int = 1, name: Optional[str] = None - ) -> list[tuple[Signal, Signal]]: - """Syntax sugar for creating MultiPriorityEncoder - - This static method allows to use MultiPriorityEncoder in a more functional - way. Instead of creating the instance manually, connecting all the signals and - adding a submodule, you can call this function to do it automatically. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - m.submodules += prio_encoder = PriorityEncoder(cnt) - m.d.top_comb += prio_encoder.input.eq(one_hot_singal) - idx = prio_encoder.outputs - valid = prio.encoder.valids - - Parameters - ---------- - m: Module - Module to add the MultiPriorityEncoder to. - input_width : int - Width of the one hot signal. - input : ValueLike - The one hot signal to decode. - outputs_count : int - Number of different decoder outputs to generate at once. Default: 1. - name : Optional[str] - Name to use when adding MultiPriorityEncoder to submodules. - If None, it will be added as an anonymous submodule. The given name - can not be used in a submodule that has already been added. Default: None. - - Returns - ------- - return : list[tuple[Signal, Signal]] - Returns a list with len equal to outputs_count. Each tuple contains - a pair of decoded index on the first position and a valid signal - on the second position. - """ - prio_encoder = MultiPriorityEncoder(input_width, outputs_count) - if name is None: - m.submodules += prio_encoder - else: - try: - getattr(m.submodules, name) - raise ValueError(f"Name: {name} is already in use, so MultiPriorityEncoder can not be added with it.") - except AttributeError: - setattr(m.submodules, name, prio_encoder) - m.d.comb += prio_encoder.input.eq(input) - return list(zip(prio_encoder.outputs, prio_encoder.valids)) - - @staticmethod - def create_simple( - m: Module, input_width: int, input: ValueLike, name: Optional[str] = None - ) -> tuple[Signal, Signal]: - """Syntax sugar for creating MultiPriorityEncoder - - This is the same as `create` function, but with `outputs_count` hardcoded to 1. - """ - lst = MultiPriorityEncoder.create(m, input_width, input, outputs_count=1, name=name) - return lst[0] - - def build_tree(self, m: Module, in_sig: Signal, start_idx: int): - assert len(in_sig) > 0 - level_outputs = [ - Signal(range(self.input_width), name=f"_lvl_out_idx{start_idx}_{i}") for i in range(self.outputs_count) - ] - level_valids = [Signal(name=f"_lvl_val_idx{start_idx}_{i}") for i in range(self.outputs_count)] - if len(in_sig) == 1: - with m.If(in_sig): - m.d.comb += level_outputs[0].eq(start_idx) - m.d.comb += level_valids[0].eq(1) - else: - middle = len(in_sig) // 2 - r_in = Signal(middle, name=f"_r_in_idx{start_idx}") - l_in = Signal(len(in_sig) - middle, name=f"_l_in_idx{start_idx}") - m.d.comb += r_in.eq(in_sig[0:middle]) - m.d.comb += l_in.eq(in_sig[middle:]) - r_out, r_val = self.build_tree(m, r_in, start_idx) - l_out, l_val = self.build_tree(m, l_in, start_idx + middle) - - with m.Switch(Cat(r_val)): - for i in range(self.outputs_count + 1): - with m.Case((1 << i) - 1): - for j in range(i): - m.d.comb += level_outputs[j].eq(r_out[j]) - m.d.comb += level_valids[j].eq(r_val[j]) - for j in range(i, self.outputs_count): - m.d.comb += level_outputs[j].eq(l_out[j - i]) - m.d.comb += level_valids[j].eq(l_val[j - i]) - return level_outputs, level_valids - - def elaborate(self, platform): - m = Module() - - level_outputs, level_valids = self.build_tree(m, self.input, 0) - - for k in range(self.outputs_count): - m.d.comb += self.outputs[k].eq(level_outputs[k]) - m.d.comb += self.valids[k].eq(level_valids[k]) - - return m - - -class RingMultiPriorityEncoder(Elaboratable): - """Priority encoder with one or more outputs and flexible start - - This is an extension of the `MultiPriorityEncoder` that supports - flexible start and end indexes. In the standard `MultiPriorityEncoder` - the first bit is always at position 0 and the last is the last bit of - the input signal. In this extended implementation, both can be - selected at runtime. - - This implementation is intended for selection from the circular buffers, - so if `last < first` the encoder will first select bits from - [first, input_width) and then from [0, last). - - Attributes - ---------- - input_width : int - Width of the input signal - outputs_count : int - Number of outputs to generate at once. - input : Signal, in - Signal with 1 on `i`-th bit if `i` can be selected by encoder - first : Signal, in - Index of the first bit in the `input`. Inclusive. - last : Signal, out - Index of the last bit in the `input`. Exclusive. - outputs : list[Signal], out - Signals with selected indicies, sorted in ascending order, - if the number of ready signals is less than `outputs_count` - then valid signals are at the beginning of the list. - valids : list[Signal], out - One bit for each output signal, indicating whether the output is valid or not. - """ - - def __init__(self, input_width: int, outputs_count: int): - self.input_width = input_width - self.outputs_count = outputs_count - - self.input = Signal(self.input_width) - self.first = Signal(range(self.input_width)) - self.last = Signal(range(self.input_width)) - self.outputs = [Signal(range(self.input_width), name=f"output_{i}") for i in range(self.outputs_count)] - self.valids = [Signal(name=f"valid_{i}") for i in range(self.outputs_count)] - - @staticmethod - def create( - m: Module, - input_width: int, - input: ValueLike, - first: ValueLike, - last: ValueLike, - outputs_count: int = 1, - name: Optional[str] = None, - ) -> list[tuple[Signal, Signal]]: - """Syntax sugar for creating RingMultiPriorityEncoder - - This static method allows to use RingMultiPriorityEncoder in a more functional - way. Instead of creating the instance manually, connecting all the signals and - adding a submodule, you can call this function to do it automatically. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - m.submodules += prio_encoder = RingMultiPriorityEncoder(input_width, outputs_count) - m.d.comb += prio_encoder.input.eq(one_hot_singal) - m.d.comb += prio_encoder.first.eq(first) - m.d.comb += prio_encoder.last.eq(last) - idx = prio_encoder.outputs - valid = prio.encoder.valids - - Parameters - ---------- - m: Module - Module to add the RingMultiPriorityEncoder to. - input_width : int - Width of the one hot signal. - input : ValueLike - The one hot signal to decode. - first : ValueLike - Index of the first bit in the `input`. Inclusive. - last : ValueLike - Index of the last bit in the `input`. Exclusive. - outputs_count : int - Number of different decoder outputs to generate at once. Default: 1. - name : Optional[str] - Name to use when adding RingMultiPriorityEncoder to submodules. - If None, it will be added as an anonymous submodule. The given name - can not be used in a submodule that has already been added. Default: None. - - Returns - ------- - return : list[tuple[Signal, Signal]] - Returns a list with len equal to outputs_count. Each tuple contains - a pair of decoded index on the first position and a valid signal - on the second position. - """ - prio_encoder = RingMultiPriorityEncoder(input_width, outputs_count) - if name is None: - m.submodules += prio_encoder - else: - try: - getattr(m.submodules, name) - raise ValueError( - f"Name: {name} is already in use, so RingMultiPriorityEncoder can not be added with it." - ) - except AttributeError: - setattr(m.submodules, name, prio_encoder) - m.d.comb += prio_encoder.input.eq(input) - m.d.comb += prio_encoder.first.eq(first) - m.d.comb += prio_encoder.last.eq(last) - return list(zip(prio_encoder.outputs, prio_encoder.valids)) - - @staticmethod - def create_simple( - m: Module, input_width: int, input: ValueLike, first: ValueLike, last: ValueLike, name: Optional[str] = None - ) -> tuple[Signal, Signal]: - """Syntax sugar for creating RingMultiPriorityEncoder - - This is the same as `create` function, but with `outputs_count` hardcoded to 1. - """ - lst = RingMultiPriorityEncoder.create(m, input_width, input, first, last, outputs_count=1, name=name) - return lst[0] - - def elaborate(self, platform): - m = Module() - double_input = Signal(2 * self.input_width) - m.d.comb += double_input.eq(Cat(self.input, self.input)) - - last_corrected = Signal(range(self.input_width * 2)) - with m.If(self.first > self.last): - m.d.comb += last_corrected.eq(self.input_width + self.last) - with m.Else(): - m.d.comb += last_corrected.eq(self.last) - - mask = Signal.like(double_input) - m.d.comb += mask.eq((1 << last_corrected) - 1) - - multi_enc_input = (double_input & mask) >> self.first - - m.submodules.multi_enc = multi_enc = MultiPriorityEncoder(self.input_width, self.outputs_count) - m.d.comb += multi_enc.input.eq(multi_enc_input) - for k in range(self.outputs_count): - moved_out = Signal(range(2 * self.input_width)) - m.d.comb += moved_out.eq(multi_enc.outputs[k] + self.first) - corrected_out = Mux(moved_out >= self.input_width, moved_out - self.input_width, moved_out) - - m.d.comb += self.outputs[k].eq(corrected_out) - m.d.comb += self.valids[k].eq(multi_enc.valids[k]) - return m diff --git a/transactron/utils/amaranth_ext/functions.py b/transactron/utils/amaranth_ext/functions.py deleted file mode 100644 index d09c7b53b..000000000 --- a/transactron/utils/amaranth_ext/functions.py +++ /dev/null @@ -1,99 +0,0 @@ -from amaranth import * -from amaranth.utils import bits_for, exact_log2 -from amaranth.lib import data -from collections.abc import Iterable, Mapping -from transactron.utils._typing import SignalBundle - -__all__ = [ - "mod_incr", - "popcount", - "count_leading_zeros", - "count_trailing_zeros", - "flatten_signals", -] - - -def mod_incr(sig: Value, mod: int) -> Value: - """ - Perform `(sig+1) % mod` operation. - """ - if mod == 2 ** len(sig): - return sig + 1 - return Mux(sig == mod - 1, 0, sig + 1) - - -def popcount(s: Value): - sum_layers = [s[i] for i in range(len(s))] - - while len(sum_layers) > 1: - if len(sum_layers) % 2: - sum_layers.append(C(0)) - sum_layers = [a + b for a, b in zip(sum_layers[::2], sum_layers[1::2])] - - return sum_layers[0][0 : bits_for(len(s))] - - -def count_leading_zeros(s: Value) -> Value: - def iter(s: Value, step: int) -> Value: - # if no bits left - return empty value - if step == 0: - return C(0) - - # boudaries of upper and lower halfs of the value - partition = 2 ** (step - 1) - current_bit = 1 << (step - 1) - - # recursive call - upper_value = iter(s[partition:], step - 1) - lower_value = iter(s[:partition], step - 1) - - # if there are lit bits in upperhalf - take result directly from recursive value - # otherwise add 1 << (step - 1) to lower value and return - result = Mux(s[partition:].any(), upper_value, lower_value | current_bit) - - return result - - try: - xlen_log = exact_log2(len(s)) - except ValueError: - raise NotImplementedError("CountLeadingZeros - only sizes aligned to power of 2 are supperted") - - value = iter(s, xlen_log) - - # 0 number edge case - # if s == 0 then iter() returns value off by 1 - # this switch negates this effect - high_bit = 1 << xlen_log - - result = Mux(s.any(), value, high_bit) - return result - - -def count_trailing_zeros(s: Value) -> Value: - try: - exact_log2(len(s)) - except ValueError: - raise NotImplementedError("CountTrailingZeros - only sizes aligned to power of 2 are supperted") - - return count_leading_zeros(s[::-1]) - - -def flatten_signals(signals: SignalBundle) -> Iterable[Signal]: - """ - Flattens input data, which can be either a signal, a record, a list (or a dict) of SignalBundle items. - - """ - if isinstance(signals, Mapping): - for x in signals.values(): - yield from flatten_signals(x) - elif isinstance(signals, Iterable): - for x in signals: - yield from flatten_signals(x) - elif isinstance(signals, Record): - for x in signals.fields.values(): - yield from flatten_signals(x) - elif isinstance(signals, data.View): - for x, _ in signals.shape(): - yield from flatten_signals(signals[x]) - else: - yield signals diff --git a/transactron/utils/assign.py b/transactron/utils/assign.py deleted file mode 100644 index 4257d2df6..000000000 --- a/transactron/utils/assign.py +++ /dev/null @@ -1,227 +0,0 @@ -from enum import Enum -from typing import Optional, TypeAlias, cast, TYPE_CHECKING -from collections.abc import Sequence, Iterable, Mapping -from amaranth import * -from amaranth.hdl import ShapeLike, ValueCastable -from amaranth.hdl._ast import ArrayProxy, Slice -from amaranth.lib import data -from ._typing import ValueLike - -if TYPE_CHECKING: - from amaranth.hdl._ast import Assign - -__all__ = [ - "AssignType", - "assign", -] - - -class AssignType(Enum): - COMMON = 1 - LHS = 2 - RHS = 3 - ALL = 4 - - -AssignFields: TypeAlias = AssignType | Iterable[str | int] | Mapping[str | int, "AssignFields"] -AssignArg: TypeAlias = ValueLike | Mapping[str, "AssignArg"] | Mapping[int, "AssignArg"] | Sequence["AssignArg"] - - -def arrayproxy_fields(proxy: ArrayProxy) -> Optional[set[str | int]]: - def flatten_elems(proxy: ArrayProxy): - for elem in proxy.elems: - if isinstance(elem, ArrayProxy): - yield from flatten_elems(elem) - else: - yield elem - - elems = list(flatten_elems(proxy)) - if elems and all(isinstance(el, data.View) for el in elems): - return set.intersection(*[set(cast(data.View, el).shape().members.keys()) for el in elems]) - - -def assign_arg_fields(val: AssignArg) -> Optional[set[str | int]]: - if isinstance(val, ArrayProxy): - return arrayproxy_fields(val) - elif isinstance(val, data.View): - layout = val.shape() - if isinstance(layout, data.StructLayout): - return set(k for k in layout.members) - if isinstance(layout, data.ArrayLayout): - return set(range(layout.length)) - elif isinstance(val, dict): - return set(val.keys()) - elif isinstance(val, list): - return set(range(len(val))) - - -def valuelike_shape(val: ValueLike) -> ShapeLike: - if isinstance(val, Value) or isinstance(val, ValueCastable): - return val.shape() - else: - return Value.cast(val).shape() - - -def is_union(val: AssignArg): - return isinstance(val, data.View) and isinstance(val.shape(), data.UnionLayout) - - -def assign( - lhs: AssignArg, rhs: AssignArg, *, fields: AssignFields = AssignType.RHS, lhs_strict=False, rhs_strict=False -) -> Iterable["Assign"]: - """Safe structured assignment. - - This function recursively generates assignment statements for - field-containing structures. This includes: - Amaranth `View`\\s using `StructLayout`, Python `dict`\\s. In case of - mismatching fields or bit widths, error is raised. - - When both `lhs` and `rhs` are field-containing, `assign` generates - assignment statements according to the value of the `field` parameter. - If either of `lhs` or `rhs` is not field-containing, `assign` checks for - the same bit width and generates a single assignment statement. - - The bit width check is performed if: - - - Any of `lhs` or `rhs` is a `View`. - - Both `lhs` and `rhs` have an explicitly defined shape (e.g. are a - `Signal`, a field of a `View`). - - Parameters - ---------- - lhs : View or Value-castable or dict - View, signal or dict being assigned. - rhs : View or Value-castable or dict - View, signal or dict containing assigned values. - fields : AssignType or Iterable or Mapping, optional - Determines which fields will be assigned. Possible values: - - AssignType.COMMON - Only fields common to `lhs` and `rhs` are assigned. - AssignType.LHS - All fields in `lhs` are assigned. If one of them is not present - in `rhs`, an exception is raised. - AssignType.RHS - All fields in `rhs` are assigned. If one of them is not present - in `lhs`, an exception is raised. - AssignType.ALL - Assume that both structures have the same layouts. All fields present - in `lhs` or `rhs` are assigned. - Mapping - Keys are field names, values follow the format for `fields`. - Iterable - Items are field names. For subfields, AssignType.ALL is assumed. - - Returns - ------- - Iterable[Assign] - Generated assignment statements. - - Raises - ------ - ValueError - If the assignment can't be safely performed. - """ - lhs_fields = assign_arg_fields(lhs) - rhs_fields = assign_arg_fields(rhs) - - def rec_call(name: str | int): - subfields = fields - if isinstance(fields, Mapping): - subfields = fields[name] - elif isinstance(fields, Iterable): - subfields = AssignType.ALL - - return assign( - lhs[name], # type: ignore - rhs[name], # type: ignore - fields=subfields, - lhs_strict=isinstance(lhs, ValueLike), - rhs_strict=isinstance(rhs, ValueLike), - ) - - if lhs_fields is not None and rhs_fields is not None: - # asserts for type checking - assert ( - isinstance(lhs, ArrayProxy) - or isinstance(lhs, Mapping) - or isinstance(lhs, Sequence) - or isinstance(lhs, data.View) - ) - assert ( - isinstance(rhs, ArrayProxy) - or isinstance(rhs, Mapping) - or isinstance(rhs, Sequence) - or isinstance(rhs, data.View) - ) - - if fields is AssignType.COMMON: - names = lhs_fields & rhs_fields - elif fields is AssignType.LHS: - names = lhs_fields - elif fields is AssignType.RHS: - names = rhs_fields - elif fields is AssignType.ALL: - names = lhs_fields | rhs_fields - else: - names = set(fields) - - if not names and (lhs_fields or rhs_fields): - raise ValueError("There are no common fields in assigment lhs: {} rhs: {}".format(lhs_fields, rhs_fields)) - - for name in names: - if name not in lhs_fields: - raise KeyError("Field {} not present in lhs".format(name)) - if name not in rhs_fields: - raise KeyError("Field {} not present in rhs".format(name)) - - yield from rec_call(name) - elif is_union(lhs) and isinstance(rhs, Mapping) or isinstance(lhs, Mapping) and is_union(rhs): - mapping, union = (lhs, rhs) if isinstance(lhs, Mapping) else (rhs, lhs) - - # asserts for type checking - assert isinstance(mapping, Mapping) - assert isinstance(union, data.View) - - if len(mapping) != 1: - raise ValueError(f"Non-singleton mapping on union assignment lhs: {lhs} rhs: {rhs}") - name = next(iter(mapping)) - - if name not in union.shape().members: - raise ValueError(f"Field {name} not present in union {union}") - - yield from rec_call(name) - else: - if not isinstance(fields, AssignType): - raise ValueError("Fields on assigning non-structures lhs: {} rhs: {}".format(lhs, rhs)) - if not isinstance(lhs, ValueLike) or not isinstance(rhs, ValueLike): - raise TypeError("Unsupported assignment lhs: {} rhs: {}".format(lhs, rhs)) - - # If a single-value structure, assign its only field - while lhs_fields is not None and len(lhs_fields) == 1: - lhs = lhs[next(iter(lhs_fields))] # type: ignore - lhs_fields = assign_arg_fields(lhs) - while rhs_fields is not None and len(rhs_fields) == 1: - rhs = rhs[next(iter(rhs_fields))] # type: ignore - rhs_fields = assign_arg_fields(rhs) - - def has_explicit_shape(val: ValueLike): - return isinstance(val, (Signal, ArrayProxy, Slice, ValueCastable)) - - if ( - isinstance(lhs, ValueCastable) - or isinstance(rhs, ValueCastable) - or (lhs_strict or has_explicit_shape(lhs)) - and (rhs_strict or has_explicit_shape(rhs)) - ): - if valuelike_shape(lhs) != valuelike_shape(rhs): - raise ValueError( - "Shapes not matching: lhs: {} {} rhs: {} {}".format( - valuelike_shape(lhs), repr(lhs), valuelike_shape(rhs), repr(rhs) - ) - ) - - lhs_val = Value.cast(lhs) - rhs_val = Value.cast(rhs) - - yield lhs_val.eq(rhs_val) diff --git a/transactron/utils/data_repr.py b/transactron/utils/data_repr.py deleted file mode 100644 index acd7c7505..000000000 --- a/transactron/utils/data_repr.py +++ /dev/null @@ -1,143 +0,0 @@ -from collections.abc import Iterable, Mapping -from ._typing import ShapeLike, MethodLayout -from typing import Any, Sized -from statistics import fmean -from amaranth.lib.data import StructLayout - - -__all__ = [ - "make_hashable", - "align_to_power_of_two", - "align_down_to_power_of_two", - "bits_from_int", - "layout_subset", - "data_layout", - "signed_to_int", - "int_to_signed", - "neg", - "average_dict_of_lists", -] - - -def layout_subset(layout: StructLayout, *, fields: set[str]) -> StructLayout: - return StructLayout({item: value for item, value in layout.members.items() if item in fields}) - - -def make_hashable(val): - if isinstance(val, Mapping): - return frozenset(((k, make_hashable(v)) for k, v in val.items())) - elif isinstance(val, Iterable): - return (make_hashable(v) for v in val) - else: - return val - - -def align_to_power_of_two(num: int, power: int) -> int: - """Rounds up a number to the given power of two. - - Parameters - ---------- - num : int - The number to align. - power : int - The power of two to align to. - - Returns - ------- - int - The aligned number. - """ - mask = 2**power - 1 - if num & mask == 0: - return num - return (num & ~mask) + 2**power - - -def align_down_to_power_of_two(num: int, power: int) -> int: - """Rounds down a number to the given power of two. - - Parameters - ---------- - num : int - The number to align. - power : int - The power of two to align to. - - Returns - ------- - int - The aligned number. - """ - mask = 2**power - 1 - - return num & ~mask - - -def bits_from_int(num: int, lower: int, length: int): - """Returns [`lower`:`lower`+`length`) bits from integer `num`.""" - return (num >> lower) & ((1 << (length)) - 1) - - -def data_layout(val: ShapeLike) -> MethodLayout: - return [("data", val)] - - -def neg(x: int, xlen: int) -> int: - """ - Computes the negation of a number in the U2 system. - - Parameters - ---------- - x: int - Number in U2 system. - xlen : int - Bit width of x. - - Returns - ------- - return : int - Negation of x in the U2 system. - """ - return (-x) & (2**xlen - 1) - - -def int_to_signed(x: int, xlen: int) -> int: - """ - Converts a Python integer into its U2 representation. - - Parameters - ---------- - x: int - Signed Python integer. - xlen : int - Bit width of x. - - Returns - ------- - return : int - Representation of x in the U2 system. - """ - return x & (2**xlen - 1) - - -def signed_to_int(x: int, xlen: int) -> int: - """ - Changes U2 representation into Python integer - - Parameters - ---------- - x: int - Number in U2 system. - xlen : int - Bit width of x. - - Returns - ------- - return : int - Representation of x as signed Python integer. - """ - return x | -(x & (2 ** (xlen - 1))) - - -def average_dict_of_lists(d: Mapping[Any, Sized]) -> float: - return fmean(map(lambda xs: len(xs), d.values())) diff --git a/transactron/utils/debug_signals.py b/transactron/utils/debug_signals.py deleted file mode 100644 index 4442e4dd4..000000000 --- a/transactron/utils/debug_signals.py +++ /dev/null @@ -1,77 +0,0 @@ -from typing import Optional -from amaranth import * -from ._typing import SignalBundle, HasDebugSignals -from collections.abc import Collection, Mapping - - -def auto_debug_signals(thing) -> SignalBundle: - """Automatic debug signal generation. - - Exposes class attributes with debug signals (Amaranth `Signal`\\s, - `Record`\\s, `Array`\\s and `Elaboratable`\\s, `Method`\\s, classes - which define `debug_signals`). Used for generating ``gtkw`` files in - tests, for use in ``gtkwave``. - """ - - def auto_debug_signals_internal(thing, *, _visited: set) -> Optional[SignalBundle]: - # Please note, that the set `_visited` is used to memorise visited elements - # to break reference cycles. There is only one instance of this set, for whole - # `auto_debug_signals` recursion stack. It is being mutated by adding to it more - # elements id, so that caller know what was visited by callee. - smap: dict[str, SignalBundle] = {} - - # Check for reference cycles e.g. Amaranth's MustUse - if id(thing) in _visited: - return None - _visited.add(id(thing)) - - match thing: - case HasDebugSignals(): - return thing.debug_signals() - # avoid infinite recursion (strings are `Collection`s of strings) - case str(): - return None - case Collection() | Mapping(): - match thing: - case Collection(): - f_iter = enumerate(thing) - case Mapping(): - f_iter = thing.items() - for i, e in f_iter: - sublist = auto_debug_signals_internal(e, _visited=_visited) - if sublist is not None: - smap[f"[{i}]"] = sublist - if smap: - return smap - return None - case Array(): - for i, e in enumerate(thing): - if isinstance(e, Record): - e.name = f"[{i}]" - return thing - case Signal() | Record(): - return thing - case _: - try: - vs = vars(thing) - except (KeyError, AttributeError, TypeError): - return None - - for v in vs: - a = getattr(thing, v) - - # ignore private fields (mostly to ignore _MustUse_context to get pretty print) - if v[0] == "_": - continue - - dsignals = auto_debug_signals_internal(a, _visited=_visited) - if dsignals is not None: - smap[v] = dsignals - if smap: - return smap - return None - - ret = auto_debug_signals_internal(thing, _visited=set()) - if ret is None: - return [] - return ret diff --git a/transactron/utils/depcache.py b/transactron/utils/depcache.py deleted file mode 100644 index 0fbe356c3..000000000 --- a/transactron/utils/depcache.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import TypeVar, Type, Any - -from transactron.utils import make_hashable - -__all__ = ["DependentCache"] - -T = TypeVar("T") - - -class DependentCache: - """ - Cache for classes, that depend on the `DependentCache` class itself. - - Cached classes may accept one positional argument in the constructor, where this `DependentCache` class will - be passed. Classes may define any number keyword arguments in the constructor and separate cache entry will - be created for each set of the arguments. - - Methods - ------- - get: T, **kwargs -> T - Gets class `cls` from cache. Caches `cls` reference if this is the first call for it. - Optionally accepts `kwargs` for additional arguments in `cls` constructor. - - """ - - def __init__(self): - self._depcache: dict[tuple[Type, Any], Type] = {} - - def get(self, cls: Type[T], **kwargs) -> T: - cache_key = make_hashable(kwargs) - v = self._depcache.get((cls, cache_key), None) - if v is None: - positional_count = cls.__init__.__code__.co_argcount - - # first positional arg is `self` field, second may be `DependentCache` - if positional_count > 2: - raise KeyError(f"Too many positional arguments in {cls!r} constructor") - - if positional_count > 1: - v = cls(self, **kwargs) - else: - v = cls(**kwargs) - self._depcache[(cls, cache_key)] = v - return v diff --git a/transactron/utils/dependencies.py b/transactron/utils/dependencies.py deleted file mode 100644 index 683ff58c2..000000000 --- a/transactron/utils/dependencies.py +++ /dev/null @@ -1,164 +0,0 @@ -from collections import defaultdict - -from abc import abstractmethod, ABC -from typing import Any, Generic, TypeVar - - -__all__ = ["DependencyManager", "DependencyKey", "DependencyContext", "SimpleKey", "ListKey"] - -T = TypeVar("T") -U = TypeVar("U") - - -class DependencyKey(Generic[T, U], ABC): - """Base class for dependency keys. - - Dependency keys are used to access dependencies in the `DependencyManager`. - Concrete instances of dependency keys should be frozen data classes. - - Parameters - ---------- - lock_on_get: bool, default: True - Specifies if no new dependencies should be added to key if it was already read by `get_dependency`. - cache: bool, default: True - If true, result of the `combine` method is cached and subsequent calls to `get_dependency` - will return the value in the cache. Adding a new dependency clears the cache. - empty_valid: bool, default : False - Specifies if getting key dependency without any added dependencies is valid. If set to `False`, that - action would cause raising `KeyError`. - """ - - @abstractmethod - def combine(self, data: list[T]) -> U: - """Combine multiple dependencies with the same key. - - This method is used to generate the value returned from `get_dependency` - in the `DependencyManager`. It takes dependencies added to the key - using `add_dependency` and combines them to a single result. - - Different implementations of `combine` give different combining behavior - for different kinds of keys. - """ - raise NotImplementedError() - - @abstractmethod - def __hash__(self) -> int: - """The `__hash__` method is made abstract so that only concrete keys - can be instanced. It is automatically overridden in frozen data - classes. - """ - raise NotImplementedError() - - lock_on_get: bool = True - cache: bool = True - empty_valid: bool = False - - -class SimpleKey(Generic[T], DependencyKey[T, T]): - """Base class for simple dependency keys. - - Simple dependency keys are used when there is an one-to-one relation between - keys and dependencies. If more than one dependency is added to a simple key, - an error is raised. - - Parameters - ---------- - default_value: T - Specifies the default value returned when no dependencies are added. To - enable it `empty_valid` must be True. - """ - - default_value: T - - def combine(self, data: list[T]) -> T: - if len(data) == 0: - return self.default_value - if len(data) != 1: - raise RuntimeError(f"Key {self} assigned {len(data)} values, expected 1") - return data[0] - - -class ListKey(Generic[T], DependencyKey[T, list[T]]): - """Base class for list key. - - List keys are used when there is an one-to-many relation between keys - and dependecies. Provides list of dependencies. - """ - - empty_valid = True - - def combine(self, data: list[T]) -> list[T]: - return data - - -class DependencyManager: - """Dependency manager. - - Tracks dependencies across the core. - """ - - def __init__(self): - self.dependencies: defaultdict[DependencyKey, list] = defaultdict(list) - self.cache: dict[DependencyKey, Any] = {} - self.locked_dependencies: set[DependencyKey] = set() - - def add_dependency(self, key: DependencyKey[T, Any], dependency: T) -> None: - """Adds a new dependency to a key. - - Depending on the key type, a key can have a single dependency or - multple dependencies added to it. - """ - - if key in self.locked_dependencies: - raise KeyError(f"Trying to add dependency to {key} that was already read and is locked") - - self.dependencies[key].append(dependency) - - if key in self.cache: - del self.cache[key] - - def get_dependency(self, key: DependencyKey[Any, U]) -> U: - """Gets the dependency for a key. - - The way dependencies are interpreted is dependent on the key type. - """ - if not key.empty_valid and key not in self.dependencies: - raise KeyError(f"Dependency {key} not provided") - - if key in self.cache: - return self.cache[key] - - if key.lock_on_get: - self.locked_dependencies.add(key) - - val = key.combine(self.dependencies[key]) - - if key.cache: - self.cache[key] = val - - return val - - def dependency_provided(self, key: DependencyKey) -> bool: - """Checks if any dependency for a key is provided (ignores `empty_valid` parameter)""" - return key in self.dependencies - - -class DependencyContext: - stack: list[DependencyManager] = [] - - def __init__(self, manager: DependencyManager): - self.manager = manager - - def __enter__(self): - self.stack.append(self.manager) - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - top = self.stack.pop() - assert self.manager is top - - @classmethod - def get(cls) -> DependencyManager: - if not cls.stack: - raise RuntimeError("DependencyContext stack is empty") - return cls.stack[-1] diff --git a/transactron/utils/gen.py b/transactron/utils/gen.py deleted file mode 100644 index 780e151cd..000000000 --- a/transactron/utils/gen.py +++ /dev/null @@ -1,258 +0,0 @@ -from dataclasses import dataclass, field -from dataclasses_json import dataclass_json -from typing import Optional, TypeAlias - -from amaranth import * -from amaranth.back import verilog -from amaranth.hdl import Fragment - -from transactron.core import TransactionManager -from transactron.core.keys import TransactionManagerKey -from transactron.core.manager import MethodMap -from transactron.lib.metrics import HardwareMetricsManager -from transactron.lib import logging -from transactron.utils.dependencies import DependencyContext -from transactron.utils.idgen import IdGenerator -from transactron.utils._typing import AbstractInterface -from transactron.profiler import ProfileData - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from amaranth.hdl._ast import SignalDict - - -__all__ = [ - "MetricLocation", - "GeneratedLog", - "GenerationInfo", - "generate_verilog", -] - -SignalHandle: TypeAlias = list[str] -"""The location of a signal is a list of Verilog identifiers that denote a path -consisting of module names (and the signal name at the end) leading -to the signal wire.""" - - -@dataclass_json -@dataclass -class MetricLocation: - """Information about the location of a metric in the generated Verilog code. - - Attributes - ---------- - regs : dict[str, SignalHandle] - The location of each register of that metric. - """ - - regs: dict[str, SignalHandle] = field(default_factory=dict) - - -@dataclass_json -@dataclass -class TransactionSignalsLocation: - """Information about transaction control signals in the generated Verilog code. - - Attributes - ---------- - request: list[str] - The location of the ``request`` signal. - runnable: list[str] - The location of the ``runnable`` signal. - grant: list[str] - The location of the ``grant`` signal. - """ - - request: list[str] - runnable: list[str] - grant: list[str] - - -@dataclass_json -@dataclass -class MethodSignalsLocation: - """Information about method control signals in the generated Verilog code. - - Attributes - ---------- - run: list[str] - The location of the ``run`` signal. - """ - - run: list[str] - - -@dataclass_json -@dataclass -class GeneratedLog(logging.LogRecordInfo): - """Information about a log record in the generated Verilog code. - - Attributes - ---------- - trigger_location : SignalHandle - The location of the trigger signal. - fields_location : list[SignalHandle] - Locations of the log fields. - """ - - trigger_location: SignalHandle - fields_location: list[SignalHandle] - - -@dataclass_json -@dataclass -class GenerationInfo: - """Various information about the generated circuit. - - Attributes - ---------- - metrics_location : dict[str, MetricInfo] - Mapping from a metric name to an object storing Verilog locations - of its registers. - logs : list[GeneratedLog] - Locations and metadata for all log records. - """ - - metrics_location: dict[str, MetricLocation] - transaction_signals_location: dict[int, TransactionSignalsLocation] - method_signals_location: dict[int, MethodSignalsLocation] - profile_data: ProfileData - logs: list[GeneratedLog] - - def encode(self, file_name: str): - """ - Encodes the generation information as JSON and saves it to a file. - """ - with open(file_name, "w") as fp: - fp.write(self.to_json()) # type: ignore - - @staticmethod - def decode(file_name: str) -> "GenerationInfo": - """ - Loads the generation information from a JSON file. - """ - with open(file_name, "r") as fp: - return GenerationInfo.from_json(fp.read()) # type: ignore - - -def escape_verilog_identifier(identifier: str) -> str: - """ - Escapes a Verilog identifier according to the language standard. - - From IEEE Std 1364-2001 (IEEE Standard VerilogĀ® Hardware Description Language) - - "2.7.1 Escaped identifiers - - Escaped identifiers shall start with the backslash character and end with white - space (space, tab, newline). They provide a means of including any of the printable ASCII - characters in an identifier (the decimal values 33 through 126, or 21 through 7E in hexadecimal)." - """ - - # The standard says how to escape a identifier, but not when. So this is - # a non-exhaustive list of characters that Yosys escapes (it is used - # by Amaranth when generating Verilog code). - characters_to_escape = [".", "$", "-"] - - for char in characters_to_escape: - if char in identifier: - return f"\\{identifier} " - - return identifier - - -def get_signal_location(signal: Signal, name_map: "SignalDict") -> SignalHandle: - raw_location = name_map[signal] - return raw_location - - -def collect_metric_locations(name_map: "SignalDict") -> dict[str, MetricLocation]: - metrics_location: dict[str, MetricLocation] = {} - - # Collect information about the location of metric registers in the generated code. - metrics_manager = HardwareMetricsManager() - for metric_name, metric in metrics_manager.get_metrics().items(): - metric_loc = MetricLocation() - for reg_name in metric.regs: - metric_loc.regs[reg_name] = get_signal_location( - metrics_manager.get_register_value(metric_name, reg_name), name_map - ) - - metrics_location[metric_name] = metric_loc - - return metrics_location - - -def collect_transaction_method_signals( - transaction_manager: TransactionManager, name_map: "SignalDict" -) -> tuple[dict[int, TransactionSignalsLocation], dict[int, MethodSignalsLocation]]: - transaction_signals_location: dict[int, TransactionSignalsLocation] = {} - method_signals_location: dict[int, MethodSignalsLocation] = {} - - method_map = MethodMap(transaction_manager.transactions) - get_id = IdGenerator() - - for transaction in method_map.transactions: - request_loc = get_signal_location(transaction.request, name_map) - runnable_loc = get_signal_location(transaction.runnable, name_map) - grant_loc = get_signal_location(transaction.grant, name_map) - transaction_signals_location[get_id(transaction)] = TransactionSignalsLocation( - request_loc, runnable_loc, grant_loc - ) - - for method in method_map.methods: - run_loc = get_signal_location(method.run, name_map) - method_signals_location[get_id(method)] = MethodSignalsLocation(run_loc) - - return (transaction_signals_location, method_signals_location) - - -def collect_logs(name_map: "SignalDict") -> list[GeneratedLog]: - logs: list[GeneratedLog] = [] - - # Get all records. - for record in logging.get_log_records(0): - trigger_loc = get_signal_location(record.trigger, name_map) - fields_loc = [get_signal_location(field, name_map) for field in record.fields] - log = GeneratedLog( - logger_name=record.logger_name, - level=record.level, - format_str=record.format_str, - location=record.location, - trigger_location=trigger_loc, - fields_location=fields_loc, - ) - logs.append(log) - - return logs - - -def generate_verilog( - elaboratable: Elaboratable, ports: Optional[list[Value]] = None, top_name: str = "top" -) -> tuple[str, GenerationInfo]: - # The ports logic is copied (and simplified) from amaranth.back.verilog.convert. - # Unfortunately, the convert function doesn't return the name map. - if ports is None and isinstance(elaboratable, AbstractInterface): - ports = [] - for _, _, value in elaboratable.signature.flatten(elaboratable): - ports.append(Value.cast(value)) - elif ports is None: - raise TypeError("The `generate_verilog()` function requires a `ports=` argument") - - fragment = Fragment.get(elaboratable, platform=None).prepare(ports=ports) - verilog_text, name_map = verilog.convert_fragment(fragment, name=top_name, emit_src=True, strip_internal_attrs=True) - - transaction_manager = DependencyContext.get().get_dependency(TransactionManagerKey()) - transaction_signals, method_signals = collect_transaction_method_signals( - transaction_manager, name_map # type: ignore - ) - profile_data, _ = ProfileData.make(transaction_manager) - gen_info = GenerationInfo( - metrics_location=collect_metric_locations(name_map), # type: ignore - transaction_signals_location=transaction_signals, - method_signals_location=method_signals, - profile_data=profile_data, - logs=collect_logs(name_map), - ) - - return verilog_text, gen_info diff --git a/transactron/utils/idgen.py b/transactron/utils/idgen.py deleted file mode 100644 index 459f3160e..000000000 --- a/transactron/utils/idgen.py +++ /dev/null @@ -1,15 +0,0 @@ -__all__ = ["IdGenerator"] - - -class IdGenerator: - def __init__(self): - self.id_map = dict[int, int]() - self.id_seq = 0 - - def __call__(self, obj): - try: - return self.id_map[id(obj)] - except KeyError: - self.id_seq += 1 - self.id_map[id(obj)] = self.id_seq - return self.id_seq diff --git a/transactron/utils/transactron_helpers.py b/transactron/utils/transactron_helpers.py deleted file mode 100644 index 9cb23cd17..000000000 --- a/transactron/utils/transactron_helpers.py +++ /dev/null @@ -1,169 +0,0 @@ -import sys -from contextlib import contextmanager -from typing import Optional, Any, Concatenate, TypeGuard, TypeVar -from collections.abc import Callable, Mapping, Sequence -from ._typing import ROGraph, GraphCC, SrcLoc, MethodLayout, MethodStruct, ShapeLike, LayoutList, LayoutListField -from inspect import Parameter, signature -from itertools import count -from amaranth import * -from amaranth import tracer -from amaranth.lib.data import StructLayout -import amaranth.lib.data as data - - -__all__ = [ - "longest_common_prefix", - "silence_mustuse", - "get_caller_class_name", - "def_helper", - "method_def_helper", - "mock_def_helper", - "async_mock_def_helper", - "get_src_loc", - "from_method_layout", - "make_layout", - "extend_layout", -] - -T = TypeVar("T") -U = TypeVar("U") - - -def _graph_ccs(gr: ROGraph[T]) -> list[GraphCC[T]]: - """_graph_ccs - - Find connected components in a graph. - - Parameters - ---------- - gr : Mapping[T, Iterable[T]] - Graph in which we should find connected components. Encoded using - adjacency lists. - - Returns - ------- - ccs : List[Set[T]] - Connected components of the graph `gr`. - """ - ccs = [] - cc = set() - visited = set() - - for v in gr.keys(): - q = [v] - while q: - w = q.pop() - if w in visited: - continue - visited.add(w) - cc.add(w) - q.extend(gr[w]) - if cc: - ccs.append(cc) - cc = set() - - return ccs - - -def longest_common_prefix(*seqs: Sequence[T]) -> Sequence[T]: - if not seqs: - raise ValueError("no arguments") - for i, letter_group in enumerate(zip(*seqs)): - if len(set(letter_group)) > 1: - return seqs[0][:i] - return min(seqs, key=lambda s: len(s)) - - -def has_first_param(func: Callable[..., T], name: str, tp: type[U]) -> TypeGuard[Callable[Concatenate[U, ...], T]]: - parameters = signature(func).parameters - return ( - len(parameters) >= 1 - and next(iter(parameters)) == name - and parameters[name].kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.POSITIONAL_ONLY} - and parameters[name].annotation in {Parameter.empty, tp} - ) - - -def def_helper(description, func: Callable[..., T], tp: type[U], arg: U, /, **kwargs) -> T: - try: - parameters = signature(func).parameters - except ValueError: - raise TypeError(f"Invalid python method signature for {func} (missing `self` for class-level mock?)") - - kw_parameters = set( - n for n, p in parameters.items() if p.kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY} - ) - if len(parameters) == 1 and has_first_param(func, "arg", tp): - return func(arg) - elif kw_parameters <= kwargs.keys(): - return func(**kwargs) - else: - raise TypeError(f"Invalid {description}: {func}") - - -def mock_def_helper(tb, func: Callable[..., T], arg: Mapping[str, Any]) -> T: - return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], arg, **arg) - - -def async_mock_def_helper(tb, func: Callable[..., T], arg: "data.Const[StructLayout]") -> T: - marg = {} - for k, _ in arg.shape(): - marg[k] = arg[k] - return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], marg, **marg) - - -def method_def_helper(method, func: Callable[..., T], arg: MethodStruct) -> T: - kwargs = {k: arg[k] for k in arg.shape().members} - return def_helper(f"method definition for {method}", func, MethodStruct, arg, **kwargs) - - -def get_caller_class_name(default: Optional[str] = None) -> tuple[Optional[Elaboratable], str]: - try: - for d in count(2): - caller_frame = sys._getframe(d) - if "self" in caller_frame.f_locals: - owner = caller_frame.f_locals["self"] - if isinstance(owner, Elaboratable): - return owner, owner.__class__.__name__ - except ValueError: - pass - - if default is not None: - return None, default - else: - raise RuntimeError("Not called from a method") - - -@contextmanager -def silence_mustuse(elaboratable: Elaboratable): - try: - yield - except Exception: - elaboratable._MustUse__silence = True # type: ignore - raise - - -def get_src_loc(src_loc: int | SrcLoc) -> SrcLoc: - return tracer.get_src_loc(1 + src_loc) if isinstance(src_loc, int) else src_loc - - -def from_layout_field(shape: ShapeLike | LayoutList) -> ShapeLike: - if isinstance(shape, list): - return from_method_layout(shape) - else: - return shape - - -def make_layout(*fields: LayoutListField) -> StructLayout: - return from_method_layout(fields) - - -def extend_layout(layout: StructLayout, *fields: LayoutListField) -> StructLayout: - return StructLayout(layout.members | from_method_layout(fields).members) - - -def from_method_layout(layout: MethodLayout) -> StructLayout: - if isinstance(layout, StructLayout): - return layout - else: - return StructLayout({k: from_layout_field(v) for k, v in layout})