From 6b4dbf32f348a3cf1a7f4ad54eae79a5c6762447 Mon Sep 17 00:00:00 2001 From: Marek Materzok Date: Tue, 28 Nov 2023 10:57:57 +0100 Subject: [PATCH] Simpler way of using method transformations (#525) --- test/transactions/test_transaction_lib.py | 28 ++++----- transactron/lib/transformers.py | 71 ++++++++++++++++------- 2 files changed, 61 insertions(+), 38 deletions(-) diff --git a/test/transactions/test_transaction_lib.py b/test/transactions/test_transaction_lib.py index d43540860..6daa6517c 100644 --- a/test/transactions/test_transaction_lib.py +++ b/test/transactions/test_transaction_lib.py @@ -351,7 +351,7 @@ def test_many_out(self): sim.add_sync_process(self.generate_producer(i)) -class MethodTransformerTestCircuit(Elaboratable): +class MethodMapTestCircuit(Elaboratable): def __init__(self, iosize: int, use_methods: bool, use_dicts: bool): self.iosize = iosize self.use_methods = use_methods @@ -399,25 +399,21 @@ def _(arg: Record): def _(arg: Record): return otransform(m, arg) - trans = MethodTransformer( - self.target.adapter.iface, i_transform=(layout, imeth), o_transform=(layout, ometh) - ) + trans = MethodMap(self.target.adapter.iface, i_transform=(layout, imeth), o_transform=(layout, ometh)) else: - trans = MethodTransformer( + trans = MethodMap( self.target.adapter.iface, i_transform=(layout, itransform), o_transform=(layout, otransform), ) - m.submodules.trans = trans - - m.submodules.source = self.source = TestbenchIO(AdapterTrans(trans.method)) + m.submodules.source = self.source = TestbenchIO(AdapterTrans(trans.use(m))) return m class TestMethodTransformer(TestCaseWithSimulator): - m: MethodTransformerTestCircuit + m: MethodMapTestCircuit def source(self): for i in range(2**self.m.iosize): @@ -430,19 +426,19 @@ def target(self, data): return {"data": (data << 1) | (data >> (self.m.iosize - 1))} def test_method_transformer(self): - self.m = MethodTransformerTestCircuit(4, False, False) + self.m = MethodMapTestCircuit(4, False, False) with self.run_simulation(self.m) as sim: sim.add_sync_process(self.source) sim.add_sync_process(self.target) def test_method_transformer_dicts(self): - self.m = MethodTransformerTestCircuit(4, False, True) + self.m = MethodMapTestCircuit(4, False, True) with self.run_simulation(self.m) as sim: sim.add_sync_process(self.source) sim.add_sync_process(self.target) def test_method_transformer_with_methods(self): - self.m = MethodTransformerTestCircuit(4, True, True) + self.m = MethodMapTestCircuit(4, True, True) with self.run_simulation(self.m) as sim: sim.add_sync_process(self.source) sim.add_sync_process(self.target) @@ -517,9 +513,9 @@ def elaborate(self, platform): if self.add_combiner: combiner = (layout, lambda _, vs: {"data": sum(vs)}) - m.submodules.product = product = MethodProduct(methods, combiner) + product = MethodProduct(methods, combiner) - m.submodules.method = self.method = TestbenchIO(AdapterTrans(product.method)) + m.submodules.method = self.method = TestbenchIO(AdapterTrans(product.use(m))) return m @@ -704,9 +700,9 @@ def elaborate(self, platform): if self.add_combiner: combiner = (layout, lambda _, vs: {"data": sum(Mux(s, r, 0) for (s, r) in vs)}) - m.submodules.product = product = MethodTryProduct(methods, combiner) + product = MethodTryProduct(methods, combiner) - m.submodules.method = self.method = TestbenchIO(AdapterTrans(product.method)) + m.submodules.method = self.method = TestbenchIO(AdapterTrans(product.use(m))) return m diff --git a/transactron/lib/transformers.py b/transactron/lib/transformers.py index b3d1e2470..18c3ac73a 100644 --- a/transactron/lib/transformers.py +++ b/transactron/lib/transformers.py @@ -1,28 +1,56 @@ +from abc import ABC from amaranth import * from ..core import * from ..core import RecordDict from typing import Optional from collections.abc import Callable -from transactron.utils import ValueLike, assign, AssignType +from transactron.utils import ValueLike, assign, AssignType, ModuleLike from .connectors import Forwarder, ManyToOneConnectTrans, ConnectTrans __all__ = [ - "MethodTransformer", + "Transformer", + "MethodMap", "MethodFilter", "MethodProduct", "MethodTryProduct", "Collector", "CatTrans", - "ConnectAndTransformTrans", + "ConnectAndMapTrans", ] -class MethodTransformer(Elaboratable): - """Method transformer. +class Transformer(ABC): + """Method transformer abstract class. + + Method transformers construct a new method which utilizes other methods. + + Attributes + ---------- + method: Method + The method. + """ + + method: Method + + def use(self, m: ModuleLike): + """ + Returns the method and adds the transformer to a module. + + Parameters + ---------- + m: Module or TModule + The module to which this transformer is added as a submodule. + """ + m.submodules += self + return self.method + + +class MethodMap(Transformer, Elaboratable): + """Bidirectional map for methods. Takes a target method and creates a transformed method which calls the - original target method, transforming the input and output values. - The transformation functions take two parameters, a `Module` and the + original target method, mapping the input and output values with + functions. The mapping functions take two parameters, a `Module` and the `Record` being transformed. Alternatively, a `Method` can be passed. @@ -45,13 +73,13 @@ def __init__( target: Method The target method. i_transform: (record layout, function or Method), optional - Input transformation. If specified, it should be a pair of a + Input mapping function. If specified, it should be a pair of a function and a input layout for the transformed method. - If not present, input is not transformed. + If not present, input is passed unmodified. o_transform: (record layout, function or Method), optional - Output transformation. If specified, it should be a pair of a + Output mapping function. If specified, it should be a pair of a function and a output layout for the transformed method. - If not present, output is not transformed. + If not present, output is passed unmodified. """ if i_transform is None: i_transform = (target.data_in.layout, lambda _, x: x) @@ -73,7 +101,7 @@ def _(arg): return m -class MethodFilter(Elaboratable): +class MethodFilter(Transformer, Elaboratable): """Method filter. Takes a target method and creates a method which calls the target method @@ -129,7 +157,7 @@ def _(arg): return m -class MethodProduct(Elaboratable): +class MethodProduct(Transformer, Elaboratable): def __init__( self, targets: list[Method], @@ -177,7 +205,7 @@ def _(arg): return m -class MethodTryProduct(Elaboratable): +class MethodTryProduct(Transformer, Elaboratable): def __init__( self, targets: list[Method], @@ -229,7 +257,7 @@ def _(arg): return m -class Collector(Elaboratable): +class Collector(Transformer, Elaboratable): """Single result collector. Creates method that collects results of many methods with identical @@ -308,14 +336,13 @@ def elaborate(self, platform): return m -class ConnectAndTransformTrans(Elaboratable): - """Connecting transaction with transformations. +class ConnectAndMapTrans(Elaboratable): + """Connecting transaction with mapping functions. Behaves like `ConnectTrans`, but modifies the transferred data using - functions or `Method`s. Equivalent to a combination of - `ConnectTrans` and `MethodTransformer`. The transformation - functions take two parameters, a `Module` and the `Record` being - transformed. + functions or `Method`s. Equivalent to a combination of `ConnectTrans` + and `MethodMap`. The mapping functions take two parameters, a `Module` + and the `Record` being transformed. """ def __init__( @@ -346,7 +373,7 @@ def __init__( def elaborate(self, platform): m = TModule() - m.submodules.transformer = transformer = MethodTransformer( + m.submodules.transformer = transformer = MethodMap( self.method2, i_transform=(self.method1.data_out.layout, self.i_fun), o_transform=(self.method1.data_in.layout, self.o_fun),