Skip to content

Commit

Permalink
Add selecting network circuit (kuznia-rdzeni/coreblocks#621)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jakub Urbańczyk authored Mar 27, 2024
1 parent d7c359a commit 95119bc
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 0 deletions.
46 changes: 46 additions & 0 deletions test/test_connectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import random
from parameterized import parameterized_class

from amaranth.sim import Settle

from transactron.lib import StableSelectingNetwork
from transactron.testing import TestCaseWithSimulator


@parameterized_class(
("n"),
[(2,), (3,), (7,), (8,)],
)
class TestStableSelectingNetwork(TestCaseWithSimulator):
n: int

def test(self):
m = StableSelectingNetwork(self.n, [("data", 8)])

random.seed(42)

def process():
for _ in range(100):
inputs = [random.randrange(2**8) for _ in range(self.n)]
valids = [random.randrange(2) for _ in range(self.n)]
total = sum(valids)

expected_output_prefix = []
for i in range(self.n):
yield m.valids[i].eq(valids[i])
yield m.inputs[i].data.eq(inputs[i])

if valids[i]:
expected_output_prefix.append(inputs[i])

yield Settle()

for i in range(total):
out = yield m.outputs[i].data
self.assertEqual(out, expected_output_prefix[i])

self.assertEqual((yield m.output_cnt), total)
yield

with self.run_simulation(m) as sim:
sim.add_sync_process(process)
80 changes: 80 additions & 0 deletions transactron/lib/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"Connect",
"ConnectTrans",
"ManyToOneConnectTrans",
"StableSelectingNetwork",
]


Expand Down Expand Up @@ -275,3 +276,82 @@ def elaborate(self, platform):
)

return m


class StableSelectingNetwork(Elaboratable):
"""A network that groups inputs with a valid bit set.
The circuit takes `n` inputs with a valid signal each and
on the output returns a grouped and consecutive sequence of the provided
input signals. The order of valid inputs is preserved.
For example for input (0 is an invalid input):
0, a, 0, d, 0, 0, e
The circuit will return:
a, d, e, 0, 0, 0, 0
The circuit uses a divide and conquer algorithm.
The recursive call takes two bit vectors and each of them
is already properly sorted, for example:
v1 = [a, b, 0, 0]; v2 = [c, d, e, 0]
Now by shifting left v2 and merging it with v1, we get the result:
v = [a, b, c, d, e, 0, 0, 0]
Thus, the network has depth log_2(n).
"""

def __init__(self, n: int, layout: MethodLayout):
self.n = n
self.layout = from_method_layout(layout)

self.inputs = [Signal(self.layout) for _ in range(n)]
self.valids = [Signal() for _ in range(n)]

self.outputs = [Signal(self.layout) for _ in range(n)]
self.output_cnt = Signal(range(n + 1))

def elaborate(self, platform):
m = TModule()

current_level = []
for i in range(self.n):
current_level.append((Array([self.inputs[i]]), self.valids[i]))

# Create the network using the bottom-up approach.
while len(current_level) >= 2:
next_level = []
while len(current_level) >= 2:
a, cnt_a = current_level.pop(0)
b, cnt_b = current_level.pop(0)

total_cnt = Signal(max(len(cnt_a), len(cnt_b)) + 1)
m.d.comb += total_cnt.eq(cnt_a + cnt_b)

total_len = len(a) + len(b)
merged = Array(Signal(self.layout) for _ in range(total_len))

for i in range(len(a)):
m.d.comb += merged[i].eq(Mux(cnt_a <= i, b[i - cnt_a], a[i]))
for i in range(len(b)):
m.d.comb += merged[len(a) + i].eq(Mux(len(a) + i - cnt_a >= len(b), 0, b[len(a) + i - cnt_a]))

next_level.append((merged, total_cnt))

# If we had an odd number of elements on the current level,
# move the item left to the next level.
if len(current_level) == 1:
next_level.append(current_level.pop(0))

current_level = next_level

last_level, total_cnt = current_level.pop(0)

for i in range(self.n):
m.d.comb += self.outputs[i].eq(last_level[i])

m.d.comb += self.output_cnt.eq(total_cnt)

return m

0 comments on commit 95119bc

Please sign in to comment.