diff --git a/transactron/core.py b/transactron/core.py index c4c34f8cc..745b98b4d 100644 --- a/transactron/core.py +++ b/transactron/core.py @@ -2,7 +2,18 @@ from collections.abc import Sequence, Iterable, Callable, Mapping, Iterator from contextlib import contextmanager from enum import Enum, auto -from typing import ClassVar, NoReturn, TypeAlias, TypedDict, Union, Optional, Tuple +from typing import ( + ClassVar, + NoReturn, + TypeAlias, + TypedDict, + Union, + Optional, + Tuple, + TypeVar, + Protocol, + runtime_checkable, +) from graphlib import TopologicalSorter from typing_extensions import Self from amaranth import * @@ -38,6 +49,7 @@ TransactionScheduler: TypeAlias = Callable[["MethodMap", TransactionGraph, TransactionGraphCC, PriorityOrder], Module] RecordDict: TypeAlias = ValueLike | Mapping[str, "RecordDict"] TransactionOrMethod: TypeAlias = Union["Transaction", "Method"] +TransactionOrMethodBound = TypeVar("TransactionOrMethodBound", "Transaction", "Method") class Priority(Enum): @@ -670,15 +682,20 @@ def elaborate(self, platform): return self.main_module -class TransactionBase(Owned): +@runtime_checkable +class TransactionBase(Owned, Protocol): stack: ClassVar[list[Union["Transaction", "Method"]]] = [] def_counter: ClassVar[count] = count() def_order: int defined: bool = False name: str + method_uses: dict["Method", Tuple[ValueLike, ValueLike]] + relations: list[RelationBase] + simultaneous_list: list[TransactionOrMethod] + independent_list: list[TransactionOrMethod] def __init__(self): - self.method_uses: dict[Method, Tuple[ValueLike, ValueLike]] = dict() + self.method_uses: dict["Method", Tuple[ValueLike, ValueLike]] = dict() self.relations: list[RelationBase] = [] self.simultaneous_list: list[TransactionOrMethod] = [] self.independent_list: list[TransactionOrMethod] = [] @@ -769,9 +786,7 @@ def _independent(self, *others: TransactionOrMethod) -> None: self.independent_list += others @contextmanager - def context(self, m: TModule) -> Iterator[Self]: - assert isinstance(self, Transaction) or isinstance(self, Method) # for typing - + def context(self: TransactionOrMethodBound, m: TModule) -> Iterator[TransactionOrMethodBound]: parent = TransactionBase.peek() if parent is not None: parent.schedule_before(self) diff --git a/transactron/graph.py b/transactron/graph.py index 2deaf24a0..4cd51d067 100644 --- a/transactron/graph.py +++ b/transactron/graph.py @@ -3,15 +3,14 @@ """ from enum import IntFlag -from abc import ABC from collections import defaultdict -from typing import Literal, Optional +from typing import Literal, Optional, Protocol from amaranth.hdl.ir import Elaboratable, Fragment from .tracing import TracingFragment -class Owned(ABC): +class Owned(Protocol): name: str owner: Optional[Elaboratable]