Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk committed Oct 25, 2023
1 parent c966514 commit 00fd20e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
28 changes: 14 additions & 14 deletions test/transactions/test_transaction_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,21 +362,21 @@ def elaborate(self, platform):

layout = data_layout(self.iosize)

def itransform_rec(m: ModuleLike, v: Record) -> Record:
s = Record.like(v)
m.d.comb += s.data.eq(v.data + 1)
def itransform_rec(m: TModule, arg: Record) -> Record:
s = Record.like(arg)
m.d.comb += s.data.eq(arg.data + 1)
return s

def otransform_rec(m: ModuleLike, v: Record) -> Record:
s = Record.like(v)
m.d.comb += s.data.eq(v.data - 1)
def otransform_rec(m: TModule, arg: Record) -> Record:
s = Record.like(arg)
m.d.comb += s.data.eq(arg.data - 1)
return s

def itransform_dict(_, v: Record) -> RecordDict:
return {"data": v.data + 1}
def itransform_dict(m: TModule, data: Value) -> RecordDict:
return {"data": data + 1}

def otransform_dict(_, v: Record) -> RecordDict:
return {"data": v.data - 1}
def otransform_dict(m: TModule, data: Value) -> RecordDict:
return {"data": data - 1}

if self.use_dicts:
itransform = itransform_dict
Expand Down Expand Up @@ -483,8 +483,8 @@ def test_method_filter_with_methods(self):
def test_method_filter(self):
self.initialize()

def condition(_, v):
return v[0]
def condition(data: Value):
return data[0]

self.tc = SimpleTestCircuit(MethodFilter(self.target.adapter.iface, condition))
m = ModuleConnector(test_circuit=self.tc, target=self.target)
Expand Down Expand Up @@ -515,7 +515,7 @@ def elaborate(self, platform):

combiner = None
if self.add_combiner:
combiner = (layout, lambda _, vs: {"data": sum(vs)})
combiner = (layout, lambda vs: {"data": sum(vs)})

m.submodules.product = product = MethodProduct(methods, combiner)

Expand Down Expand Up @@ -702,7 +702,7 @@ def elaborate(self, platform):

combiner = None
if self.add_combiner:
combiner = (layout, lambda _, vs: {"data": sum(Mux(s, r, 0) for (s, r) in vs)})
combiner = (layout, lambda vs: {"data": sum(Mux(s, r, 0) for (s, r) in vs)})

m.submodules.product = product = MethodTryProduct(methods, combiner)

Expand Down
18 changes: 9 additions & 9 deletions transactron/lib/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def bind_tmodule(m: TModule, func: CallableOptTModule[P, T]) -> Callable[P, T]:


def transformer_helper(tr, m: TModule, func: Callable[..., T], arg: Record) -> T:
return def_helper(f"function for {tr}", bind_tmodule(m, func), Record, arg)
return def_helper(f"function for {tr}", bind_tmodule(m, func), Record, arg, **arg.fields)


class MethodTransformer(Elaboratable):
Expand Down Expand Up @@ -68,9 +68,9 @@ def __init__(
If not present, output is not transformed.
"""
if i_transform is None:
i_transform = (target.data_in.layout, lambda _, x: x)
i_transform = (target.data_in.layout, lambda arg: arg)
if o_transform is None:
o_transform = (target.data_out.layout, lambda _, x: x)
o_transform = (target.data_out.layout, lambda arg: arg)

self.target = target
self.method = Method(i=i_transform[0], o=o_transform[0])
Expand Down Expand Up @@ -171,7 +171,7 @@ def __init__(
The product method.
"""
if combiner is None:
combiner = (targets[0].data_out.layout, lambda _, x: x[0])
combiner = (targets[0].data_out.layout, lambda x: x[0])
self.targets = targets
self.combiner = combiner
self.method = Method(i=targets[0].data_in.layout, o=combiner[0])
Expand All @@ -193,7 +193,7 @@ class MethodTryProduct(Elaboratable):
def __init__(
self,
targets: list[Method],
combiner: Optional[tuple[MethodLayout, Callable[[TModule, list[tuple[Value, Record]]], RecordDict]]] = None,
combiner: Optional[tuple[MethodLayout, CallableOptTModule[[list[tuple[Value, Record]]], RecordDict]]] = None,
):
"""Method product with optional calling.
Expand All @@ -220,7 +220,7 @@ def __init__(
The product method.
"""
if combiner is None:
combiner = ([], lambda _, __: {})
combiner = ([], lambda arg: {})
self.targets = targets
self.combiner = combiner
self.method = Method(i=targets[0].data_in.layout, o=combiner[0])
Expand All @@ -236,7 +236,7 @@ def _(arg):
with Transaction().body(m):
m.d.comb += success.eq(1)
results.append((success, target(m, arg)))
return self.combiner[1](m, results)
return bind_tmodule(m, self.combiner[1])(results)

return m

Expand Down Expand Up @@ -352,8 +352,8 @@ def __init__(
"""
self.method1 = method1
self.method2 = method2
self.i_fun = i_fun or (lambda _, x: x)
self.o_fun = o_fun or (lambda _, x: x)
self.i_fun = i_fun or (lambda arg: arg)
self.o_fun = o_fun or (lambda arg: arg)

def elaborate(self, platform):
m = TModule()
Expand Down

0 comments on commit 00fd20e

Please sign in to comment.