diff --git a/test/transactron/test_connectors.py b/test/transactron/test_connectors.py new file mode 100644 index 000000000..2903397b6 --- /dev/null +++ b/test/transactron/test_connectors.py @@ -0,0 +1,46 @@ +import random +from parameterized import parameterized_class + +from amaranth.sim import Settle + +from transactron.lib import StableSelectingNetwork +from transactron.testing import TestCaseWithSimulator + + +@parameterized_class( + ("n"), + [(2,), (3,), (7,), (8,)], +) +class TestStableSelectingNetwork(TestCaseWithSimulator): + n: int + + def test(self): + m = StableSelectingNetwork(self.n, [("data", 8)]) + + random.seed(42) + + def process(): + for _ in range(100): + inputs = [random.randrange(2**8) for _ in range(self.n)] + valids = [random.randrange(2) for _ in range(self.n)] + total = sum(valids) + + expected_output_prefix = [] + for i in range(self.n): + yield m.valids[i].eq(valids[i]) + yield m.inputs[i].data.eq(inputs[i]) + + if valids[i]: + expected_output_prefix.append(inputs[i]) + + yield Settle() + + for i in range(total): + out = yield m.outputs[i].data + self.assertEqual(out, expected_output_prefix[i]) + + self.assertEqual((yield m.output_cnt), total) + yield + + with self.run_simulation(m) as sim: + sim.add_sync_process(process) diff --git a/transactron/lib/connectors.py b/transactron/lib/connectors.py index b9a6eb204..96620c7ef 100644 --- a/transactron/lib/connectors.py +++ b/transactron/lib/connectors.py @@ -11,6 +11,7 @@ "Connect", "ConnectTrans", "ManyToOneConnectTrans", + "StableSelectingNetwork", ] @@ -275,3 +276,75 @@ def elaborate(self, platform): ) return m + + +class StableSelectingNetwork(Elaboratable): + """A network that groups inputs with a valid bit set. + + The circuit takes `n` inputs with a valid signal each and + on the output returns a grouped and consecutive sequence of the provided + input signals. The order of valid inputs is preserved. + + For example for input (0 is an invalid input): + 0, a, 0, d, 0, 0, e + + The circuit will return: + a, d, e, 0, 0, 0, 0 + + """ + + def __init__(self, n: int, layout: MethodLayout): + self.n = n + self.layout = from_method_layout(layout) + + self.inputs = [Signal(self.layout) for _ in range(n)] + self.valids = [Signal() for _ in range(n)] + + self.outputs = [Signal(self.layout) for _ in range(n)] + self.output_cnt = Signal(range(n + 1)) + + def elaborate(self, platform): + m = TModule() + + current_level = [] + for i in range(self.n): + current_level.append((Array([self.inputs[i]]), self.valids[i])) + + # Create the network using the bottom-up approach. + while True: + if len(current_level) == 1: + break + + next_level = [] + while len(current_level) >= 2: + a, cnt_a = current_level.pop(0) + b, cnt_b = current_level.pop(0) + + total_cnt = Signal(max(len(cnt_a), len(cnt_b)) + 1) + m.d.comb += total_cnt.eq(cnt_a + cnt_b) + + total_len = len(a) + len(b) + merged = Array(Signal(self.layout) for _ in range(total_len)) + + for i in range(len(a)): + m.d.comb += merged[i].eq(Mux(cnt_a <= i, b[i - cnt_a], a[i])) + for i in range(len(b)): + m.d.comb += merged[len(a) + i].eq(Mux(len(a) + i - cnt_a >= len(b), 0, b[len(a) + i - cnt_a])) + + next_level.append((merged, total_cnt)) + + # If we had an odd number of elements on the current level, + # move the item left to the next level. + if len(current_level) == 1: + next_level.append(current_level.pop(0)) + + current_level = next_level + + last_level, total_cnt = current_level.pop(0) + + for i in range(self.n): + m.d.comb += self.outputs[i].eq(last_level[i]) + + m.d.comb += self.output_cnt.eq(total_cnt) + + return m