Skip to content

Commit

Permalink
support unroll up to 16
Browse files Browse the repository at this point in the history
  • Loading branch information
kalhankoul96 committed Jul 19, 2024
1 parent cedc82c commit 6624fd4
Showing 1 changed file with 68 additions and 2 deletions.
70 changes: 68 additions & 2 deletions sam/onyx/parse_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sam.onyx.hw_nodes.hw_node import HWNodeType
from lassen.utils import float2bfbin
import json
import math


class SAMDotGraphLoweringError(Exception):
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(self, filename=None, local_mems=True, use_fork=False,
self.shared_glb = {}
self.shared_stream_arb = {}
self.shared_stream_arb_glb_edge = [] # Key assuming, single level stream arbiter
self.stage2_count = {}

self.annotate_IO_nodes()
self.graph.write_png('mek.png')
Expand Down Expand Up @@ -932,6 +934,10 @@ def rewrite_lookup(self, unroll):
nodes_to_proc = [node for node in self.graph.get_nodes() if 'fiberlookup' in node.get_comment() or
'fiberwrite' in node.get_comment()]

# print comment for each node
for node in nodes_to_proc:
print(node.get_comment())

for node in nodes_to_proc:
if 'fiberlookup' in node.get_comment():
# Rewrite this node to a read
Expand Down Expand Up @@ -1101,7 +1107,8 @@ def rewrite_lookup(self, unroll):
else:
in_edge = [edge for edge in self.graph.get_edges() if edge.get_destination() == node.get_name() and
"crd" in edge.get_label()][0]
if unroll > 1:

if unroll > 1 and unroll <= 4:
# create shared stream arbiter
stream_arb_mode = attrs['mode'].strip('"')
stream_arb_label = f"stream_arb_{stream_arb_mode}"
Expand All @@ -1118,7 +1125,49 @@ def rewrite_lookup(self, unroll):
hwnode=f"{HWNodeType.StreamArbiter}")
self.shared_stream_arb[stream_arb_label] = stream_arb
self.graph.add_node(stream_arb)
elif unroll > 4:
assert unroll <= 16
# need to create two stages of arbiters
stage2 = math.ceil(unroll/4)

# stream arbiter for stage 1
stream_arb_mode = attrs['mode'].strip('"')
stream_arb_label = f"stream_arb_{stream_arb_mode}_stage1"
if stream_arb_label in self.shared_stream_arb:
stream_arb_stage1 = self.shared_stream_arb[stream_arb_label]
else:
stream_arb_attr = dict()
if stream_arb_mode == 'vals':
stream_arb_attr['seg_mode'] = 0
else:
stream_arb_attr['seg_mode'] = 1
stream_arb_stage1 = pydot.Node(f"stream_arb_{self.get_next_seq()}_stage1", **stream_arb_attr, label=stream_arb_label,
comment=f"type=stream_arbiter,mode={stream_arb_mode}", type="stream_arbiter",
hwnode=f"{HWNodeType.StreamArbiter}")
self.shared_stream_arb[stream_arb_label] = stream_arb_stage1
self.graph.add_node(stream_arb_stage1)

self.stage2_count[stream_arb_stage1] = 0

# add stream arbiters for stage 2
stream_arb_mode = attrs['mode'].strip('"')

for s2_stream_arb in range(stage2):
stream_arb_label = f"stream_arb_{stream_arb_mode}_stage2_{s2_stream_arb}"
stream_arb_attr = dict()
if stream_arb_mode == 'vals':
stream_arb_attr['seg_mode'] = 0
else:
stream_arb_attr['seg_mode'] = 1
stream_arb = pydot.Node(stream_arb_label, **stream_arb_attr, label=stream_arb_label,
comment=f"type=stream_arbiter,mode={stream_arb_mode}", type="stream_arbiter",
hwnode=f"{HWNodeType.StreamArbiter}")
self.graph.add_node(stream_arb)
# connect edge from stage 1 to stage 2
stream_arb_to_stream_arb = pydot.Edge(src=stream_arb, dst=stream_arb_stage1,
label=f"stream_arb_to_stream_arb_{self.get_next_seq()}", style="bold")
self.graph.add_edge(stream_arb_to_stream_arb)

# Now add the nodes and move the edges...
self.graph.add_node(rd_scan)
self.graph.add_node(wr_scan)
Expand All @@ -1127,7 +1176,7 @@ def rewrite_lookup(self, unroll):
if self.local_mems is False:
self.graph.add_node(memory)

if unroll > 1:
if unroll > 1 and unroll <= 4:
# RD to Stream Arb
rd_to_stream_arb = pydot.Edge(src=rd_scan, dst=stream_arb,
label=f"rd_to_stream_arb_{self.get_next_seq()}", style="bold")
Expand All @@ -1139,6 +1188,23 @@ def rewrite_lookup(self, unroll):
label=f"stream_arb_to_glb_{self.get_next_seq()}", style="bold")
self.graph.add_edge(stream_arb_to_glb)
self.shared_stream_arb_glb_edge.append((stream_arb, glb_read))
elif unroll > 4:
# connect rd scan to stage 2 stream

count = self.stage2_count[stream_arb_stage1] % stage2
stream_arb_label = f"stream_arb_{stream_arb_mode}_stage2_{count}"
rd_to_stream_arb = pydot.Edge(src=rd_scan, dst=stream_arb_label,
label=f"rd_to_stream_arb_{self.get_next_seq()}", style="bold")
self.graph.add_edge(rd_to_stream_arb)
self.stage2_count[stream_arb_stage1] += 1

if (stream_arb_stage1, glb_read) not in self.shared_stream_arb_glb_edge:
# Stream Arb to GLB
stream_arb_to_glb = pydot.Edge(src=stream_arb_stage1, dst=glb_read,
label=f"stream_arb_to_glb_{self.get_next_seq()}", style="bold")
self.graph.add_edge(stream_arb_to_glb)
self.shared_stream_arb_glb_edge.append((stream_arb_stage1, glb_read))

else:
# RD to GLB
rd_to_glb = pydot.Edge(src=rd_scan, dst=glb_read, label=f"glb_to_wr_{self.get_next_seq()}",
Expand Down

0 comments on commit 6624fd4

Please sign in to comment.