From 6624fd41042587849fa10e2e7547fa49bb7b9052 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 19 Jul 2024 11:13:01 -0700 Subject: [PATCH] support unroll up to 16 --- sam/onyx/parse_dot.py | 70 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 2 deletions(-) diff --git a/sam/onyx/parse_dot.py b/sam/onyx/parse_dot.py index f5562be2..f1a68845 100644 --- a/sam/onyx/parse_dot.py +++ b/sam/onyx/parse_dot.py @@ -8,6 +8,7 @@ from sam.onyx.hw_nodes.hw_node import HWNodeType from lassen.utils import float2bfbin import json +import math class SAMDotGraphLoweringError(Exception): @@ -38,6 +39,7 @@ def __init__(self, filename=None, local_mems=True, use_fork=False, self.shared_glb = {} self.shared_stream_arb = {} self.shared_stream_arb_glb_edge = [] # Key assuming, single level stream arbiter + self.stage2_count = {} self.annotate_IO_nodes() self.graph.write_png('mek.png') @@ -932,6 +934,10 @@ def rewrite_lookup(self, unroll): nodes_to_proc = [node for node in self.graph.get_nodes() if 'fiberlookup' in node.get_comment() or 'fiberwrite' in node.get_comment()] + # print comment for each node + for node in nodes_to_proc: + print(node.get_comment()) + for node in nodes_to_proc: if 'fiberlookup' in node.get_comment(): # Rewrite this node to a read @@ -1101,7 +1107,8 @@ def rewrite_lookup(self, unroll): else: in_edge = [edge for edge in self.graph.get_edges() if edge.get_destination() == node.get_name() and "crd" in edge.get_label()][0] - if unroll > 1: + + if unroll > 1 and unroll <= 4: # create shared stream arbiter stream_arb_mode = attrs['mode'].strip('"') stream_arb_label = f"stream_arb_{stream_arb_mode}" @@ -1118,7 +1125,49 @@ def rewrite_lookup(self, unroll): hwnode=f"{HWNodeType.StreamArbiter}") self.shared_stream_arb[stream_arb_label] = stream_arb self.graph.add_node(stream_arb) + elif unroll > 4: + assert unroll <= 16 + # need to create two stages of arbiters + stage2 = math.ceil(unroll/4) + # stream arbiter for stage 1 + stream_arb_mode = attrs['mode'].strip('"') + stream_arb_label = f"stream_arb_{stream_arb_mode}_stage1" + if stream_arb_label in self.shared_stream_arb: + stream_arb_stage1 = self.shared_stream_arb[stream_arb_label] + else: + stream_arb_attr = dict() + if stream_arb_mode == 'vals': + stream_arb_attr['seg_mode'] = 0 + else: + stream_arb_attr['seg_mode'] = 1 + stream_arb_stage1 = pydot.Node(f"stream_arb_{self.get_next_seq()}_stage1", **stream_arb_attr, label=stream_arb_label, + comment=f"type=stream_arbiter,mode={stream_arb_mode}", type="stream_arbiter", + hwnode=f"{HWNodeType.StreamArbiter}") + self.shared_stream_arb[stream_arb_label] = stream_arb_stage1 + self.graph.add_node(stream_arb_stage1) + + self.stage2_count[stream_arb_stage1] = 0 + + # add stream arbiters for stage 2 + stream_arb_mode = attrs['mode'].strip('"') + + for s2_stream_arb in range(stage2): + stream_arb_label = f"stream_arb_{stream_arb_mode}_stage2_{s2_stream_arb}" + stream_arb_attr = dict() + if stream_arb_mode == 'vals': + stream_arb_attr['seg_mode'] = 0 + else: + stream_arb_attr['seg_mode'] = 1 + stream_arb = pydot.Node(stream_arb_label, **stream_arb_attr, label=stream_arb_label, + comment=f"type=stream_arbiter,mode={stream_arb_mode}", type="stream_arbiter", + hwnode=f"{HWNodeType.StreamArbiter}") + self.graph.add_node(stream_arb) + # connect edge from stage 1 to stage 2 + stream_arb_to_stream_arb = pydot.Edge(src=stream_arb, dst=stream_arb_stage1, + label=f"stream_arb_to_stream_arb_{self.get_next_seq()}", style="bold") + self.graph.add_edge(stream_arb_to_stream_arb) + # Now add the nodes and move the edges... self.graph.add_node(rd_scan) self.graph.add_node(wr_scan) @@ -1127,7 +1176,7 @@ def rewrite_lookup(self, unroll): if self.local_mems is False: self.graph.add_node(memory) - if unroll > 1: + if unroll > 1 and unroll <= 4: # RD to Stream Arb rd_to_stream_arb = pydot.Edge(src=rd_scan, dst=stream_arb, label=f"rd_to_stream_arb_{self.get_next_seq()}", style="bold") @@ -1139,6 +1188,23 @@ def rewrite_lookup(self, unroll): label=f"stream_arb_to_glb_{self.get_next_seq()}", style="bold") self.graph.add_edge(stream_arb_to_glb) self.shared_stream_arb_glb_edge.append((stream_arb, glb_read)) + elif unroll > 4: + # connect rd scan to stage 2 stream + + count = self.stage2_count[stream_arb_stage1] % stage2 + stream_arb_label = f"stream_arb_{stream_arb_mode}_stage2_{count}" + rd_to_stream_arb = pydot.Edge(src=rd_scan, dst=stream_arb_label, + label=f"rd_to_stream_arb_{self.get_next_seq()}", style="bold") + self.graph.add_edge(rd_to_stream_arb) + self.stage2_count[stream_arb_stage1] += 1 + + if (stream_arb_stage1, glb_read) not in self.shared_stream_arb_glb_edge: + # Stream Arb to GLB + stream_arb_to_glb = pydot.Edge(src=stream_arb_stage1, dst=glb_read, + label=f"stream_arb_to_glb_{self.get_next_seq()}", style="bold") + self.graph.add_edge(stream_arb_to_glb) + self.shared_stream_arb_glb_edge.append((stream_arb_stage1, glb_read)) + else: # RD to GLB rd_to_glb = pydot.Edge(src=rd_scan, dst=glb_read, label=f"glb_to_wr_{self.get_next_seq()}",