diff --git a/sam/onyx/hw_nodes/fiberaccess_node.py b/sam/onyx/hw_nodes/fiberaccess_node.py index 3b82d5c4..1983e22e 100644 --- a/sam/onyx/hw_nodes/fiberaccess_node.py +++ b/sam/onyx/hw_nodes/fiberaccess_node.py @@ -83,6 +83,7 @@ def connect(self, other, edge, kwargs=None): from sam.onyx.hw_nodes.merge_node import MergeNode from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode from sam.onyx.hw_nodes.stream_arbiter_node import StreamArbiterNode + from sam.onyx.hw_nodes.pass_through_node import PassThroughNode new_conns = None other_type = type(other) @@ -230,6 +231,16 @@ def connect(self, other, edge, kwargs=None): print(init_conns) final_conns = self.remap_conns(init_conns, kwargs['flavor_this']) return final_conns + elif other_type == PassThroughNode: + assert kwargs is not None + assert 'flavor_this' in kwargs + this_flavor = self.get_flavor(kwargs['flavor_this']) + print(kwargs) + print("FIBER ACCESS TO Pass Through") + init_conns = this_flavor.connect(other, edge) + print(init_conns) + final_conns = self.remap_conns(init_conns, kwargs['flavor_this']) + return final_conns else: raise NotImplementedError(f'Cannot connect FiberAccessNode to {other_type}') diff --git a/sam/onyx/hw_nodes/intersect_node.py b/sam/onyx/hw_nodes/intersect_node.py index 6a7f8445..d52855e2 100644 --- a/sam/onyx/hw_nodes/intersect_node.py +++ b/sam/onyx/hw_nodes/intersect_node.py @@ -32,6 +32,7 @@ def connect(self, other, edge, kwargs=None): from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode from sam.onyx.hw_nodes.fiberaccess_node import FiberAccessNode + from sam.onyx.hw_nodes.pass_through_node import PassThroughNode new_conns = None isect = self.get_name() @@ -221,7 +222,32 @@ def connect(self, other, edge, kwargs=None): print(init_conns) final_conns = other.remap_conns(init_conns, kwargs['flavor_that']) return final_conns + elif other_type == PassThroughNode: + pass_through = other.get_name() + comment = edge.get_attributes()['comment'].strip('"') + try: + tensor = comment.split("-")[1] + except Exception: + try: + tensor = comment.split("_")[1] + except Exception: + tensor = comment + edge_type = edge.get_attributes()['type'].strip('"') + if 'crd' in edge_type: + new_conns = { + f'isect_to_isect': [ + ([(isect, f"coord_out"), (pass_through, "stream_in")], 17), + ] + } + elif 'ref' in edge_type: + isect_conn = self.get_connection_from_tensor(tensor) + new_conns = { + f'isect_to_isect': [ + ([(isect, f"pos_out_{isect_conn}"), (pass_through, "stream_in")], 17), + ] + } + return new_conns else: raise NotImplementedError(f'Cannot connect IntersectNode to {other_type}') diff --git a/sam/onyx/hw_nodes/merge_node.py b/sam/onyx/hw_nodes/merge_node.py index 3b2e5fd2..4bbbc2a6 100644 --- a/sam/onyx/hw_nodes/merge_node.py +++ b/sam/onyx/hw_nodes/merge_node.py @@ -33,6 +33,7 @@ def connect(self, other, edge, kwargs=None): from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode from sam.onyx.hw_nodes.fiberaccess_node import FiberAccessNode + from sam.onyx.hw_nodes.pass_through_node import PassThroughNode new_conns = None other_type = type(other) @@ -106,6 +107,25 @@ def connect(self, other, edge, kwargs=None): ([(merge, f"coord_out_{out_conn}"), (other_merge, f"coord_in_{in_conn}")], 17), ] } + return new_conns + elif other_type == PassThroughNode: + pass_through = other.get_name() + # Use inner to process outer + comment = edge.get_attributes()['comment'].strip('"') + tensor_lvl = None + if self.get_inner() in comment: + out_conn = 0 + tensor_lvl = self.get_inner() + else: + out_conn = 1 + tensor_lvl = self.get_outer() + + new_conns = { + f'merger_to_merger_{out_conn}_to_pass_through': [ + ([(merge, f"coord_out_{out_conn}"), (pass_through, "stream_in")], 17), + ] + } + return new_conns elif other_type == RepeatNode: raise NotImplementedError(f'Cannot connect MergeNode to {other_type}') diff --git a/sam/onyx/hw_nodes/pass_through_node.py b/sam/onyx/hw_nodes/pass_through_node.py index c6dbea2a..9debf377 100644 --- a/sam/onyx/hw_nodes/pass_through_node.py +++ b/sam/onyx/hw_nodes/pass_through_node.py @@ -2,9 +2,15 @@ class PassThroughNode(HWNode): - def __init__(self, name=None) -> None: + def __init__(self, name=None, conn_to_tensor=None) -> None: super().__init__(name=name) + self.conn_to_tensor = conn_to_tensor + self.tensor_to_conn = {} + if conn_to_tensor is not None: + for conn, tensor in self.conn_to_tensor.items(): + self.tensor_to_conn[tensor] = conn + def connect(self, other, edge, kwargs=None): from sam.onyx.hw_nodes.broadcast_node import BroadcastNode @@ -27,6 +33,7 @@ def connect(self, other, edge, kwargs=None): pass_through = self.get_name() other_type = type(other) + print(other_type) if other_type == WriteScannerNode: wr_scan = other.get_name() @@ -36,21 +43,117 @@ def connect(self, other, edge, kwargs=None): ] } return new_conns + elif other_type == ReadScannerNode: + print("PASSTHORUGH TO REPEAT EDGE!") + rd_scan = other.get_name() + new_conns = { + 'pass_through_to_rd_scan': [ + ([(pass_through, "stream_out"), (rd_scan, f"us_pos_in")], 17), + ] + } + return new_conns + elif other_type == RepeatNode: + repeat = other.get_name() + print("PASSTHROUGH TO REPEAT EDGE!") + new_conns = { + 'pass_through_to_repeat': [ + # send output to rd scanner + ([(pass_through, "stream_out"), (repeat, "proc_data_in")], 17), + ] + } + return new_conns + elif other_type == IntersectNode: + comment = edge.get_attributes()['comment'].strip('"') + try: + tensor = comment.split("-")[1] + except Exception: + try: + tensor = comment.split("_")[1] + except Exception: + tensor = comment + + other_isect = other.get_name() + isect_conn = self.get_connection_from_tensor(tensor) + other_isect_conn = other.get_connection_from_tensor(tensor) + + edge_type = edge.get_attributes()['type'].strip('"') + + if 'crd' in edge_type: + new_conns = { + f'pass_through_to_isect': [ + ([(pass_through, "stream_out"), (other_isect, f"coord_in_{other_isect_conn}")], 17), + ] + } + elif 'ref' in edge_type: + new_conns = { + f'pass_through_to_isect': [ + ([(pass_through, "stream_out"), (other_isect, f"pos_in_{other_isect_conn}")], 17), + ] + } + return new_conns + + elif other_type == MergeNode: + edge_attr = edge.get_attributes() + crddrop = other.get_name() + print("CHECKING READ TENSOR - CRDDROP") + print(edge) + crd_drop_outer = other.get_outer() + comment = edge_attr['comment'].strip('"') + conn = 0 + # okay this is dumb, stopgap until we can have super consistent output + try: + mapped_to_conn = comment.split("-")[1] + except Exception: + try: + mapped_to_conn = comment.split("_")[1] + except Exception: + mapped_to_conn = comment + if crd_drop_outer in mapped_to_conn: + conn = 1 + + if 'use_alt_out_port' in edge_attr: + out_conn = 'block_rd_out' + elif ('vector_reduce_mode' in edge_attr): + if (edge_attr['vector_reduce_mode']): + out_conn = 'pos_out' + else: + out_conn = 'coord_out' + + new_conns = { + f'rd_scan_to_crddrop_{conn}': [ + ([(pass_through, "stream_out"), (crddrop, f"coord_in_{conn}")], 17), + ] + } + + return new_conns + elif other_type == RepSigGenNode: + rsg = other.get_name() + new_conns = { + f'pass_through_to_rsg': [ + ([(pass_through, "stream_out"), (rsg, f"base_data_in")], 17), + ] + } elif other_type == FiberAccessNode: - # Only could be using the write scanner portion of the fiber access # fa = other.get_name() - conns_original = self.connect(other.get_write_scanner(), edge=edge) - print(conns_original) - conns_remapped = other.remap_conns(conns_original, "write_scanner") - print(conns_remapped) - - return conns_remapped + print("PASSTHROUGH TO FIBER ACCESS") + assert kwargs is not None + assert 'flavor_that' in kwargs + that_flavor = other.get_flavor(kwargs['flavor_that']) + print(kwargs) + init_conns = self.connect(that_flavor, edge) + print(init_conns) + final_conns = other.remap_conns(init_conns, kwargs['flavor_that']) + return final_conns else: - raise NotImplementedError(f'Cannot connect GLBNode to {other_type}') + raise NotImplementedError(f'Cannot connect Pass Through Node to {other_type}') return new_conns + def get_connection_from_tensor(self, tensor): + print(self.tensor_to_conn) + return self.tensor_to_conn[tensor] + def update_input_connections(self): self.num_inputs_connected += 1 diff --git a/sam/onyx/hw_nodes/read_scanner_node.py b/sam/onyx/hw_nodes/read_scanner_node.py index 3091f058..32e90f14 100644 --- a/sam/onyx/hw_nodes/read_scanner_node.py +++ b/sam/onyx/hw_nodes/read_scanner_node.py @@ -44,6 +44,7 @@ def connect(self, other, edge, kwargs=None): from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode from sam.onyx.hw_nodes.stream_arbiter_node import StreamArbiterNode + from sam.onyx.hw_nodes.pass_through_node import PassThroughNode new_conns = None rd_scan = self.get_name() @@ -268,6 +269,28 @@ def connect(self, other, edge, kwargs=None): } other.update_input_connections() return new_conns + elif other_type == PassThroughNode: + pass_through = other.get_name() + e_attr = edge.get_attributes() + e_type = e_attr['type'].strip('"') + if "crd" in e_type: + new_conns = { + f'rd_scan_to_pass_through_crd': [ + # send output to rd scanner + ([(rd_scan, "coord_out"), (pass_through, "stream_in")], 17), + ] + } + elif 'ref' in e_type: + rd_scan_out_port = "pos_out" + if 'val' in e_attr and e_attr['val'].strip('"') == 'true': + rd_scan_out_port = "coord_out" + new_conns = { + f'rd_scan_to_pass_through_pos': [ + # send output to rd scanner + ([(rd_scan, rd_scan_out_port), (pass_through, "stream_in")], 17), + ] + } + return new_conns else: raise NotImplementedError(f'Cannot connect ReadScannerNode to {other_type}') diff --git a/sam/onyx/parse_dot.py b/sam/onyx/parse_dot.py index f5562be2..ecf3f39f 100644 --- a/sam/onyx/parse_dot.py +++ b/sam/onyx/parse_dot.py @@ -8,6 +8,7 @@ from sam.onyx.hw_nodes.hw_node import HWNodeType from lassen.utils import float2bfbin import json +import math class SAMDotGraphLoweringError(Exception): @@ -38,6 +39,7 @@ def __init__(self, filename=None, local_mems=True, use_fork=False, self.shared_glb = {} self.shared_stream_arb = {} self.shared_stream_arb_glb_edge = [] # Key assuming, single level stream arbiter + self.stage2_count = {} self.annotate_IO_nodes() self.graph.write_png('mek.png') @@ -932,6 +934,10 @@ def rewrite_lookup(self, unroll): nodes_to_proc = [node for node in self.graph.get_nodes() if 'fiberlookup' in node.get_comment() or 'fiberwrite' in node.get_comment()] + # print comment for each node + for node in nodes_to_proc: + print(node.get_comment()) + for node in nodes_to_proc: if 'fiberlookup' in node.get_comment(): # Rewrite this node to a read @@ -1101,7 +1107,8 @@ def rewrite_lookup(self, unroll): else: in_edge = [edge for edge in self.graph.get_edges() if edge.get_destination() == node.get_name() and "crd" in edge.get_label()][0] - if unroll > 1: + + if unroll > 1 and unroll <= 4: # create shared stream arbiter stream_arb_mode = attrs['mode'].strip('"') stream_arb_label = f"stream_arb_{stream_arb_mode}" @@ -1118,6 +1125,59 @@ def rewrite_lookup(self, unroll): hwnode=f"{HWNodeType.StreamArbiter}") self.shared_stream_arb[stream_arb_label] = stream_arb self.graph.add_node(stream_arb) + elif unroll > 4: + assert unroll <= 16 + # need to create two stages of arbiters + stage2 = math.ceil(unroll / 4) + + # stream arbiter for stage 1 + stream_arb_mode = attrs['mode'].strip('"') + stream_arb_label = f"stream_arb_{stream_arb_mode}_stage1" + if stream_arb_label in self.shared_stream_arb: + stream_arb_stage1 = self.shared_stream_arb[stream_arb_label] + else: + stream_arb_attr = dict() + if stream_arb_mode == 'vals': + stream_arb_attr['seg_mode'] = 0 + else: + stream_arb_attr['seg_mode'] = 1 + stream_arb_stage1 = pydot.Node( + f"stream_arb_{self.get_next_seq()}_stage1", + **stream_arb_attr, + label=stream_arb_label, + comment=f"type=stream_arbiter,mode={stream_arb_mode}", + type="stream_arbiter", + hwnode=f"{HWNodeType.StreamArbiter}") + self.shared_stream_arb[stream_arb_label] = stream_arb_stage1 + self.graph.add_node(stream_arb_stage1) + + self.stage2_count[stream_arb_stage1] = 0 + + # add stream arbiters for stage 2 + stream_arb_mode = attrs['mode'].strip('"') + + for s2_stream_arb in range(stage2): + stream_arb_label = f"stream_arb_{stream_arb_mode}_stage2_{s2_stream_arb}" + stream_arb_attr = dict() + if stream_arb_mode == 'vals': + stream_arb_attr['seg_mode'] = 0 + else: + stream_arb_attr['seg_mode'] = 1 + stream_arb = pydot.Node( + stream_arb_label, + **stream_arb_attr, + label=stream_arb_label, + comment=f"type=stream_arbiter,mode={stream_arb_mode}", + type="stream_arbiter", + hwnode=f"{HWNodeType.StreamArbiter}") + self.graph.add_node(stream_arb) + # connect edge from stage 1 to stage 2 + stream_arb_to_stream_arb = pydot.Edge( + src=stream_arb, + dst=stream_arb_stage1, + label=f"stream_arb_to_stream_arb_{self.get_next_seq()}", + style="bold") + self.graph.add_edge(stream_arb_to_stream_arb) # Now add the nodes and move the edges... self.graph.add_node(rd_scan) @@ -1127,7 +1187,7 @@ def rewrite_lookup(self, unroll): if self.local_mems is False: self.graph.add_node(memory) - if unroll > 1: + if unroll > 1 and unroll <= 4: # RD to Stream Arb rd_to_stream_arb = pydot.Edge(src=rd_scan, dst=stream_arb, label=f"rd_to_stream_arb_{self.get_next_seq()}", style="bold") @@ -1139,6 +1199,23 @@ def rewrite_lookup(self, unroll): label=f"stream_arb_to_glb_{self.get_next_seq()}", style="bold") self.graph.add_edge(stream_arb_to_glb) self.shared_stream_arb_glb_edge.append((stream_arb, glb_read)) + elif unroll > 4: + # connect rd scan to stage 2 stream + + count = self.stage2_count[stream_arb_stage1] % stage2 + stream_arb_label = f"stream_arb_{stream_arb_mode}_stage2_{count}" + rd_to_stream_arb = pydot.Edge(src=rd_scan, dst=stream_arb_label, + label=f"rd_to_stream_arb_{self.get_next_seq()}", style="bold") + self.graph.add_edge(rd_to_stream_arb) + self.stage2_count[stream_arb_stage1] += 1 + + if (stream_arb_stage1, glb_read) not in self.shared_stream_arb_glb_edge: + # Stream Arb to GLB + stream_arb_to_glb = pydot.Edge(src=stream_arb_stage1, dst=glb_read, + label=f"stream_arb_to_glb_{self.get_next_seq()}", style="bold") + self.graph.add_edge(stream_arb_to_glb) + self.shared_stream_arb_glb_edge.append((stream_arb_stage1, glb_read)) + else: # RD to GLB rd_to_glb = pydot.Edge(src=rd_scan, dst=glb_read, label=f"glb_to_wr_{self.get_next_seq()}",