-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add selecting network circuit #621
Changes from all commits
fb2cf64
c8bfe67
afd2a42
469c3dc
e07e03a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import random | ||
from parameterized import parameterized_class | ||
|
||
from amaranth.sim import Settle | ||
|
||
from transactron.lib import StableSelectingNetwork | ||
from transactron.testing import TestCaseWithSimulator | ||
|
||
|
||
@parameterized_class( | ||
("n"), | ||
[(2,), (3,), (7,), (8,)], | ||
) | ||
class TestStableSelectingNetwork(TestCaseWithSimulator): | ||
n: int | ||
|
||
def test(self): | ||
m = StableSelectingNetwork(self.n, [("data", 8)]) | ||
|
||
random.seed(42) | ||
|
||
def process(): | ||
for _ in range(100): | ||
inputs = [random.randrange(2**8) for _ in range(self.n)] | ||
valids = [random.randrange(2) for _ in range(self.n)] | ||
total = sum(valids) | ||
|
||
expected_output_prefix = [] | ||
for i in range(self.n): | ||
yield m.valids[i].eq(valids[i]) | ||
yield m.inputs[i].data.eq(inputs[i]) | ||
|
||
if valids[i]: | ||
expected_output_prefix.append(inputs[i]) | ||
|
||
yield Settle() | ||
|
||
for i in range(total): | ||
out = yield m.outputs[i].data | ||
self.assertEqual(out, expected_output_prefix[i]) | ||
|
||
self.assertEqual((yield m.output_cnt), total) | ||
yield | ||
|
||
with self.run_simulation(m) as sim: | ||
sim.add_sync_process(process) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
"Connect", | ||
"ConnectTrans", | ||
"ManyToOneConnectTrans", | ||
"StableSelectingNetwork", | ||
] | ||
|
||
|
||
|
@@ -275,3 +276,82 @@ def elaborate(self, platform): | |
) | ||
|
||
return m | ||
|
||
|
||
class StableSelectingNetwork(Elaboratable): | ||
"""A network that groups inputs with a valid bit set. | ||
|
||
The circuit takes `n` inputs with a valid signal each and | ||
on the output returns a grouped and consecutive sequence of the provided | ||
input signals. The order of valid inputs is preserved. | ||
|
||
For example for input (0 is an invalid input): | ||
0, a, 0, d, 0, 0, e | ||
|
||
The circuit will return: | ||
a, d, e, 0, 0, 0, 0 | ||
|
||
xThaid marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The circuit uses a divide and conquer algorithm. | ||
The recursive call takes two bit vectors and each of them | ||
is already properly sorted, for example: | ||
v1 = [a, b, 0, 0]; v2 = [c, d, e, 0] | ||
|
||
Now by shifting left v2 and merging it with v1, we get the result: | ||
v = [a, b, c, d, e, 0, 0, 0] | ||
|
||
Thus, the network has depth log_2(n). | ||
|
||
""" | ||
|
||
def __init__(self, n: int, layout: MethodLayout): | ||
self.n = n | ||
self.layout = from_method_layout(layout) | ||
|
||
self.inputs = [Signal(self.layout) for _ in range(n)] | ||
self.valids = [Signal() for _ in range(n)] | ||
|
||
self.outputs = [Signal(self.layout) for _ in range(n)] | ||
self.output_cnt = Signal(range(n + 1)) | ||
|
||
def elaborate(self, platform): | ||
m = TModule() | ||
|
||
current_level = [] | ||
for i in range(self.n): | ||
current_level.append((Array([self.inputs[i]]), self.valids[i])) | ||
|
||
# Create the network using the bottom-up approach. | ||
while len(current_level) >= 2: | ||
next_level = [] | ||
while len(current_level) >= 2: | ||
a, cnt_a = current_level.pop(0) | ||
b, cnt_b = current_level.pop(0) | ||
|
||
total_cnt = Signal(max(len(cnt_a), len(cnt_b)) + 1) | ||
m.d.comb += total_cnt.eq(cnt_a + cnt_b) | ||
|
||
total_len = len(a) + len(b) | ||
merged = Array(Signal(self.layout) for _ in range(total_len)) | ||
|
||
for i in range(len(a)): | ||
m.d.comb += merged[i].eq(Mux(cnt_a <= i, b[i - cnt_a], a[i])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This generates a lot of subtractors. Why you doesn't use shifts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it is a problem. These are not proper substractions. In fact, the synthesis tool can permutate inputs of the mux and simply address it using Shifting by variable number of bits in particular won't be better than a mux. |
||
for i in range(len(b)): | ||
m.d.comb += merged[len(a) + i].eq(Mux(len(a) + i - cnt_a >= len(b), 0, b[len(a) + i - cnt_a])) | ||
|
||
next_level.append((merged, total_cnt)) | ||
|
||
# If we had an odd number of elements on the current level, | ||
# move the item left to the next level. | ||
if len(current_level) == 1: | ||
next_level.append(current_level.pop(0)) | ||
|
||
current_level = next_level | ||
|
||
last_level, total_cnt = current_level.pop(0) | ||
|
||
for i in range(self.n): | ||
m.d.comb += self.outputs[i].eq(last_level[i]) | ||
|
||
m.d.comb += self.output_cnt.eq(total_cnt) | ||
|
||
return m |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be
("n",)
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out - it should. However, apparently this decorator correctly handles this case: