Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add selecting network circuit #621

Merged
merged 5 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions test/transactron/test_connectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import random
from parameterized import parameterized_class

from amaranth.sim import Settle

from transactron.lib import StableSelectingNetwork
from transactron.testing import TestCaseWithSimulator


@parameterized_class(
("n"),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be ("n",)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out - it should. However, apparently this decorator correctly handles this case:

    if isinstance(attrs, string_types):
        attrs = [attrs]

[(2,), (3,), (7,), (8,)],
)
class TestStableSelectingNetwork(TestCaseWithSimulator):
n: int

def test(self):
m = StableSelectingNetwork(self.n, [("data", 8)])

random.seed(42)

def process():
for _ in range(100):
inputs = [random.randrange(2**8) for _ in range(self.n)]
valids = [random.randrange(2) for _ in range(self.n)]
total = sum(valids)

expected_output_prefix = []
for i in range(self.n):
yield m.valids[i].eq(valids[i])
yield m.inputs[i].data.eq(inputs[i])

if valids[i]:
expected_output_prefix.append(inputs[i])

yield Settle()

for i in range(total):
out = yield m.outputs[i].data
self.assertEqual(out, expected_output_prefix[i])

self.assertEqual((yield m.output_cnt), total)
yield

with self.run_simulation(m) as sim:
sim.add_sync_process(process)
80 changes: 80 additions & 0 deletions transactron/lib/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"Connect",
"ConnectTrans",
"ManyToOneConnectTrans",
"StableSelectingNetwork",
]


Expand Down Expand Up @@ -275,3 +276,82 @@ def elaborate(self, platform):
)

return m


class StableSelectingNetwork(Elaboratable):
"""A network that groups inputs with a valid bit set.

The circuit takes `n` inputs with a valid signal each and
on the output returns a grouped and consecutive sequence of the provided
input signals. The order of valid inputs is preserved.

For example for input (0 is an invalid input):
0, a, 0, d, 0, 0, e

The circuit will return:
a, d, e, 0, 0, 0, 0

xThaid marked this conversation as resolved.
Show resolved Hide resolved
The circuit uses a divide and conquer algorithm.
The recursive call takes two bit vectors and each of them
is already properly sorted, for example:
v1 = [a, b, 0, 0]; v2 = [c, d, e, 0]

Now by shifting left v2 and merging it with v1, we get the result:
v = [a, b, c, d, e, 0, 0, 0]

Thus, the network has depth log_2(n).

"""

def __init__(self, n: int, layout: MethodLayout):
self.n = n
self.layout = from_method_layout(layout)

self.inputs = [Signal(self.layout) for _ in range(n)]
self.valids = [Signal() for _ in range(n)]

self.outputs = [Signal(self.layout) for _ in range(n)]
self.output_cnt = Signal(range(n + 1))

def elaborate(self, platform):
m = TModule()

current_level = []
for i in range(self.n):
current_level.append((Array([self.inputs[i]]), self.valids[i]))

# Create the network using the bottom-up approach.
while len(current_level) >= 2:
next_level = []
while len(current_level) >= 2:
a, cnt_a = current_level.pop(0)
b, cnt_b = current_level.pop(0)

total_cnt = Signal(max(len(cnt_a), len(cnt_b)) + 1)
m.d.comb += total_cnt.eq(cnt_a + cnt_b)

total_len = len(a) + len(b)
merged = Array(Signal(self.layout) for _ in range(total_len))

for i in range(len(a)):
m.d.comb += merged[i].eq(Mux(cnt_a <= i, b[i - cnt_a], a[i]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generates a lot of subtractors. Why you doesn't use shifts?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is a problem. These are not proper substractions. In fact, the synthesis tool can permutate inputs of the mux and simply address it using cnt_a.

Shifting by variable number of bits in particular won't be better than a mux.

for i in range(len(b)):
m.d.comb += merged[len(a) + i].eq(Mux(len(a) + i - cnt_a >= len(b), 0, b[len(a) + i - cnt_a]))

next_level.append((merged, total_cnt))

# If we had an odd number of elements on the current level,
# move the item left to the next level.
if len(current_level) == 1:
next_level.append(current_level.pop(0))

current_level = next_level

last_level, total_cnt = current_level.pop(0)

for i in range(self.n):
m.d.comb += self.outputs[i].eq(last_level[i])

m.d.comb += self.output_cnt.eq(total_cnt)

return m