From 7d2222e9637c6dc1339a7cc710f38b212078e393 Mon Sep 17 00:00:00 2001 From: Marek Materzok Date: Mon, 5 Feb 2024 11:40:44 +0100 Subject: [PATCH] Use DependencyManager as the global object (#581) --- scripts/core_graph.py | 6 ++-- test/transactions/test_transactions.py | 6 ++-- transactron/__init__.py | 1 - transactron/core.py | 40 +++++++++----------------- transactron/testing/infrastructure.py | 6 ++-- transactron/utils/dependencies.py | 22 ++++++++++++++ 6 files changed, 45 insertions(+), 36 deletions(-) diff --git a/scripts/core_graph.py b/scripts/core_graph.py index a589c205a..6818f6dd0 100755 --- a/scripts/core_graph.py +++ b/scripts/core_graph.py @@ -17,7 +17,7 @@ from transactron.graph import TracingFragment # noqa: E402 from test.test_core import CoreTestElaboratable # noqa: E402 from coreblocks.params.configurations import basic_core_config # noqa: E402 -from transactron.core import TransactionModule # noqa: E402 +from transactron.core import TransactionManagerKey, TransactionModule # noqa: E402 gp = GenParams(basic_core_config) elaboratable = CoreTestElaboratable(gp) @@ -25,10 +25,10 @@ fragment = TracingFragment.get(tm, platform=None).prepare() core = fragment -while not hasattr(core, "transactionManager"): +while not hasattr(core, "manager"): core = core._tracing_original # type: ignore -mgr = core.transactionManager # type: ignore +mgr = core.manager.get_dependency(TransactionManagerKey()) # type: ignore with arg.ofile as fp: graph = mgr.visual_graph(fragment) diff --git a/test/transactions/test_transactions.py b/test/transactions/test_transactions.py index 9c7680e26..fe20edbee 100644 --- a/test/transactions/test_transactions.py +++ b/test/transactions/test_transactions.py @@ -317,7 +317,7 @@ def elaborate(self, platform): m = TModule() tm = TransactionModule(m) - with tm.transaction_context(): + with tm.context(): with Transaction().body(m, request=self.r1): m.d.comb += self.t1.eq(1) with Transaction().body(m, request=self.r2): @@ -342,7 +342,7 @@ def _(): def _(): m.d.comb += self.t2.eq(1) - with tm.transaction_context(): + with tm.context(): with Transaction().body(m): method1(m) @@ -389,7 +389,7 @@ def elaborate(self, platform): def _(): pass - with tm.transaction_context(): + with tm.context(): with (t1 := Transaction()).body(m, request=self.r1): method(m) m.d.comb += self.t1.eq(1) diff --git a/transactron/__init__.py b/transactron/__init__.py index ce1898da3..de27375ac 100644 --- a/transactron/__init__.py +++ b/transactron/__init__.py @@ -3,7 +3,6 @@ __all__ = [ "TModule", "TransactionManager", - "TransactionContext", "TransactionModule", "Transaction", "Method", diff --git a/transactron/core.py b/transactron/core.py index 6ef868c9a..598c97fd4 100644 --- a/transactron/core.py +++ b/transactron/core.py @@ -34,7 +34,7 @@ "Priority", "TModule", "TransactionManager", - "TransactionContext", + "TransactionManagerKey", "TransactionModule", "Transaction", "Method", @@ -553,25 +553,9 @@ def method_debug(m: Method): } -class TransactionContext: - stack: list[TransactionManager] = [] - - def __init__(self, manager: TransactionManager): - self.manager = manager - - def __enter__(self): - self.stack.append(self.manager) - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - top = self.stack.pop() - assert self.manager is top - - @classmethod - def get(cls) -> TransactionManager: - if not cls.stack: - raise RuntimeError("TransactionContext stack is empty") - return cls.stack[-1] +@dataclass(frozen=True) +class TransactionManagerKey(SimpleKey[TransactionManager]): + pass class TransactionModule(Elaboratable): @@ -580,6 +564,7 @@ class TransactionModule(Elaboratable): which adds support for transactions. It creates a `TransactionManager` which will handle transaction scheduling and can be used in definition of `Method`\\s and `Transaction`\\s. + The `TransactionManager` is stored in a `DependencyManager`. """ def __init__(self, elaboratable: HasElaborate, manager: Optional[TransactionManager] = None): @@ -592,21 +577,22 @@ def __init__(self, elaboratable: HasElaborate, manager: Optional[TransactionMana """ if manager is None: manager = TransactionManager() - self.transactionManager = manager + self.manager = DependencyManager() + self.manager.add_dependency(TransactionManagerKey(), manager) self.elaboratable = elaboratable - def transaction_context(self) -> TransactionContext: - return TransactionContext(self.transactionManager) + def context(self) -> DependencyContext: + return DependencyContext(self.manager) def elaborate(self, platform): - with silence_mustuse(self.transactionManager): - with self.transaction_context(): + with silence_mustuse(self.manager.get_dependency(TransactionManagerKey())): + with self.context(): elaboratable = Fragment.get(self.elaboratable, platform) m = Module() m.submodules.main_module = elaboratable - m.submodules.transactionManager = self.transactionManager + m.submodules.transactionManager = self.manager.get_dependency(TransactionManagerKey()) return m @@ -1093,7 +1079,7 @@ def __init__( self.owner, owner_name = get_caller_class_name(default="$transaction") self.name = name or tracer.get_var_name(depth=2, default=owner_name) if manager is None: - manager = TransactionContext.get() + manager = DependencyContext.get().get_dependency(TransactionManagerKey()) manager.add_transaction(self) self.request = Signal(name=self.owned_name + "_request") self.runnable = Signal(name=self.owned_name + "_runnable") diff --git a/transactron/testing/infrastructure.py b/transactron/testing/infrastructure.py index 761bd7859..1790a5b09 100644 --- a/transactron/testing/infrastructure.py +++ b/transactron/testing/infrastructure.py @@ -14,7 +14,7 @@ from .gtkw_extension import write_vcd_ext from transactron import Method from transactron.lib import AdapterTrans -from transactron.core import TransactionModule +from transactron.core import TransactionManagerKey, TransactionModule from transactron.utils import ModuleConnector, HasElaborate, auto_debug_signals, HasDebugSignals T = TypeVar("T") @@ -227,7 +227,9 @@ def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_tra profile = None if "__TRANSACTRON_PROFILE" in os.environ and isinstance(sim.tested_module, TransactionModule): profile = Profile() - sim.add_sync_process(profiler_process(sim.tested_module.transactionManager, profile, clk_period)) + sim.add_sync_process( + profiler_process(sim.tested_module.manager.get_dependency(TransactionManagerKey()), profile, clk_period) + ) res = sim.run() diff --git a/transactron/utils/dependencies.py b/transactron/utils/dependencies.py index 2aa0c73df..f365e17b9 100644 --- a/transactron/utils/dependencies.py +++ b/transactron/utils/dependencies.py @@ -7,6 +7,7 @@ __all__ = [ "DependencyManager", "DependencyKey", + "DependencyContext", "SimpleKey", "ListKey" ] @@ -116,3 +117,24 @@ def get_dependency(self, key: DependencyKey[Any, U]) -> U: self.locked_dependencies.add(key) return key.combine(self.dependencies[key]) + + +class DependencyContext: + stack: list[DependencyManager] = [] + + def __init__(self, manager: DependencyManager): + self.manager = manager + + def __enter__(self): + self.stack.append(self.manager) + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + top = self.stack.pop() + assert self.manager is top + + @classmethod + def get(cls) -> DependencyManager: + if not cls.stack: + raise RuntimeError("DependencyContext stack is empty") + return cls.stack[-1]