Skip to content

Commit

Permalink
Components at top level in Transactron (kuznia-rdzeni/coreblocks#708)
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk authored May 24, 2024
1 parent 1483b7f commit b509086
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 16 deletions.
46 changes: 45 additions & 1 deletion transactron/core/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from os import environ
from graphlib import TopologicalSorter
from amaranth import *
from amaranth.lib.wiring import Component, connect, flipped
from itertools import chain, filterfalse, product

from transactron.utils import *
Expand All @@ -15,7 +16,7 @@
from .tmodule import TModule
from .schedulers import eager_deterministic_cc_scheduler

__all__ = ["TransactionManager", "TransactionModule"]
__all__ = ["TransactionManager", "TransactionModule", "TransactionComponent"]

TransactionGraph: TypeAlias = Graph["Transaction"]
TransactionGraphCC: TypeAlias = GraphCC["Transaction"]
Expand Down Expand Up @@ -479,3 +480,46 @@ def elaborate(self, platform):
)

return m


class TransactionComponent(TransactionModule, Component):
"""Top-level component for Transactron projects.
The `TransactronComponent` is a wrapper on `Component` classes,
which adds Transactron support for the wrapped class. The use
case is to wrap a top-level module of the project, and pass the
wrapped module for simulation, HDL generation or synthesis.
The ports of the wrapped component are forwarded to the wrapper.
It extends the functionality of `TransactionModule`.
"""

def __init__(
self,
component: Component,
dependency_manager: Optional[DependencyManager] = None,
transaction_manager: Optional[TransactionManager] = None,
):
"""
Parameters
----------
component: Component
The `Component` which should be wrapped to add support for
transactions and methods.
dependency_manager: DependencyManager, optional
The `DependencyManager` to use inside the transaction component.
If omitted, a new one is created.
transaction_manager: TransactionManager, optional
The `TransactionManager` to use inside the transaction component.
If omitted, a new one is created.
"""
TransactionModule.__init__(self, component, dependency_manager, transaction_manager)
Component.__init__(self, component.signature)

def elaborate(self, platform):
m = super().elaborate(platform)

for name in self.signature.members:
connect(m, flipped(getattr(self, name)), getattr(self.elaboratable, name))

return m
19 changes: 9 additions & 10 deletions transactron/lib/adapters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional
from amaranth import *
from amaranth.lib.wiring import Component, In, Out
from amaranth.lib.data import StructLayout, View

from ..utils import SrcLoc, get_src_loc, MethodStruct
Expand All @@ -13,14 +14,15 @@
]


class AdapterBase(Elaboratable):
class AdapterBase(Component):
data_in: MethodStruct
data_out: MethodStruct
en: Signal
done: Signal

def __init__(self, iface: Method):
def __init__(self, iface: Method, layout_in: StructLayout, layout_out: StructLayout):
super().__init__({"data_in": In(layout_in), "data_out": Out(layout_out), "en": In(1), "done": Out(1)})
self.iface = iface
self.en = Signal()
self.done = Signal()

def debug_signals(self) -> SignalBundle:
return [self.en, self.done, self.data_in, self.data_out]
Expand Down Expand Up @@ -55,10 +57,8 @@ def __init__(self, iface: Method, *, src_loc: int | SrcLoc = 0):
How many stack frames deep the source location is taken from.
Alternatively, the source location to use instead of the default.
"""
super().__init__(iface)
super().__init__(iface, iface.layout_in, iface.layout_out)
self.src_loc = get_src_loc(src_loc)
self.data_in = Signal.like(iface.data_in)
self.data_out = Signal.like(iface.data_out)

def elaborate(self, platform):
m = TModule()
Expand Down Expand Up @@ -107,9 +107,8 @@ def __init__(self, **kwargs):

kwargs["src_loc"] = get_src_loc(kwargs.setdefault("src_loc", 0))

super().__init__(Method(**kwargs))
self.data_in = Signal.like(self.iface.data_out)
self.data_out = Signal.like(self.iface.data_in)
iface = Method(**kwargs)
super().__init__(iface, iface.layout_out, iface.layout_in)
self.validators: list[tuple[View[StructLayout], Signal]] = []
self.with_validate_arguments: bool = False

Expand Down
6 changes: 4 additions & 2 deletions transactron/utils/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,13 @@ def create(
def __repr__(self) -> str: ...


_T_AbstractSignature = TypeVar("_T_AbstractSignature", bound=AbstractSignature)
_T_AbstractSignature = TypeVar("_T_AbstractSignature", bound=AbstractSignature, covariant=True)


@runtime_checkable
class AbstractInterface(Protocol, Generic[_T_AbstractSignature]):
signature: _T_AbstractSignature
@property
def signature(self) -> _T_AbstractSignature: ...


class HasElaborate(Protocol):
Expand Down
16 changes: 13 additions & 3 deletions transactron/utils/gen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from dataclasses_json import dataclass_json
from typing import TypeAlias
from typing import Optional, TypeAlias

from amaranth import *
from amaranth.back import verilog
Expand All @@ -13,6 +13,7 @@
from transactron.lib import logging
from transactron.utils.dependencies import DependencyContext
from transactron.utils.idgen import IdGenerator
from transactron.utils._typing import AbstractInterface
from transactron.profiler import ProfileData

from typing import TYPE_CHECKING
Expand Down Expand Up @@ -227,9 +228,18 @@ def collect_logs(name_map: "SignalDict") -> list[GeneratedLog]:


def generate_verilog(
top_module: Elaboratable, ports: list[Signal], top_name: str = "top"
elaboratable: Elaboratable, ports: Optional[list[Value]] = None, top_name: str = "top"
) -> tuple[str, GenerationInfo]:
fragment = Fragment.get(top_module, platform=None).prepare(ports=ports)
# The ports logic is copied (and simplified) from amaranth.back.verilog.convert.
# Unfortunately, the convert function doesn't return the name map.
if ports is None and isinstance(elaboratable, AbstractInterface):
ports = []
for _, _, value in elaboratable.signature.flatten(elaboratable):
ports.append(Value.cast(value))
elif ports is None:
raise TypeError("The `generate_verilog()` function requires a `ports=` argument")

fragment = Fragment.get(elaboratable, platform=None).prepare(ports=ports)
verilog_text, name_map = verilog.convert_fragment(fragment, name=top_name, emit_src=True, strip_internal_attrs=True)

transaction_manager = DependencyContext.get().get_dependency(TransactionManagerKey())
Expand Down

0 comments on commit b509086

Please sign in to comment.