Skip to content

Commit

Permalink
Simpler way of using method transformations (kuznia-rdzeni/coreblocks…
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk authored Nov 28, 2023
1 parent 2831e9f commit c47d864
Showing 1 changed file with 49 additions and 22 deletions.
71 changes: 49 additions & 22 deletions transactron/lib/transformers.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,56 @@
from abc import ABC
from amaranth import *
from ..core import *
from ..core import RecordDict
from typing import Optional
from collections.abc import Callable
from transactron.utils import ValueLike, assign, AssignType
from transactron.utils import ValueLike, assign, AssignType, ModuleLike
from .connectors import Forwarder, ManyToOneConnectTrans, ConnectTrans

__all__ = [
"MethodTransformer",
"Transformer",
"MethodMap",
"MethodFilter",
"MethodProduct",
"MethodTryProduct",
"Collector",
"CatTrans",
"ConnectAndTransformTrans",
"ConnectAndMapTrans",
]


class MethodTransformer(Elaboratable):
"""Method transformer.
class Transformer(ABC):
"""Method transformer abstract class.
Method transformers construct a new method which utilizes other methods.
Attributes
----------
method: Method
The method.
"""

method: Method

def use(self, m: ModuleLike):
"""
Returns the method and adds the transformer to a module.
Parameters
----------
m: Module or TModule
The module to which this transformer is added as a submodule.
"""
m.submodules += self
return self.method


class MethodMap(Transformer, Elaboratable):
"""Bidirectional map for methods.
Takes a target method and creates a transformed method which calls the
original target method, transforming the input and output values.
The transformation functions take two parameters, a `Module` and the
original target method, mapping the input and output values with
functions. The mapping functions take two parameters, a `Module` and the
`Record` being transformed. Alternatively, a `Method` can be
passed.
Expand All @@ -45,13 +73,13 @@ def __init__(
target: Method
The target method.
i_transform: (record layout, function or Method), optional
Input transformation. If specified, it should be a pair of a
Input mapping function. If specified, it should be a pair of a
function and a input layout for the transformed method.
If not present, input is not transformed.
If not present, input is passed unmodified.
o_transform: (record layout, function or Method), optional
Output transformation. If specified, it should be a pair of a
Output mapping function. If specified, it should be a pair of a
function and a output layout for the transformed method.
If not present, output is not transformed.
If not present, output is passed unmodified.
"""
if i_transform is None:
i_transform = (target.data_in.layout, lambda _, x: x)
Expand All @@ -73,7 +101,7 @@ def _(arg):
return m


class MethodFilter(Elaboratable):
class MethodFilter(Transformer, Elaboratable):
"""Method filter.
Takes a target method and creates a method which calls the target method
Expand Down Expand Up @@ -129,7 +157,7 @@ def _(arg):
return m


class MethodProduct(Elaboratable):
class MethodProduct(Transformer, Elaboratable):
def __init__(
self,
targets: list[Method],
Expand Down Expand Up @@ -177,7 +205,7 @@ def _(arg):
return m


class MethodTryProduct(Elaboratable):
class MethodTryProduct(Transformer, Elaboratable):
def __init__(
self,
targets: list[Method],
Expand Down Expand Up @@ -229,7 +257,7 @@ def _(arg):
return m


class Collector(Elaboratable):
class Collector(Transformer, Elaboratable):
"""Single result collector.
Creates method that collects results of many methods with identical
Expand Down Expand Up @@ -308,14 +336,13 @@ def elaborate(self, platform):
return m


class ConnectAndTransformTrans(Elaboratable):
"""Connecting transaction with transformations.
class ConnectAndMapTrans(Elaboratable):
"""Connecting transaction with mapping functions.
Behaves like `ConnectTrans`, but modifies the transferred data using
functions or `Method`s. Equivalent to a combination of
`ConnectTrans` and `MethodTransformer`. The transformation
functions take two parameters, a `Module` and the `Record` being
transformed.
functions or `Method`s. Equivalent to a combination of `ConnectTrans`
and `MethodMap`. The mapping functions take two parameters, a `Module`
and the `Record` being transformed.
"""

def __init__(
Expand Down Expand Up @@ -346,7 +373,7 @@ def __init__(
def elaborate(self, platform):
m = TModule()

m.submodules.transformer = transformer = MethodTransformer(
m.submodules.transformer = transformer = MethodMap(
self.method2,
i_transform=(self.method1.data_out.layout, self.i_fun),
o_transform=(self.method1.data_in.layout, self.o_fun),
Expand Down

0 comments on commit c47d864

Please sign in to comment.