diff --git a/sam/onyx/generate_matrices.py b/sam/onyx/generate_matrices.py index 5fdeaa27..f77164b3 100644 --- a/sam/onyx/generate_matrices.py +++ b/sam/onyx/generate_matrices.py @@ -489,7 +489,7 @@ def create_matrix_from_point_list(name, pt_list, shape, use_fp=False) -> MatrixG return mg -def convert_aha_glb_output_file(glbfile, output_dir, tiles, batches): +def convert_aha_glb_output_file(glbfile, output_dir, tiles, batches, glb_mem_stride=500): glbfile_s = os.path.basename(glbfile).rstrip(".txt") @@ -531,7 +531,9 @@ def convert_aha_glb_output_file(glbfile, output_dir, tiles, batches): tile = 0 batch = 0 block = 0 + cur_base = 0 for file_path in files: + assert sl_ptr < len(straightline) num_items = straightline[sl_ptr] sl_ptr += 1 with open(file_path, "w+") as fh_: @@ -546,11 +548,13 @@ def convert_aha_glb_output_file(glbfile, output_dir, tiles, batches): if block == num_blocks: block = 0 tile += 1 + sl_ptr = cur_base + tile * glb_mem_stride if tile == tiles: tile = 0 batch = batch + 1 - # TODO hardcoded value for now + # TODO hardcoded value for now, need to consider a larger case sl_ptr = 32768 * batch # size of glb + cur_base = 32768 * batch def find_file_based_on_sub_string(files_dir, sub_string_list): diff --git a/sam/onyx/hw_nodes/fiberaccess_node.py b/sam/onyx/hw_nodes/fiberaccess_node.py index 7e7a3722..3b82d5c4 100644 --- a/sam/onyx/hw_nodes/fiberaccess_node.py +++ b/sam/onyx/hw_nodes/fiberaccess_node.py @@ -82,6 +82,7 @@ def connect(self, other, edge, kwargs=None): from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode from sam.onyx.hw_nodes.merge_node import MergeNode from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode + from sam.onyx.hw_nodes.stream_arbiter_node import StreamArbiterNode new_conns = None other_type = type(other) @@ -219,6 +220,16 @@ def connect(self, other, edge, kwargs=None): final_conns_2 = other.remap_conns(final_conns_1, kwargs['flavor_that']) print(final_conns_2) return final_conns_2 + elif other_type == StreamArbiterNode: + assert kwargs is not None + assert 'flavor_this' in kwargs + this_flavor = self.get_flavor(kwargs['flavor_this']) + print(kwargs) + print("FIBER ACCESS TO Stream Arbiter") + init_conns = this_flavor.connect(other, edge) + print(init_conns) + final_conns = self.remap_conns(init_conns, kwargs['flavor_this']) + return final_conns else: raise NotImplementedError(f'Cannot connect FiberAccessNode to {other_type}') diff --git a/sam/onyx/hw_nodes/glb_node.py b/sam/onyx/hw_nodes/glb_node.py index ed0df936..2dc3b55b 100644 --- a/sam/onyx/hw_nodes/glb_node.py +++ b/sam/onyx/hw_nodes/glb_node.py @@ -3,7 +3,7 @@ class GLBNode(HWNode): def __init__(self, name=None, data=None, valid=None, ready=None, - direction=None, num_blocks=None, file_number=None, tx_size=None, IO_id=0, + direction=None, num_blocks=None, seg_mode=None, file_number=None, tx_size=None, IO_id=0, bespoke=False, tensor=None, mode=None, format=None) -> None: super().__init__(name=name) @@ -12,6 +12,7 @@ def __init__(self, name=None, data=None, valid=None, ready=None, self.ready = ready self.direction = direction self.num_blocks = num_blocks + self.seg_mode = seg_mode self.file_number = file_number self.tx_size = tx_size self.IO_id = IO_id @@ -37,6 +38,9 @@ def get_tx_size(self): def get_num_blocks(self): return self.num_blocks + def get_seg_mode(self): + return self.seg_mode + def get_data(self): return self.data @@ -74,6 +78,7 @@ def connect(self, other, edge, kwargs=None): from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode from sam.onyx.hw_nodes.fiberaccess_node import FiberAccessNode + from sam.onyx.hw_nodes.pass_through_node import PassThroughNode other_type = type(other) @@ -120,6 +125,14 @@ def connect(self, other, edge, kwargs=None): print(conns_remapped) return conns_remapped + elif other_type == PassThroughNode: + pass_through = other.get_name() + new_conns = { + 'glb_to_pass_through': [ + ([(self.data, "io2f_17"), (pass_through, "stream_in")], 17), + ] + } + return new_conns else: raise NotImplementedError(f'Cannot connect GLBNode to {other_type}') diff --git a/sam/onyx/hw_nodes/hw_node.py b/sam/onyx/hw_nodes/hw_node.py index 9b7d8847..d6af9d13 100644 --- a/sam/onyx/hw_nodes/hw_node.py +++ b/sam/onyx/hw_nodes/hw_node.py @@ -18,6 +18,8 @@ class HWNodeType(Enum): CrdHold = 14 VectorReducer = 15 FiberAccess = 16 + StreamArbiter = 17 + PassThrough = 18 class HWNode(): diff --git a/sam/onyx/hw_nodes/pass_through_node.py b/sam/onyx/hw_nodes/pass_through_node.py new file mode 100644 index 00000000..c6dbea2a --- /dev/null +++ b/sam/onyx/hw_nodes/pass_through_node.py @@ -0,0 +1,68 @@ +from sam.onyx.hw_nodes.hw_node import * + + +class PassThroughNode(HWNode): + def __init__(self, name=None) -> None: + super().__init__(name=name) + + def connect(self, other, edge, kwargs=None): + + from sam.onyx.hw_nodes.broadcast_node import BroadcastNode + from sam.onyx.hw_nodes.compute_node import ComputeNode + from sam.onyx.hw_nodes.glb_node import GLBNode + from sam.onyx.hw_nodes.buffet_node import BuffetNode + from sam.onyx.hw_nodes.memory_node import MemoryNode + from sam.onyx.hw_nodes.read_scanner_node import ReadScannerNode + from sam.onyx.hw_nodes.write_scanner_node import WriteScannerNode + from sam.onyx.hw_nodes.intersect_node import IntersectNode + from sam.onyx.hw_nodes.reduce_node import ReduceNode + from sam.onyx.hw_nodes.lookup_node import LookupNode + from sam.onyx.hw_nodes.merge_node import MergeNode + from sam.onyx.hw_nodes.repeat_node import RepeatNode + from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode + from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode + from sam.onyx.hw_nodes.fiberaccess_node import FiberAccessNode + + new_conns = None + pass_through = self.get_name() + + other_type = type(other) + + if other_type == WriteScannerNode: + wr_scan = other.get_name() + new_conns = { + 'pass_through_to_wr_scan': [ + ([(pass_through, "stream_out"), (wr_scan, "block_wr_in")], 17), + ] + } + return new_conns + elif other_type == FiberAccessNode: + # Only could be using the write scanner portion of the fiber access + # fa = other.get_name() + conns_original = self.connect(other.get_write_scanner(), edge=edge) + print(conns_original) + conns_remapped = other.remap_conns(conns_original, "write_scanner") + print(conns_remapped) + + return conns_remapped + + else: + raise NotImplementedError(f'Cannot connect GLBNode to {other_type}') + + return new_conns + + def update_input_connections(self): + self.num_inputs_connected += 1 + + def get_num_inputs(self): + return self.num_inputs_connected + + def configure(self, attributes): + # print("Pass Through CONFIGURE") + # print(attributes) + + placeholder = 1 + cfg_kwargs = { + 'placeholder': placeholder + } + return (placeholder), cfg_kwargs diff --git a/sam/onyx/hw_nodes/read_scanner_node.py b/sam/onyx/hw_nodes/read_scanner_node.py index 1b0b97a0..3091f058 100644 --- a/sam/onyx/hw_nodes/read_scanner_node.py +++ b/sam/onyx/hw_nodes/read_scanner_node.py @@ -43,6 +43,7 @@ def connect(self, other, edge, kwargs=None): from sam.onyx.hw_nodes.repeat_node import RepeatNode from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode + from sam.onyx.hw_nodes.stream_arbiter_node import StreamArbiterNode new_conns = None rd_scan = self.get_name() @@ -255,6 +256,17 @@ def connect(self, other, edge, kwargs=None): ] } + return new_conns + elif other_type == StreamArbiterNode: + cur_inputs = other.get_num_inputs() + assert cur_inputs <= other.max_num_inputs - 1, f"Cannot connect ReadScannerNode to {other_type}, too many inputs" + down_stream_arb = other.get_name() + new_conns = { + f'rd_scan_to_stream_arbiter_{cur_inputs}': [ + ([(rd_scan, "block_rd_out"), (down_stream_arb, f"stream_in_{cur_inputs}")], 17), + ] + } + other.update_input_connections() return new_conns else: raise NotImplementedError(f'Cannot connect ReadScannerNode to {other_type}') @@ -313,6 +325,13 @@ def configure(self, attributes): else: vr_mode = 0 + glb_addr_base = 0 + glb_addr_stride = 0 + if 'glb_addr_base' in attributes: + glb_addr_base = int(attributes['glb_addr_base']) + if 'glb_addr_stride' in attributes: + glb_addr_stride = int(attributes['glb_addr_stride']) + cfg_kwargs = { 'dense': dense, 'dim_size': dim_size, @@ -328,7 +347,9 @@ def configure(self, attributes): 'block_mode': block_mode, 'lookup': lookup, # 'spacc_mode': spacc_mode - 'vr_mode': vr_mode + 'vr_mode': vr_mode, + 'glb_addr_base': glb_addr_base, + 'glb_addr_stride': glb_addr_stride } return (inner_offset, max_outer_dim, strides, ranges, is_root, do_repeat, diff --git a/sam/onyx/hw_nodes/stream_arbiter_node.py b/sam/onyx/hw_nodes/stream_arbiter_node.py new file mode 100644 index 00000000..1ae6ed09 --- /dev/null +++ b/sam/onyx/hw_nodes/stream_arbiter_node.py @@ -0,0 +1,80 @@ +from sam.onyx.hw_nodes.hw_node import * + + +class StreamArbiterNode(HWNode): + def __init__(self, name=None) -> None: + super().__init__(name=name) + self.max_num_inputs = 4 + self.num_inputs_connected = 0 + self.num_outputs = 1 + self.num_outputs_connected = 0 + + def connect(self, other, edge, kwargs=None): + + from sam.onyx.hw_nodes.broadcast_node import BroadcastNode + from sam.onyx.hw_nodes.compute_node import ComputeNode + from sam.onyx.hw_nodes.glb_node import GLBNode + from sam.onyx.hw_nodes.buffet_node import BuffetNode + from sam.onyx.hw_nodes.memory_node import MemoryNode + from sam.onyx.hw_nodes.read_scanner_node import ReadScannerNode + from sam.onyx.hw_nodes.write_scanner_node import WriteScannerNode + from sam.onyx.hw_nodes.intersect_node import IntersectNode + from sam.onyx.hw_nodes.reduce_node import ReduceNode + from sam.onyx.hw_nodes.lookup_node import LookupNode + from sam.onyx.hw_nodes.merge_node import MergeNode + from sam.onyx.hw_nodes.repeat_node import RepeatNode + from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode + from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode + from sam.onyx.hw_nodes.fiberaccess_node import FiberAccessNode + + new_conns = None + stream_arb = self.get_name() + + other_type = type(other) + + if other_type == GLBNode: + other_data = other.get_data() + other_ready = other.get_ready() + other_valid = other.get_valid() + new_conns = { + 'stream_arbiter_to_glb': [ + ([(stream_arb, "stream_out"), (other_data, "f2io_17")], 17), + ] + } + return new_conns + elif other_type == StreamArbiterNode: + cur_inputs = other.get_num_inputs() + assert cur_inputs < self.max_num_inputs - 1, f"Cannot connect StreamArbiterNode to {other_type}, too many inputs" + down_stream_arb = other.get_name() + new_conns = { + f'stream_arbiter_to_stream_arbiter_{cur_inputs}': [ + ([(stream_arb, "stream_out"), (down_stream_arb, f"stream_in_{cur_inputs}")], 17), + ] + } + other.update_input_connections() + return new_conns + else: + raise NotImplementedError(f'Cannot connect IntersectNode to {other_type}') + + return new_conns + + def update_input_connections(self): + self.num_inputs_connected += 1 + + def get_num_inputs(self): + return self.num_inputs_connected + + def configure(self, attributes): + # print("STREAM ARBITER CONFIGURE") + # print(attributes) + + seg_mode = attributes['seg_mode'] + num_requests = self.num_inputs_connected + assert num_requests > 0, "StreamArbiterNode must have at least one input" + num_requests = num_requests - 1 # remap to the range of 0-3 + + cfg_kwargs = { + 'num_requests': num_requests, + 'seg_mode': seg_mode + } + return (num_requests, seg_mode), cfg_kwargs diff --git a/sam/onyx/hw_nodes/write_scanner_node.py b/sam/onyx/hw_nodes/write_scanner_node.py index 00cf11ff..6e5bc49f 100644 --- a/sam/onyx/hw_nodes/write_scanner_node.py +++ b/sam/onyx/hw_nodes/write_scanner_node.py @@ -106,6 +106,10 @@ def configure(self, attributes): else: vr_mode = 0 + stream_id = 0 + if 'stream_id' in attributes: + stream_id = int(attributes['stream_id']) + cfg_tuple = (compressed, lowest_level, stop_lvl, block_mode, vr_mode, init_blank) cfg_kwargs = { 'compressed': compressed, @@ -113,6 +117,7 @@ def configure(self, attributes): 'stop_lvl': stop_lvl, 'block_mode': block_mode, 'vr_mode': vr_mode, - 'init_blank': init_blank + 'init_blank': init_blank, + 'stream_id': stream_id } return cfg_tuple, cfg_kwargs diff --git a/sam/onyx/parse_dot.py b/sam/onyx/parse_dot.py index ae5dd48f..f5562be2 100644 --- a/sam/onyx/parse_dot.py +++ b/sam/onyx/parse_dot.py @@ -18,7 +18,7 @@ def __init__(self, *args: object) -> None: class SAMDotGraph(): def __init__(self, filename=None, local_mems=True, use_fork=False, - use_fa=False, unroll=1, collat_dir=None, opal_workaround=False) -> None: + use_fa=False, unroll=1, collat_dir=None, opal_workaround=False, mem_block_size=1000) -> None: assert filename is not None, "filename is None" self.graphs = pydot.graph_from_dot_file(filename) self.graph = self.graphs[0] @@ -32,14 +32,14 @@ def __init__(self, filename=None, local_mems=True, use_fork=False, self.fa_color = 0 self.collat_dir = collat_dir self.opal_workaround = opal_workaround + self.mem_block_size = mem_block_size self.alu_nodes = [] - self.shared_writes = {} + self.shared_glb = {} + self.shared_stream_arb = {} + self.shared_stream_arb_glb_edge = [] # Key assuming, single level stream arbiter - if unroll > 1: - self.duplicate_graph('B', unroll) self.annotate_IO_nodes() - # self.unroll_graph('b', 2) self.graph.write_png('mek.png') # exit() # print(self.graph) @@ -48,8 +48,10 @@ def __init__(self, filename=None, local_mems=True, use_fork=False, self.rewrite_tri_to_binary() self.rewrite_VectorReducer() + self.duplicate_graph(unroll) # duplicate the entire graph + # Passes to lower to CGRA - self.rewrite_lookup() + self.rewrite_lookup(unroll) self.rewrite_arrays() # If using real fork, we don't rewrite the rsg broadcast in the same way if self.use_fork: @@ -61,6 +63,11 @@ def __init__(self, filename=None, local_mems=True, use_fork=False, if len(self.alu_nodes) > 0: self.rewrite_complex_ops() + nodes = self.graph.get_nodes() + for node in nodes: + print(node.get_name()) + print(node.get_attributes()) + def get_mode_map(self): sc = self.graph.get_comment().strip('"') for tensor in sc.split(","): @@ -236,6 +243,10 @@ def map_nodes(self): hw_nt = f"HWNodeType.CrdHold" elif n_type == "vectorreducer": hw_nt = f"HWNodeType.VectorReducer" + elif n_type == "streamarbiter": + hw_node = f"HWNodeType.StreamArbiter" + elif n_type == "pass_through": + hw_nt = f"HWNodeType.PassThrough" else: # if the current node is not any of the primitives, it must be a compute hw_nt = f"HWNodeType.Compute" @@ -832,8 +843,8 @@ def rewrite_broadcast(self): # Now we have the broadcast node - want to find the incoming edge and redirect to the destinations for broadcast_node in nodes_to_proc: # broadcast_node = self.graph.get_node(broadcast_node) - attrs = node.get_attributes() - og_label = attrs['label'] + # attrs = node.get_attributes() + # og_label = attrs['label'] # del attrs['label'] # Find the upstream broadcast node in_src = None @@ -914,7 +925,7 @@ def rewrite_rsg_broadcast(self): self.graph.add_edge(og_to_rsg) self.graph.add_edge(rsg_to_branch) - def rewrite_lookup(self): + def rewrite_lookup(self, unroll): ''' Rewrites the lookup nodes to become (wr_scan, rd_scan, buffet) triples ''' @@ -931,7 +942,14 @@ def rewrite_lookup(self): attrs = node.get_attributes() og_label = attrs['label'] + og_label = og_label.split('_') + if len(og_label) > 1: + dup_id = int(og_label[-1]) + else: + dup_id = 0 + og_label = og_label[0] del attrs['label'] + attrs['stream_id'] = dup_id rd_scan = pydot.Node(f"rd_scan_{self.get_next_seq()}", **attrs, label=f"{og_label}_rd_scan", hwnode=f"{HWNodeType.ReadScanner}", @@ -953,16 +971,23 @@ def rewrite_lookup(self): # dense scanner is basically a counter that counts up to the dimension size # and does not rely on the GLB tile to supply any data glb_write = None + pass_through = None if not is_dense or not self.opal_workaround: - if f'{tensor}_{mode}_fiberlookup' in self.shared_writes and \ - self.shared_writes[f'{tensor}_{mode}_fiberlookup'][1] is not None: - glb_write = self.shared_writes[f'{tensor}_{mode}_fiberlookup'][1] + if f'{tensor}_{mode}_fiberlookup' in self.shared_glb: + (glb_write, pass_through) = self.shared_glb[f'{tensor}_{mode}_fiberlookup'] else: glb_write = pydot.Node(f"glb_write_{self.get_next_seq()}", **attrs, label=f"{og_label}_glb_write", hwnode=f"{HWNodeType.GLB}") self.graph.add_node(glb_write) - if f'{tensor}_{mode}_fiberlookup' in self.shared_writes: - self.shared_writes[f'{tensor}_{mode}_fiberlookup'][1] = glb_write + + pass_through = pydot.Node(f"passthrough_{self.get_next_seq()}", **attrs, + label=f"{og_label}_passthrough", hwnode=f"{HWNodeType.PassThrough}") + self.graph.add_node(pass_through) + self.shared_glb[f'{tensor}_{mode}_fiberlookup'] = (glb_write, pass_through) + + glb_to_pass_through = pydot.Edge(src=glb_write, dst=pass_through, + label=f"glb_to_pass_through_{self.get_next_seq()}", style="bold") + self.graph.add_edge(glb_to_pass_through) if self.local_mems is False: memory = pydot.Node(f"memory_{self.get_next_seq()}", **attrs, label=f"{og_label}_SRAM", hwnode=f"{HWNodeType.Memory}") @@ -977,9 +1002,12 @@ def rewrite_lookup(self): # Glb to WR # Dense scanner doesn't need data from the GLB, hence no connection to the GLB if not is_dense or not self.opal_workaround: - glb_to_wr = pydot.Edge(src=glb_write, dst=wr_scan, label=f"glb_to_wr_{self.get_next_seq()}", - style="bold") - self.graph.add_edge(glb_to_wr) + # glb_to_wr = pydot.Edge(src=glb_write, dst=wr_scan, label=f"glb_to_wr_{self.get_next_seq()}", + # style="bold") + # self.graph.add_edge(glb_to_wr) + pass_through_to_wr = pydot.Edge(src=pass_through, dst=wr_scan, + label=f"pass_through_to_wr_{self.get_next_seq()}", style="bold") + self.graph.add_edge(pass_through_to_wr) # write + read to buffet wr_to_buff = pydot.Edge(src=wr_scan, dst=buffet, label=f'wr_to_buff_{self.get_next_seq()}') self.graph.add_edge(wr_to_buff) @@ -1029,6 +1057,16 @@ def rewrite_lookup(self): node.create_attribute_methods(attrs) og_label = attrs['label'] del attrs['label'] + og_label = og_label.split('_') + if len(og_label) > 1: + dup_id = int(og_label[-1]) + else: + dup_id = 0 + og_label = og_label[0] + attrs['glb_addr_base'] = dup_id * self.mem_block_size + attrs['glb_addr_stride'] = unroll * self.mem_block_size + attrs['stream_id'] = dup_id + rd_scan = pydot.Node(f"rd_scan_{self.get_next_seq()}", **attrs, label=f"{og_label}_rd_scan", hwnode=f"{HWNodeType.ReadScanner}", fa_color=self.fa_color) @@ -1041,8 +1079,17 @@ def rewrite_lookup(self): self.fa_color += 1 - glb_read = pydot.Node(f"glb_read_{self.get_next_seq()}", **attrs, - label=f"{og_label}_glb_read", hwnode=f"{HWNodeType.GLB}") + tensor = attrs['tensor'].strip('"') + attrs['mode'].strip('"') + if f'{tensor}_fiber' in self.shared_glb: + glb_read = self.shared_glb[f'{tensor}_fiber'] + else: + glb_read = pydot.Node(f"glb_read_{self.get_next_seq()}", **attrs, + label=f"{og_label}_glb_read", hwnode=f"{HWNodeType.GLB}") + self.shared_glb[f'{tensor}_fiber'] = glb_read + self.graph.add_node(glb_read) + + # glb_read = pydot.Node(f"glb_read_{self.get_next_seq()}", **attrs, + # label=f"{og_label}_glb_read", hwnode=f"{HWNodeType.GLB}") if self.local_mems is False: memory = pydot.Node(f"memory_{self.get_next_seq()}", **attrs, label=f"{og_label}_SRAM", hwnode=f"{HWNodeType.Memory}") @@ -1054,18 +1101,49 @@ def rewrite_lookup(self): else: in_edge = [edge for edge in self.graph.get_edges() if edge.get_destination() == node.get_name() and "crd" in edge.get_label()][0] + if unroll > 1: + # create shared stream arbiter + stream_arb_mode = attrs['mode'].strip('"') + stream_arb_label = f"stream_arb_{stream_arb_mode}" + if stream_arb_label in self.shared_stream_arb: + stream_arb = self.shared_stream_arb[stream_arb_label] + else: + stream_arb_attr = dict() + if stream_arb_mode == 'vals': + stream_arb_attr['seg_mode'] = 0 + else: + stream_arb_attr['seg_mode'] = 1 + stream_arb = pydot.Node(f"stream_arb_{self.get_next_seq()}", **stream_arb_attr, label=stream_arb_label, + comment=f"type=stream_arbiter,mode={stream_arb_mode}", type="stream_arbiter", + hwnode=f"{HWNodeType.StreamArbiter}") + self.shared_stream_arb[stream_arb_label] = stream_arb + self.graph.add_node(stream_arb) # Now add the nodes and move the edges... self.graph.add_node(rd_scan) self.graph.add_node(wr_scan) self.graph.add_node(buffet) - self.graph.add_node(glb_read) + # self.graph.add_node(glb_read) if self.local_mems is False: self.graph.add_node(memory) - # RD to GLB - rd_to_glb = pydot.Edge(src=rd_scan, dst=glb_read, label=f"glb_to_wr_{self.get_next_seq()}", - style="bold") - self.graph.add_edge(rd_to_glb) + + if unroll > 1: + # RD to Stream Arb + rd_to_stream_arb = pydot.Edge(src=rd_scan, dst=stream_arb, + label=f"rd_to_stream_arb_{self.get_next_seq()}", style="bold") + self.graph.add_edge(rd_to_stream_arb) + + if (stream_arb, glb_read) not in self.shared_stream_arb_glb_edge: + # Stream Arb to GLB + stream_arb_to_glb = pydot.Edge(src=stream_arb, dst=glb_read, + label=f"stream_arb_to_glb_{self.get_next_seq()}", style="bold") + self.graph.add_edge(stream_arb_to_glb) + self.shared_stream_arb_glb_edge.append((stream_arb, glb_read)) + else: + # RD to GLB + rd_to_glb = pydot.Edge(src=rd_scan, dst=glb_read, label=f"glb_to_wr_{self.get_next_seq()}", + style="bold") + self.graph.add_edge(rd_to_glb) # write + read to buffet wr_to_buff = pydot.Edge(src=wr_scan, dst=buffet, label=f'wr_to_buff_{self.get_next_seq()}') self.graph.add_edge(wr_to_buff) @@ -1097,7 +1175,20 @@ def rewrite_arrays(self): # Rewrite this node to a read attrs = node.get_attributes() og_label = attrs['label'] + og_label = og_label.split('_') + + # TODO better solution for this? + if len(og_label) > 1 and og_label[-1] != 'lut': + print(attrs) + print(og_label) + # TODO better solution + dup_id = int(og_label[-1]) + else: + dup_id = 0 + og_label = og_label[0] del attrs['label'] + attrs['stream_id'] = dup_id + rd_scan = pydot.Node(f"rd_scan_{self.get_next_seq()}", **attrs, label=f"{og_label}_rd_scan", hwnode=f"{HWNodeType.ReadScanner}", fa_color=self.fa_color) @@ -1112,14 +1203,20 @@ def rewrite_arrays(self): # Only instantiate the glb_write if it doesn't exist tensor = attrs['tensor'].strip('"') - if f'{tensor}_arrayvals' in self.shared_writes and self.shared_writes[f'{tensor}_arrayvals'][1] is not None: - glb_write = self.shared_writes[f'{tensor}_arrayvals'][1] + if f'{tensor}_arrayvals' in self.shared_glb: + (glb_write, pass_through) = self.shared_glb[f'{tensor}_arrayvals'] else: glb_write = pydot.Node(f"glb_write_{self.get_next_seq()}", **attrs, label=f"{og_label}_glb_write", hwnode=f"{HWNodeType.GLB}") - if f'{tensor}_arrayvals' in self.shared_writes: - self.shared_writes[f'{tensor}_arrayvals'][1] = glb_write self.graph.add_node(glb_write) + pass_through = pydot.Node(f"passthrough_{self.get_next_seq()}", **attrs, + label=f"{og_label}_passthrough", hwnode=f"{HWNodeType.PassThrough}") + self.graph.add_node(pass_through) + self.shared_glb[f'{tensor}_arrayvals'] = (glb_write, pass_through) + + glb_to_pass_through = pydot.Edge(src=glb_write, dst=pass_through, + label=f"glb_to_pass_through_{self.get_next_seq()}", style="bold") + self.graph.add_edge(glb_to_pass_through) # glb_write = pydot.Node(f"glb_write_{self.get_next_seq()}", # **attrs, label=f"{og_label}_glb_write", hwnode=f"{HWNodeType.GLB}") @@ -1139,8 +1236,11 @@ def rewrite_arrays(self): if self.local_mems is False: self.graph.add_node(memory) # Glb to WR - glb_to_wr = pydot.Edge(src=glb_write, dst=wr_scan, label=f"glb_to_wr_{self.get_next_seq()}", style="bold") - self.graph.add_edge(glb_to_wr) + # glb_to_wr = pydot.Edge(src=glb_write, dst=wr_scan, label=f"glb_to_wr_{self.get_next_seq()}", style="bold") + # self.graph.add_edge(glb_to_wr) + pass_through_to_wr = pydot.Edge(src=pass_through, dst=wr_scan, + label=f"pass_through_to_wr_{self.get_next_seq()}", style="bold") + self.graph.add_edge(pass_through_to_wr) # write + read to buffet wr_to_buff = pydot.Edge(src=wr_scan, dst=buffet, label=f'wr_to_buff_{self.get_next_seq()}') self.graph.add_edge(wr_to_buff) @@ -1170,84 +1270,54 @@ def rewrite_arrays(self): def get_graph(self): return self.graph - def unroll_graph(self, tensor, unroll_factor): + def duplicate_graph(self, unroll_factor): + if unroll_factor == 1: + return + dupe_map = {} + orig_nodes_list = self.graph.get_nodes().copy() # shallow copy is sufficient + node_count = len(orig_nodes_list) # Duplicate every node that isn't the tensor of interest - for node in self.graph.get_nodes(): + for node in orig_nodes_list: node_attrs = node.get_attributes() - og_label = node_attrs['label'].strip('"') - node_type = node_attrs['type'].strip('"') - # del node_attrs['label'] - attrs_copy = node_attrs.copy() - del attrs_copy['label'] - if node_type == "fiberlookup" or node_type == "arrayvals": - node_tensor = node_attrs['tensor'].strip('"') - if node_tensor == tensor: - continue - node_name = node.get_name().strip('"') - new_node = pydot.Node(f"{node_name}_dup", **attrs_copy, label=f"{og_label}_dup") - dupe_map[node_name] = new_node.get_name().strip('"') - self.graph.add_node(new_node) - # Duplicate every edge and map it to the duped versions - for edge in self.graph.get_edges(): - src = edge.get_source() - dst = edge.get_destination() - if src not in dupe_map and dst not in dupe_map: - continue - rmp_src = src if src not in dupe_map else dupe_map[src] - rmp_dst = dst if dst not in dupe_map else dupe_map[dst] - new_edge = pydot.Edge(src=rmp_src, dst=rmp_dst, **edge.get_attributes()) - self.graph.add_edge(new_edge) - print(self.graph) + if 'broadcast' in node_attrs['type']: + node_name = node.get_name().strip('"') + dupe_map[node_name] = [] + for i in range(1, unroll_factor): + new_node_name = node_count * i + int(node_name) + new_node = pydot.Node(new_node_name, **node_attrs) + dupe_map[node_name].append(new_node.get_name().strip('"')) + self.graph.add_node(new_node) - def duplicate_graph(self, tensor, factor, output='x'): - original_nodes = self.graph.get_nodes() - # Do it over the whole graph multiple times - for fac_ in range(factor - 1): - dupe_map = {} - # Duplicate every node that isn't the tensor of interest - for node in original_nodes: - node_attrs = node.get_attributes() - print(node_attrs) - if 'label' not in node_attrs: - node_attrs['label'] = 'bcast' + else: og_label = node_attrs['label'].strip('"') node_type = node_attrs['type'].strip('"') + node_attrs['label'] = f"{og_label}_0" # del node_attrs['label'] attrs_copy = node_attrs.copy() del attrs_copy['label'] node_name = node.get_name().strip('"') - new_node = pydot.Node(f"{node_name}_dup_{fac_}", **attrs_copy, label=f"{og_label}_dup_{fac_}") - - if node_type == "fiberlookup" or node_type == "arrayvals": - node_tensor = node_attrs['tensor'].strip('"') - mode = None - if node_type == "fiberlookup": - mode = node_attrs['mode'].strip('"') - if node_tensor == tensor: - # continue - # Mark this as a shared - # mode_ = attrs_copy['mode'].strip('"') - # self.shared_writes[f'{node_tensor}_{node_type}'] = [[node, new_node], None] - name_str = f'{node_tensor}_{node_type}' if mode is None else f'{node_tensor}_{mode}_{node_type}' - self.shared_writes[name_str] = [[node, new_node], None] - dupe_map[node_name] = new_node.get_name().strip('"') - self.graph.add_node(new_node) - # Duplicate every edge and map it to the duped versions - for edge in self.graph.get_edges(): - src = edge.get_source() - dst = edge.get_destination() - if src not in dupe_map and dst not in dupe_map: - continue - rmp_src = src if src not in dupe_map else dupe_map[src] - rmp_dst = dst if dst not in dupe_map else dupe_map[dst] + dupe_map[node_name] = [] + for i in range(1, unroll_factor): + new_node_name = node_count * i + int(node_name) + new_node = pydot.Node(new_node_name, **attrs_copy, label=f"{og_label}_{i}") + dupe_map[node_name].append(new_node.get_name().strip('"')) + self.graph.add_node(new_node) + # Duplicate every edge and map it to the duped versions + orig_edge_list = self.graph.get_edges().copy() # shallow copy is sufficient + for edge in orig_edge_list: + src = edge.get_source() + dst = edge.get_destination() + assert src in dupe_map and dst in dupe_map, f"src: {src} dst: {dst} failed to duplicate" + for i in range(1, unroll_factor): + rmp_src = dupe_map[src][i - 1] + rmp_dst = dupe_map[dst][i - 1] new_edge = pydot.Edge(src=rmp_src, dst=rmp_dst, **edge.get_attributes()) self.graph.add_edge(new_edge) print(self.graph) - print(self.shared_writes) def annotate_IO_nodes(self): original_nodes = self.graph.get_nodes() diff --git a/sam/sim/src/rd_scanner.py b/sam/sim/src/rd_scanner.py index 5e51553a..66332a5b 100644 --- a/sam/sim/src/rd_scanner.py +++ b/sam/sim/src/rd_scanner.py @@ -144,6 +144,11 @@ def update(self): self.curr_crd = stkn self.curr_ref = stkn return + elif is_0tkn(self.curr_in_ref): + self.curr_crd = '' + self.curr_ref = '' + self.emit_tkn = True + return else: self.curr_crd = 0 self.curr_ref = self.curr_crd + (self.curr_in_ref * self.meta_dim)