diff --git a/transactron/core/manager.py b/transactron/core/manager.py index 5a7b165..20b10a8 100644 --- a/transactron/core/manager.py +++ b/transactron/core/manager.py @@ -3,6 +3,7 @@ from os import environ from graphlib import TopologicalSorter from amaranth import * +from amaranth.lib.wiring import Component, connect, flipped from itertools import chain, filterfalse, product from transactron.utils import * @@ -15,7 +16,7 @@ from .tmodule import TModule from .schedulers import eager_deterministic_cc_scheduler -__all__ = ["TransactionManager", "TransactionModule"] +__all__ = ["TransactionManager", "TransactionModule", "TransactionComponent"] TransactionGraph: TypeAlias = Graph["Transaction"] TransactionGraphCC: TypeAlias = GraphCC["Transaction"] @@ -479,3 +480,46 @@ def elaborate(self, platform): ) return m + + +class TransactionComponent(TransactionModule, Component): + """Top-level component for Transactron projects. + + The `TransactronComponent` is a wrapper on `Component` classes, + which adds Transactron support for the wrapped class. The use + case is to wrap a top-level module of the project, and pass the + wrapped module for simulation, HDL generation or synthesis. + The ports of the wrapped component are forwarded to the wrapper. + + It extends the functionality of `TransactionModule`. + """ + + def __init__( + self, + component: Component, + dependency_manager: Optional[DependencyManager] = None, + transaction_manager: Optional[TransactionManager] = None, + ): + """ + Parameters + ---------- + component: Component + The `Component` which should be wrapped to add support for + transactions and methods. + dependency_manager: DependencyManager, optional + The `DependencyManager` to use inside the transaction component. + If omitted, a new one is created. + transaction_manager: TransactionManager, optional + The `TransactionManager` to use inside the transaction component. + If omitted, a new one is created. + """ + TransactionModule.__init__(self, component, dependency_manager, transaction_manager) + Component.__init__(self, component.signature) + + def elaborate(self, platform): + m = super().elaborate(platform) + + for name in self.signature.members: + connect(m, flipped(getattr(self, name)), getattr(self.elaboratable, name)) + + return m diff --git a/transactron/lib/adapters.py b/transactron/lib/adapters.py index 0d76796..611ef4f 100644 --- a/transactron/lib/adapters.py +++ b/transactron/lib/adapters.py @@ -1,5 +1,6 @@ from typing import Optional from amaranth import * +from amaranth.lib.wiring import Component, In, Out from amaranth.lib.data import StructLayout, View from ..utils import SrcLoc, get_src_loc, MethodStruct @@ -13,14 +14,15 @@ ] -class AdapterBase(Elaboratable): +class AdapterBase(Component): data_in: MethodStruct data_out: MethodStruct + en: Signal + done: Signal - def __init__(self, iface: Method): + def __init__(self, iface: Method, layout_in: StructLayout, layout_out: StructLayout): + super().__init__({"data_in": In(layout_in), "data_out": Out(layout_out), "en": In(1), "done": Out(1)}) self.iface = iface - self.en = Signal() - self.done = Signal() def debug_signals(self) -> SignalBundle: return [self.en, self.done, self.data_in, self.data_out] @@ -55,10 +57,8 @@ def __init__(self, iface: Method, *, src_loc: int | SrcLoc = 0): How many stack frames deep the source location is taken from. Alternatively, the source location to use instead of the default. """ - super().__init__(iface) + super().__init__(iface, iface.layout_in, iface.layout_out) self.src_loc = get_src_loc(src_loc) - self.data_in = Signal.like(iface.data_in) - self.data_out = Signal.like(iface.data_out) def elaborate(self, platform): m = TModule() @@ -107,9 +107,8 @@ def __init__(self, **kwargs): kwargs["src_loc"] = get_src_loc(kwargs.setdefault("src_loc", 0)) - super().__init__(Method(**kwargs)) - self.data_in = Signal.like(self.iface.data_out) - self.data_out = Signal.like(self.iface.data_in) + iface = Method(**kwargs) + super().__init__(iface, iface.layout_out, iface.layout_in) self.validators: list[tuple[View[StructLayout], Signal]] = [] self.with_validate_arguments: bool = False diff --git a/transactron/utils/_typing.py b/transactron/utils/_typing.py index e8e3152..2be6f38 100644 --- a/transactron/utils/_typing.py +++ b/transactron/utils/_typing.py @@ -178,11 +178,13 @@ def create( def __repr__(self) -> str: ... -_T_AbstractSignature = TypeVar("_T_AbstractSignature", bound=AbstractSignature) +_T_AbstractSignature = TypeVar("_T_AbstractSignature", bound=AbstractSignature, covariant=True) +@runtime_checkable class AbstractInterface(Protocol, Generic[_T_AbstractSignature]): - signature: _T_AbstractSignature + @property + def signature(self) -> _T_AbstractSignature: ... class HasElaborate(Protocol): diff --git a/transactron/utils/gen.py b/transactron/utils/gen.py index 6837419..780e151 100644 --- a/transactron/utils/gen.py +++ b/transactron/utils/gen.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from dataclasses_json import dataclass_json -from typing import TypeAlias +from typing import Optional, TypeAlias from amaranth import * from amaranth.back import verilog @@ -13,6 +13,7 @@ from transactron.lib import logging from transactron.utils.dependencies import DependencyContext from transactron.utils.idgen import IdGenerator +from transactron.utils._typing import AbstractInterface from transactron.profiler import ProfileData from typing import TYPE_CHECKING @@ -227,9 +228,18 @@ def collect_logs(name_map: "SignalDict") -> list[GeneratedLog]: def generate_verilog( - top_module: Elaboratable, ports: list[Signal], top_name: str = "top" + elaboratable: Elaboratable, ports: Optional[list[Value]] = None, top_name: str = "top" ) -> tuple[str, GenerationInfo]: - fragment = Fragment.get(top_module, platform=None).prepare(ports=ports) + # The ports logic is copied (and simplified) from amaranth.back.verilog.convert. + # Unfortunately, the convert function doesn't return the name map. + if ports is None and isinstance(elaboratable, AbstractInterface): + ports = [] + for _, _, value in elaboratable.signature.flatten(elaboratable): + ports.append(Value.cast(value)) + elif ports is None: + raise TypeError("The `generate_verilog()` function requires a `ports=` argument") + + fragment = Fragment.get(elaboratable, platform=None).prepare(ports=ports) verilog_text, name_map = verilog.convert_fragment(fragment, name=top_name, emit_src=True, strip_internal_attrs=True) transaction_manager = DependencyContext.get().get_dependency(TransactionManagerKey())