diff --git a/test/common/functions.py b/test/common/functions.py index 1f36e69fe..3c37aec61 100644 --- a/test/common/functions.py +++ b/test/common/functions.py @@ -1,9 +1,18 @@ from amaranth import * +from typing import TYPE_CHECKING, Generator, Any, TypeAlias, TypeVar, Union from coreblocks.utils._typing import RecordValueDict, RecordIntDict -from .infrastructure import TestGen +from amaranth.hdl.ast import Statement +from amaranth.sim.core import Command +if TYPE_CHECKING: + from .infrastructure import CoreblockCommand -def set_inputs(values: RecordValueDict, field: Record) -> TestGen[None]: + +T = TypeVar("T") +TestGen: TypeAlias = Generator[Union[Command, Value, Statement, "CoreblockCommand", None], Any, T] + + +def set_inputs(values: RecordValueDict, field: Record) -> "TestGen[None]": for name, value in values.items(): if isinstance(value, dict): yield from set_inputs(value, getattr(field, name)) @@ -11,7 +20,7 @@ def set_inputs(values: RecordValueDict, field: Record) -> TestGen[None]: yield getattr(field, name).eq(value) -def get_outputs(field: Record) -> TestGen[RecordIntDict]: +def get_outputs(field: Record) -> "TestGen[RecordIntDict]": # return dict of all signal values in a record because amaranth's simulator can't read all # values of a Record in a single yield - it can only read Values (Signals) result = {} diff --git a/test/common/infrastructure.py b/test/common/infrastructure.py index 6c0a1a271..4eb9e3eac 100644 --- a/test/common/infrastructure.py +++ b/test/common/infrastructure.py @@ -3,12 +3,11 @@ import unittest import functools from contextlib import contextmanager, nullcontext -from typing import TypeVar, Generic, Type, TypeGuard, Any, Union, Callable, cast, Generator, TypeAlias +from typing import TypeVar, Generic, Type, TypeGuard, Any, Union, Callable, cast, TypeAlias from amaranth import * from amaranth.sim import * -from amaranth.hdl.ast import Statement -from amaranth.sim.core import Command from .testbenchio import TestbenchIO +from .functions import TestGen from ..gtkw_extension import write_vcd_ext from transactron import Method from transactron.lib import AdapterTrans @@ -16,8 +15,7 @@ from coreblocks.utils import ModuleConnector, HasElaborate, auto_debug_signals, HasDebugSignals T = TypeVar("T") -_T_nested_collection = T | list["_T_nested_collection[T]"] | dict[str, "_T_nested_collection[T]"] -TestGen: TypeAlias = Generator[Command | Value | Statement | "CoreblockCommand" | None, Any, T] +_T_nested_collection: TypeAlias = T | list["_T_nested_collection[T]"] | dict[str, "_T_nested_collection[T]"] def guard_nested_collection(cont: Any, t: Type[T]) -> TypeGuard[_T_nested_collection[T]]: