diff --git a/test/common/_test/__init__.py b/test/common/_test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/common/_test/test_infrastructure.py b/test/common/_test/test_infrastructure.py new file mode 100644 index 000000000..ecf1c84d9 --- /dev/null +++ b/test/common/_test/test_infrastructure.py @@ -0,0 +1,31 @@ +from amaranth import * +from test.common import * + + +class EmptyCircuit(Elaboratable): + def __init__(self): + pass + + def elaborate(self, platform): + m = Module() + return m + + +class TestNow(TestCaseWithSimulator): + def setUp(self): + self.test_cycles = 10 + self.m = SimpleTestCircuit(EmptyCircuit()) + + def process(self): + for k in range(self.test_cycles): + now = yield Now() + assert k == now + # check if second call don't change the returned value + now = yield Now() + assert k == now + + yield + + def test_random(self): + with self.run_simulation(self.m, 50) as sim: + sim.add_sync_process(self.process) diff --git a/test/common/functions.py b/test/common/functions.py index 42b43cdf1..eb7abf886 100644 --- a/test/common/functions.py +++ b/test/common/functions.py @@ -1,11 +1,16 @@ from amaranth import * from amaranth.hdl.ast import Statement from amaranth.sim.core import Command -from typing import TypeVar, Any, Generator, TypeAlias +from typing import TypeVar, Any, Generator, TypeAlias, TYPE_CHECKING, Union from transactron.utils._typing import RecordValueDict, RecordIntDict + +if TYPE_CHECKING: + from .infrastructure import CoreblocksCommand + + T = TypeVar("T") -TestGen: TypeAlias = Generator[Command | Value | Statement | None, Any, T] +TestGen: TypeAlias = Generator[Union[Command, Value, Statement, "CoreblocksCommand", None], Any, T] def set_inputs(values: RecordValueDict, field: Record) -> TestGen[None]: diff --git a/test/common/infrastructure.py b/test/common/infrastructure.py index e1141a0eb..058d5b9ed 100644 --- a/test/common/infrastructure.py +++ b/test/common/infrastructure.py @@ -3,10 +3,12 @@ import unittest import functools from contextlib import contextmanager, nullcontext -from typing import TypeVar, Generic, Type, TypeGuard, Any, Union, Callable, cast +from typing import TypeVar, Generic, Type, TypeGuard, Any, Union, Callable, cast, TypeAlias +from abc import ABC from amaranth import * from amaranth.sim import * from .testbenchio import TestbenchIO +from .functions import TestGen from ..gtkw_extension import write_vcd_ext from transactron import Method from transactron.lib import AdapterTrans @@ -14,7 +16,7 @@ from transactron.utils import ModuleConnector, HasElaborate, auto_debug_signals, HasDebugSignals T = TypeVar("T") -_T_nested_collection = T | list["_T_nested_collection[T]"] | dict[str, "_T_nested_collection[T]"] +_T_nested_collection: TypeAlias = T | list["_T_nested_collection[T]"] | dict[str, "_T_nested_collection[T]"] def guard_nested_collection(cont: Any, t: Type[T]) -> TypeGuard[_T_nested_collection[T]]: @@ -99,6 +101,40 @@ def elaborate(self, platform) -> HasElaborate: return m +class CoreblocksCommand(ABC): + pass + + +class Now(CoreblocksCommand): + pass + + +class SyncProcessWrapper: + def __init__(self, f): + self.org_process = f + self.current_cycle = 0 + + def _wrapping_function(self): + response = None + org_coroutine = self.org_process() + try: + while True: + # call orginal test process and catch data yielded by it in `command` variable + command = org_coroutine.send(response) + # If process wait for new cycle + if command is None: + self.current_cycle += 1 + # forward to amaranth + yield + elif isinstance(command, Now): + response = self.current_cycle + # Pass everything else to amaranth simulator without modifications + else: + response = yield command + except StopIteration: + pass + + class PysimSimulator(Simulator): def __init__(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True, traces_file=None): test_module = TestModule(module, add_transaction_module) @@ -133,6 +169,10 @@ def __init__(self, module: HasElaborate, max_cycles: float = 10e4, add_transacti self.deadline = clk_period * max_cycles + def add_sync_process(self, f: Callable[[], TestGen]): + f_wrapped = SyncProcessWrapper(f) + super().add_sync_process(f_wrapped._wrapping_function) + def run(self) -> bool: with self.ctx: self.run_until(self.deadline)