From b627a3d9138721a67dcc527ef9dc35e4ad2643c3 Mon Sep 17 00:00:00 2001 From: Marek Materzok Date: Wed, 11 Dec 2024 10:57:38 +0100 Subject: [PATCH] Add unifiers key --- transactron/core/method.py | 13 +++++++++++- transactron/lib/dependencies.py | 37 +++++++++++++++++++++++++++------ transactron/lib/transformers.py | 19 ++++++++++------- 3 files changed, 54 insertions(+), 15 deletions(-) diff --git a/transactron/core/method.py b/transactron/core/method.py index afc678d..27ec621 100644 --- a/transactron/core/method.py +++ b/transactron/core/method.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from transactron.utils import * from amaranth import * @@ -330,6 +330,17 @@ def __init__(self, count: int, **kwargs): kwargs["src_loc"] += 1 self._methods = [Method(**{**kwargs, "name": f"{self.name}{i}"}) for i in range(count)] + @staticmethod + def like(other: "Methods", *, name: Optional[str] = None, src_loc: int | SrcLoc = 0) -> "Methods": + return Methods(len(other), name=name, i=other.layout_in, o=other.layout_out, src_loc=get_src_loc(src_loc)) + + def proxy(self, m: "TModule", methods: Iterable[Method]): + methods = list(methods) + if len(methods) != len(self): + raise ValueError("number of methods not matching") + for lhs, rhs in zip(self, methods): + lhs.proxy(m, rhs) + @property def layout_in(self): return self._methods[0].layout_in diff --git a/transactron/lib/dependencies.py b/transactron/lib/dependencies.py index c7b099b..a1c6266 100644 --- a/transactron/lib/dependencies.py +++ b/transactron/lib/dependencies.py @@ -1,14 +1,14 @@ -from collections.abc import Callable +from collections.abc import Callable, Sequence from .. import Method from .transformers import Unifier from ..utils.dependencies import * -__all__ = ["DependencyManager", "DependencyKey", "SimpleKey", "ListKey", "UnifierKey"] +__all__ = ["DependencyManager", "DependencyKey", "SimpleKey", "ListKey", "UnifierKey", "UnifiersKey"] -class UnifierKey(DependencyKey["Method", tuple["Method", dict[str, "Unifier"]]]): +class UnifierKey(DependencyKey[Method, tuple[Method, dict[str, Unifier]]]): """Base class for method unifier dependency keys. Method unifier dependency keys are used to collect methods to be called by @@ -17,13 +17,13 @@ class UnifierKey(DependencyKey["Method", tuple["Method", dict[str, "Unifier"]]]) allows to customize the calling behavior. """ - unifier: Callable[[list["Method"]], "Unifier"] + unifier: Callable[[list[Method]], Unifier] - def __init_subclass__(cls, unifier: Callable[[list["Method"]], "Unifier"], **kwargs) -> None: + def __init_subclass__(cls, unifier: Callable[[list[Method]], Unifier], **kwargs) -> None: cls.unifier = unifier return super().__init_subclass__(**kwargs) - def combine(self, data: list["Method"]) -> tuple["Method", dict[str, "Unifier"]]: + def combine(self, data: list[Method]) -> tuple[Method, dict[str, Unifier]]: if len(data) == 1: return data[0], {} else: @@ -32,3 +32,28 @@ def combine(self, data: list["Method"]) -> tuple["Method", dict[str, "Unifier"]] unifiers[self.__class__.__name__ + "_unifier"] = unifier_inst method = unifier_inst.method return method, unifiers + + +class UnifiersKey(DependencyKey[Sequence[Method], tuple[Sequence[Method], dict[str, Unifier]]]): + """Base class for method unifier dependency keys. + + Method unifier dependency keys are used to collect methods to be called by + some part of the core. As multiple modules may wish to be called, a method + unifier is used to present a single method interface to the caller, which + allows to customize the calling behavior. + """ + + unifier: Callable[[list[Method]], Unifier] + + def __init_subclass__(cls, unifier: Callable[[list[Method]], Unifier], **kwargs) -> None: + cls.unifier = unifier + return super().__init_subclass__(**kwargs) + + def combine(self, data: list[Sequence[Method]]) -> tuple[Sequence[Method], dict[str, Unifier]]: + if len(data) == 1: + return data[0], {} + assert all(len(ms) == len(data[0]) for ms in data) + unifiers = [self.unifier(row) for row in zip(*data)] + unifiers_dict = {self.__class__.__name__ + f"_unifier{i}": u for i, u in enumerate(unifiers)} + methods = [unifier.method for unifier in unifiers] + return methods, unifiers_dict diff --git a/transactron/lib/transformers.py b/transactron/lib/transformers.py index d3198d7..47115ba 100644 --- a/transactron/lib/transformers.py +++ b/transactron/lib/transformers.py @@ -4,7 +4,7 @@ from ..core import * from ..utils import SrcLoc from typing import Optional, Protocol -from collections.abc import Callable +from collections.abc import Callable, Iterable from transactron.utils import ( ValueLike, assign, @@ -60,7 +60,7 @@ def use(self, m: ModuleLike): class Unifier(Transformer, Protocol): method: Method - def __init__(self, targets: list[Method]): ... + def __init__(self, targets: Iterable[Method]): ... class MethodMap(Elaboratable, Transformer): @@ -208,7 +208,7 @@ def _(arg): class MethodProduct(Elaboratable, Unifier): def __init__( self, - targets: list[Method], + targets: Iterable[Method], combiner: Optional[tuple[MethodLayout, Callable[[TModule, list[MethodStruct]], RecordDict]]] = None, *, src_loc: int | SrcLoc = 0, @@ -224,7 +224,7 @@ def __init__( Parameters ---------- - targets: list[Method] + targets: Iterable[Method] A list of methods to be called. combiner: (int or method layout, function), optional A pair of the output layout and the combiner function. The @@ -239,6 +239,7 @@ def __init__( method: Method The product method. """ + targets = list(targets) if combiner is None: combiner = (targets[0].layout_out, lambda _, x: x[0]) self.targets = targets @@ -262,7 +263,7 @@ def _(arg): class MethodTryProduct(Elaboratable, Unifier): def __init__( self, - targets: list[Method], + targets: Iterable[Method], combiner: Optional[ tuple[MethodLayout, Callable[[TModule, list[tuple[Value, MethodStruct]]], RecordDict]] ] = None, @@ -280,7 +281,7 @@ def __init__( Parameters ---------- - targets: list[Method] + targets: Iterable[Method] A list of methods to be called. combiner: (int or method layout, function), optional A pair of the output layout and the combiner function. The @@ -296,6 +297,7 @@ def __init__( method: Method The product method. """ + targets = list(targets) if combiner is None: combiner = ([], lambda _, __: {}) self.targets = targets @@ -332,16 +334,17 @@ class Collector(Elaboratable, Unifier): Method which returns single result of provided methods. """ - def __init__(self, targets: list[Method], *, src_loc: int | SrcLoc = 0): + def __init__(self, targets: Iterable[Method], *, src_loc: int | SrcLoc = 0): """ Parameters ---------- - method_list: list[Method] + targets: Iterable[Method] List of methods from which results will be collected. src_loc: int | SrcLoc How many stack frames deep the source location is taken from. Alternatively, the source location to use instead of the default. """ + targets = list(targets) self.method_list = targets layout = targets[0].layout_out self.src_loc = get_src_loc(src_loc)