Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unifiers key #27

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion transactron/core/method.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Sequence
from collections.abc import Iterable, Sequence

from transactron.utils import *
from amaranth import *
Expand Down Expand Up @@ -330,6 +330,17 @@ def __init__(self, count: int, **kwargs):
kwargs["src_loc"] += 1
self._methods = [Method(**{**kwargs, "name": f"{self.name}{i}"}) for i in range(count)]

@staticmethod
def like(other: "Methods", *, name: Optional[str] = None, src_loc: int | SrcLoc = 0) -> "Methods":
return Methods(len(other), name=name, i=other.layout_in, o=other.layout_out, src_loc=get_src_loc(src_loc))

def proxy(self, m: "TModule", methods: Iterable[Method]):
methods = list(methods)
if len(methods) != len(self):
raise ValueError("number of methods not matching")
for lhs, rhs in zip(self, methods):
lhs.proxy(m, rhs)

@property
def layout_in(self):
return self._methods[0].layout_in
Expand Down
37 changes: 31 additions & 6 deletions transactron/lib/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from collections.abc import Callable
from collections.abc import Callable, Sequence

from .. import Method
from .transformers import Unifier
from ..utils.dependencies import *


__all__ = ["DependencyManager", "DependencyKey", "SimpleKey", "ListKey", "UnifierKey"]
__all__ = ["DependencyManager", "DependencyKey", "SimpleKey", "ListKey", "UnifierKey", "UnifiersKey"]


class UnifierKey(DependencyKey["Method", tuple["Method", dict[str, "Unifier"]]]):
class UnifierKey(DependencyKey[Method, tuple[Method, dict[str, Unifier]]]):
"""Base class for method unifier dependency keys.

Method unifier dependency keys are used to collect methods to be called by
Expand All @@ -17,13 +17,13 @@ class UnifierKey(DependencyKey["Method", tuple["Method", dict[str, "Unifier"]]])
allows to customize the calling behavior.
"""

unifier: Callable[[list["Method"]], "Unifier"]
unifier: Callable[[list[Method]], Unifier]

def __init_subclass__(cls, unifier: Callable[[list["Method"]], "Unifier"], **kwargs) -> None:
def __init_subclass__(cls, unifier: Callable[[list[Method]], Unifier], **kwargs) -> None:
cls.unifier = unifier
return super().__init_subclass__(**kwargs)

def combine(self, data: list["Method"]) -> tuple["Method", dict[str, "Unifier"]]:
def combine(self, data: list[Method]) -> tuple[Method, dict[str, Unifier]]:
if len(data) == 1:
return data[0], {}
else:
Expand All @@ -32,3 +32,28 @@ def combine(self, data: list["Method"]) -> tuple["Method", dict[str, "Unifier"]]
unifiers[self.__class__.__name__ + "_unifier"] = unifier_inst
method = unifier_inst.method
return method, unifiers


class UnifiersKey(DependencyKey[Sequence[Method], tuple[Sequence[Method], dict[str, Unifier]]]):
"""Base class for method unifier dependency keys.

Method unifier dependency keys are used to collect methods to be called by
some part of the core. As multiple modules may wish to be called, a method
unifier is used to present a single method interface to the caller, which
allows to customize the calling behavior.
"""

unifier: Callable[[list[Method]], Unifier]

def __init_subclass__(cls, unifier: Callable[[list[Method]], Unifier], **kwargs) -> None:
cls.unifier = unifier
return super().__init_subclass__(**kwargs)

def combine(self, data: list[Sequence[Method]]) -> tuple[Sequence[Method], dict[str, Unifier]]:
if len(data) == 1:
return data[0], {}
assert all(len(ms) == len(data[0]) for ms in data)
unifiers = [self.unifier(row) for row in zip(*data)]
unifiers_dict = {self.__class__.__name__ + f"_unifier{i}": u for i, u in enumerate(unifiers)}
methods = [unifier.method for unifier in unifiers]
return methods, unifiers_dict
19 changes: 11 additions & 8 deletions transactron/lib/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ..core import *
from ..utils import SrcLoc
from typing import Optional, Protocol
from collections.abc import Callable
from collections.abc import Callable, Iterable
from transactron.utils import (
ValueLike,
assign,
Expand Down Expand Up @@ -60,7 +60,7 @@ def use(self, m: ModuleLike):
class Unifier(Transformer, Protocol):
method: Method

def __init__(self, targets: list[Method]): ...
def __init__(self, targets: Iterable[Method]): ...


class MethodMap(Elaboratable, Transformer):
Expand Down Expand Up @@ -208,7 +208,7 @@ def _(arg):
class MethodProduct(Elaboratable, Unifier):
def __init__(
self,
targets: list[Method],
targets: Iterable[Method],
combiner: Optional[tuple[MethodLayout, Callable[[TModule, list[MethodStruct]], RecordDict]]] = None,
*,
src_loc: int | SrcLoc = 0,
Expand All @@ -224,7 +224,7 @@ def __init__(

Parameters
----------
targets: list[Method]
targets: Iterable[Method]
A list of methods to be called.
combiner: (int or method layout, function), optional
A pair of the output layout and the combiner function. The
Expand All @@ -239,6 +239,7 @@ def __init__(
method: Method
The product method.
"""
targets = list(targets)
if combiner is None:
combiner = (targets[0].layout_out, lambda _, x: x[0])
self.targets = targets
Expand All @@ -262,7 +263,7 @@ def _(arg):
class MethodTryProduct(Elaboratable, Unifier):
def __init__(
self,
targets: list[Method],
targets: Iterable[Method],
combiner: Optional[
tuple[MethodLayout, Callable[[TModule, list[tuple[Value, MethodStruct]]], RecordDict]]
] = None,
Expand All @@ -280,7 +281,7 @@ def __init__(

Parameters
----------
targets: list[Method]
targets: Iterable[Method]
A list of methods to be called.
combiner: (int or method layout, function), optional
A pair of the output layout and the combiner function. The
Expand All @@ -296,6 +297,7 @@ def __init__(
method: Method
The product method.
"""
targets = list(targets)
if combiner is None:
combiner = ([], lambda _, __: {})
self.targets = targets
Expand Down Expand Up @@ -332,16 +334,17 @@ class Collector(Elaboratable, Unifier):
Method which returns single result of provided methods.
"""

def __init__(self, targets: list[Method], *, src_loc: int | SrcLoc = 0):
def __init__(self, targets: Iterable[Method], *, src_loc: int | SrcLoc = 0):
"""
Parameters
----------
method_list: list[Method]
targets: Iterable[Method]
List of methods from which results will be collected.
src_loc: int | SrcLoc
How many stack frames deep the source location is taken from.
Alternatively, the source location to use instead of the default.
"""
targets = list(targets)
self.method_list = targets
layout = targets[0].layout_out
self.src_loc = get_src_loc(src_loc)
Expand Down
Loading