diff --git a/compiler/sam-outputs/onyx-dot/mat_elemadd_leakyrelu_exp.gv b/compiler/sam-outputs/onyx-dot/mat_elemadd_leakyrelu_exp.gv index aaad0587..e360afc4 100644 --- a/compiler/sam-outputs/onyx-dot/mat_elemadd_leakyrelu_exp.gv +++ b/compiler/sam-outputs/onyx-dot/mat_elemadd_leakyrelu_exp.gv @@ -9,15 +9,9 @@ digraph SAM { 4 [comment="type=arrayvals,tensor=B" label="Array Vals: B" color=green2 shape=box style=filled type="arrayvals" tensor="B"] 3 [comment="type=fp_add" label="FP_Add" color=brown shape=box style=filled type="fp_add"] 12 [comment="broadcast" shape=point style=invis type="broadcast"] - 13 [comment="type=fp_mul,rb_const=0.2" label="FP_Mul * 0.2" color=brown shape=box style=filled type="fp_mul" rb_const="0.2"] + 13 [comment="type=fp_mul,const0=0.2" label="FP_Mul * 0.2" color=brown shape=box style=filled type="fp_mul" const0="0.2"] 14 [comment="type=fp_max" label="FP_Max" color=brown shape=box style=filled type="fp_max"] - 15 [comment="type=fp_mul,rb_const=1.44269504089" label="FP_Mul * 1.44269504089" color=brown shape=box style=filled type="fp_mul" rb_const="1.44269504089"] - 16 [comment="type=broadcast" shape=point style=invis type="broadcast"] - 17 [comment="type=fgetfint" label="Fgetfint" color=brown shape=box style=filled type="fgetfint"] - 18 [comment="type=fgetffrac" label="Fgetffrac" color=brown shape=box style=filled type="fgetffrac"] - 19 [comment="type=and,rb_const=255" label="And 0x00FF" color=brown shape=box style=filled type="and" rb_const="255"] - 20 [comment="type=faddiexp" label="Faddiexp" color=brown shape=box style=filled type="faddiexp"] - 21 [comment="type=arrayvals,tensor=exp" label="Array Vals: exp" color=green2 shape=box style=filled type="arrayvals" tensor="exp"] + 15 [comment="type=exp" label="Exp" color=brown shape=box style=filled type="exp"] 0 [comment="type=fiberwrite,mode=vals,tensor=X,size=1*B0_dim*B1_dim,sink=true" label="FiberWrite Vals: X" color=green3 shape=box style=filled type="fiberwrite" tensor="X" mode="vals" size="1*B0_dim*B1_dim" sink="true"] 5 [comment="type=arrayvals,tensor=C" label="Array Vals: C" color=green2 shape=box style=filled type="arrayvals" tensor="C"] 8 [comment="type=fiberlookup,index=j,tensor=C,mode=1,format=compressed,src=true,root=false" label="FiberLookup j: C1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="C" mode="1" format="compressed" src="true" root="false"] @@ -34,14 +28,7 @@ digraph SAM { 12 -> 14 [label="val" type="val"] 13 -> 14 [label="val" type="val"] 14 -> 15 [label="val" type="val"] - 15 -> 16 [label="val" type="val"] - 16 -> 17 [label="val" type="val"] - 16 -> 18 [label="val" type="val"] - 18 -> 19 [label="val" type="val"] - 19 -> 21 [label="ref" style=bold type="ref"] - 21 -> 20 [label="val" type="val" comment="fp"] - 17 -> 20 [label="val" type="val" comment="exp"] - 20 -> 0 [label="val" type="val"] + 15 -> 0 [label="val" type="val"] 6 -> 5 [label="ref_out-C" style=bold type="ref" comment="out-C"] 5 -> 3 [label="val" type="val"] 7 -> 6 [label="ref_in-B" style=bold type="ref" comment="in-B"] diff --git a/compiler/sam-outputs/onyx-dot/mat_elemadd_relu.gv b/compiler/sam-outputs/onyx-dot/mat_elemadd_relu.gv index 6e5a47ae..b0364766 100644 --- a/compiler/sam-outputs/onyx-dot/mat_elemadd_relu.gv +++ b/compiler/sam-outputs/onyx-dot/mat_elemadd_relu.gv @@ -9,7 +9,7 @@ digraph SAM { 5 [comment="type=arrayvals,tensor=C" label="Array Vals: C" color=green2 shape=box style=filled type="arrayvals" tensor="C"] 8 [comment="type=fiberlookup,index=j,tensor=C,mode=1,format=compressed,src=true,root=false" label="FiberLookup j: C1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="C" mode="1" format="compressed" src="true" root="false"] 11 [comment="type=fiberlookup,index=i,tensor=C,mode=0,format=compressed,src=true,root=true" label="FiberLookup i: C0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="i" tensor="C" mode="0" format="compressed" src="true" root="true"] - 12 [comment="type=max" label="Max 0" color=brown shape=box style=filled type="max"] + 12 [comment="type=smax,const0=0" label="Max 0" color=brown shape=box style=filled type="smax", const0="0"] 13 [comment="type=crddrop,outer=j,inner=val,mode=0" label="CrdDrop Compression j, val" color=orange style=filled type="crddrop" outer="j" inner="val" mode="0"] 0 [comment="type=fiberwrite,mode=vals,tensor=X,size=1*B0_dim*B1_dim,sink=true" label="FiberWrite Vals: X" color=green3 shape=box style=filled type="fiberwrite" tensor="X" mode="vals" size="2*B0_dim*B1_dim" sink="true"] 14 [comment="type=crddrop,outer=i,inner=j" label="CrdDrop i,j" color=orange shape=box style=filled type="crddrop" outer="i" inner="j"] diff --git a/compiler/sam-outputs/onyx-dot/matmul_ijk_crddrop_relu.gv b/compiler/sam-outputs/onyx-dot/matmul_ijk_crddrop_relu.gv index 107d4d2c..b4fb38c9 100644 --- a/compiler/sam-outputs/onyx-dot/matmul_ijk_crddrop_relu.gv +++ b/compiler/sam-outputs/onyx-dot/matmul_ijk_crddrop_relu.gv @@ -18,7 +18,7 @@ digraph SAM { 0 [comment="type=fiberwrite,mode=vals,tensor=X,size=1*B0_dim*C1_dim,sink=true" label="FiberWrite Vals: X" color=green3 shape=box style=filled type="fiberwrite" tensor="X" mode="vals" size="1*B0_dim*C1_dim" sink="true"] 6 [comment="type=arrayvals,tensor=C" label="Array Vals: C" color=green2 shape=box style=filled type="arrayvals" tensor="C"] 9 [comment="type=fiberlookup,index=k,tensor=C,mode=0,format=compressed,src=true,root=false" label="FiberLookup k: C0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="k" tensor="C" mode="0" format="compressed" src="true" root="false"] - 20 [comment="type=max" label="Max 0" color=brown shape=box style=filled type="max"] + 20 [comment="type=smax,const0=0" label="Max 0" color=brown shape=box style=filled type="smax", const0="0"] 21 [comment="type=crddrop,outer=j,inner=val,mode=0" label="CrdDrop Compression j, val" color=orange style=filled type="crddrop" outer="j" inner="val" mode="0"] 22 [comment="type=crddrop,outer=i,inner=j" label="CrdDrop i,j" color=orange shape=box style=filled type="crddrop" outer="i" inner="j"] 17 -> 16 [label="crd" style=dashed type="crd" comment=""] diff --git a/compiler/sam-outputs/onyx-dot/spmm_ijk_crddrop_relu.gv b/compiler/sam-outputs/onyx-dot/spmm_ijk_crddrop_relu.gv index 7fdda848..f9414be6 100644 --- a/compiler/sam-outputs/onyx-dot/spmm_ijk_crddrop_relu.gv +++ b/compiler/sam-outputs/onyx-dot/spmm_ijk_crddrop_relu.gv @@ -18,7 +18,7 @@ digraph SAM { 0 [comment="type=fiberwrite,mode=vals,tensor=X,size=1*B0_dim*C1_dim,sink=true" label="FiberWrite Vals: X" color=green3 shape=box style=filled type="fiberwrite" tensor="X" mode="vals" size="1*B0_dim*C1_dim" sink="true"] 6 [comment="type=arrayvals,tensor=C" label="Array Vals: C" color=green2 shape=box style=filled type="arrayvals" tensor="C"] 9 [comment="type=fiberlookup,index=k,tensor=C,mode=0,format=compressed,src=true,root=false" label="FiberLookup k: C0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="k" tensor="C" mode="0" format="compressed" src="true" root="false"] - 20 [comment="type=max" label="Max 0" color=brown shape=box style=filled type="max"] + 20 [comment="type=smax,const0=0" label="Max 0" color=brown shape=box style=filled type="smax", const0="0"] 21 [comment="type=crddrop,outer=j,inner=val,mode=0" label="CrdDrop Compression j, val" color=orange style=filled type="crddrop" outer="j" inner="val" mode="0"] 22 [comment="type=crddrop,outer=i,inner=j" label="CrdDrop i,j" color=orange shape=box style=filled type="crddrop" outer="i" inner="j"] 17 -> 16 [label="crd" style=dashed type="crd" comment=""] diff --git a/compiler/sam-outputs/onyx-dot/spmv_relu.gv b/compiler/sam-outputs/onyx-dot/spmv_relu.gv index 55d2bfe1..8bc06419 100644 --- a/compiler/sam-outputs/onyx-dot/spmv_relu.gv +++ b/compiler/sam-outputs/onyx-dot/spmv_relu.gv @@ -13,7 +13,7 @@ digraph SAM { 9 [comment="type=fiberlookup,index=j,tensor=B,mode=1,format=compressed,src=true,root=false" label="FiberLookup j: B1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="B" mode="1" format="compressed" src="true" root="false"] - 20 [comment="type=max" label="Max 0" color=brown shape=box style=filled type="max"] + 20 [comment="type=smax,const0=0" label="Max 0" color=brown shape=box style=filled type="smax", const0="0"] 0 [comment="type=fiberwrite,mode=vals,tensor=x,size=1*B0_dim,sink=true" label="FiberWrite Vals: x" color=green3 shape=box style=filled type="fiberwrite" tensor="x" mode="vals" size="1*B0_dim" sink="true"] 21 [comment="type=crddrop,outer=i,inner=val,mode=0" label="CrdDrop Compression i, val" color=orange style=filled type="crddrop" outer="i" inner="val" mode="0"] 2 [comment="type=fiberwrite,index=i,tensor=x,mode=0,format=compressed,segsize=2,crdsize=B0_dim,sink=true" label="FiberWrite i: x0\ncompressed" color=green3 shape=box style=filled type="fiberwrite" index="i" tensor="x" mode="0" format="compressed" segsize="2" crdsize="B0_dim" sink="true"] diff --git a/sam/onyx/hw_nodes/compute_node.py b/sam/onyx/hw_nodes/compute_node.py index 9c6af1e3..0107ff83 100644 --- a/sam/onyx/hw_nodes/compute_node.py +++ b/sam/onyx/hw_nodes/compute_node.py @@ -1,15 +1,24 @@ from sam.onyx.hw_nodes.hw_node import * from lassen.utils import float2bfbin +import json class ComputeNode(HWNode): - def __init__(self, name=None, op=None) -> None: + def __init__(self, name=None, op=None, sam_graph_node_id=None, + mapped_coreir_dir=None, is_mapped_from_complex_op=False, original_complex_op_id=None) -> None: super().__init__(name=name) self.num_inputs = 2 self.num_outputs = 1 self.num_inputs_connected = 0 self.num_outputs_connected = 0 + self.mapped_input_ports = [] self.op = op + self.opcode = None + self.mapped_coreir_dir = mapped_coreir_dir + # parse the mapped coreir file to get the input ports and opcode + self.parse_mapped_json(self.mapped_coreir_dir + "/alu_coreir_spec_mapped.json", + sam_graph_node_id, is_mapped_from_complex_op, original_complex_op_id) + assert self.opcode is not None def connect(self, other, edge, kwargs=None): @@ -118,22 +127,23 @@ def connect(self, other, edge, kwargs=None): other_pe = other.get_name() other_conn = other.get_num_inputs() pe = self.get_name() - # TODO: remove hack eventually - if 'Max 0' in other.op: - other_conn = 1 - elif 'Faddiexp' in other.op: - comment = edge.get_attributes()["comment"].strip('"') - if 'fp' in comment: - other_conn = 0 - elif 'exp' in comment: - other_conn = 1 - else: - assert 0 & "edge connected to faddiexp has to have comment specified to either 'exp' or 'fp'" - new_conns = { - f'pe_to_pe_{other_conn}': [ - ([(pe, "res"), (other_pe, f"data{other_conn}")], 17), - ] - } + edge_attr = edge.get_attributes() + # a destination port name has been specified by metamapper + if "specified_port" in edge_attr and edge_attr["specified_port"] is not None: + other_conn = edge_attr["specified_port"] + other.mapped_input_ports.append(other_conn.strip("data")) + new_conns = { + f'pe_to_pe_{other_conn}': [ + ([(pe, "res"), (other_pe, f"{other_conn}")], 17), + ] + } + else: + other_conn = other.mapped_input_ports[other_conn] + new_conns = { + f'pe_to_pe_{other_conn}': [ + ([(pe, "res"), (other_pe, f"data{other_conn}")], 17), + ] + } other.update_input_connections() return new_conns elif other_type == BroadcastNode: @@ -161,6 +171,46 @@ def update_input_connections(self): def get_num_inputs(self): return self.num_inputs_connected + def parse_mapped_json(self, filename, node_id, is_mapped_from_complex_op, original_complex_op_id): + with open(filename, 'r') as alu_mapped_file: + alu_mapped = json.load(alu_mapped_file) + # parse out the mapped opcode + alu_instance_name = None + module = None + if not is_mapped_from_complex_op: + module = alu_mapped["namespaces"]["global"]["modules"]["ALU_" + node_id + "_mapped"] + for instance_name, instance in module["instances"].items(): + if "modref" in instance and instance["modref"] == "global.PE": + alu_instance_name = instance_name + break + for connection in alu_mapped["namespaces"]["global"]["modules"]["ALU_" + node_id + "_mapped"]["connections"]: + port0, port1 = connection + if "self.in" in port0: + # get the port name of the alu + self.mapped_input_ports.append(port1.split(".")[1].strip("data")) + elif "self.in" in port1: + self.mapped_input_ports.append(port0.split(".")[1].strip("data")) + assert (len(self.mapped_input_ports) > 0) + else: + assert original_complex_op_id is not None + module = alu_mapped["namespaces"]["global"]["modules"]["ALU_" + original_complex_op_id + "_mapped"] + # node namae of a remapped alu node from a complex op is of the format _ + alu_instance_name = '_'.join(node_id.split("_")[0:-1]) + # no need to find the input and output port for remapped op + # as it is already assigned when we remap the complex op and stored in the edge object + # look for the constant coreir object that supplies the opcode to the alu at question + # insturction is supplied through the "inst" port of the alu + for connection in module["connections"]: + port0, port1 = connection + if f"{alu_instance_name}.inst" in port0: + constant_name = port1.split(".")[0] + elif f"{alu_instance_name}.inst" in port1: + constant_name = port0.split(".")[0] + opcode = module["instances"][constant_name]["modargs"]["value"][1] + opcode = "0x" + opcode.split('h')[1] + + self.opcode = int(opcode, 0) + def configure(self, attributes): print("PE CONFIGURE") print(attributes) @@ -174,46 +224,18 @@ def configure(self, attributes): pe_only = True # data I/O should interface with other primitive outside of the cluster pe_in_external = 1 - if c_op == 'mul': - op_code = 1 - elif c_op == 'add' and 'sub=1' not in comment: - op_code = 0 - elif c_op == 'add' and 'sub=1' in comment: - op_code = 2 - elif c_op == 'max': - op_code = 4 - elif c_op == 'and': - op_code = 5 - elif c_op == 'fp_mul': - op_code = 6 - elif c_op == 'fgetfint': - op_code = 7 - elif c_op == 'fgetffrac': - op_code = 8 - elif c_op == 'faddiexp': - op_code = 9 - elif c_op == 'fp_max': - op_code = 10 - elif c_op == 'fp_add': - op_code = 11 - - rb_const = None - if "rb_const" in attributes: - # the b operand of the op is a constant - rb_const = attributes["rb_const"].strip('"') - if "." in rb_const: - # constant is a floating point - rb_const = float(rb_const) - rb_const = int(float2bfbin(rb_const), 2) - else: - # it is a int - rb_const = int(rb_const) + # according to the mapped input ports generate input port config + num_sparse_inputs = list("000") + for port in self.mapped_input_ports: + num_sparse_inputs[2 - int(port)] = '1' + print("".join(num_sparse_inputs)) + num_sparse_inputs = int("".join(num_sparse_inputs), 2) cfg_kwargs = { - 'op': op_code, + 'op': self.opcode, 'use_dense': use_dense, 'pe_only': pe_only, 'pe_in_external': pe_in_external, - 'rb_const': rb_const + 'num_sparse_inputs': num_sparse_inputs } - return (op_code, use_dense, pe_only, pe_in_external, rb_const), cfg_kwargs + return (op_code, use_dense, pe_only, pe_in_external, num_sparse_inputs), cfg_kwargs diff --git a/sam/onyx/hw_nodes/intersect_node.py b/sam/onyx/hw_nodes/intersect_node.py index f84feeef..c95bb3f1 100644 --- a/sam/onyx/hw_nodes/intersect_node.py +++ b/sam/onyx/hw_nodes/intersect_node.py @@ -175,18 +175,14 @@ def connect(self, other, edge, kwargs=None): # Could be doing a sparse accum compute = other compute_name = other.get_name() - print("INTERSECT TO COMPUTE EDGE!") - print(edge) - print(edge.get_attributes()) edge_comment = edge.get_attributes()['comment'].strip('"') tensor = edge_comment.split('-')[1] - print(self.tensor_to_conn) out_conn = self.tensor_to_conn[tensor] compute_conn = compute.get_num_inputs() new_conns = { 'intersect_to_repeat': [ # send output to rd scanner - ([(isect, f"pos_out_{out_conn}"), (compute_name, f"data{compute_conn}")], 17), + ([(isect, f"pos_out_{out_conn}"), (compute_name, f"data{compute.mapped_input_ports[compute_conn]}")], 17), ] } compute.update_input_connections() diff --git a/sam/onyx/hw_nodes/read_scanner_node.py b/sam/onyx/hw_nodes/read_scanner_node.py index 02acf89b..3a952a35 100644 --- a/sam/onyx/hw_nodes/read_scanner_node.py +++ b/sam/onyx/hw_nodes/read_scanner_node.py @@ -194,32 +194,27 @@ def connect(self, other, edge, kwargs=None): return new_conns elif other_type == ComputeNode: compute = other.get_name() - # compute_conn = 0 - print("CHECKING READ TENSOR - COMPUTE") - print(edge) - print(self.get_tensor()) - # if self.get_tensor() == 'C' or self.get_tensor() == 'c': - # compute_conn = 1 - # Can use dynamic information to assign inputs to compute nodes # since add/mul are commutative compute_conn = other.get_num_inputs() - # TODO: get rid of this hack - if 'Faddiexp' in other.op: - comment = edge.get_attributes()["comment"].strip('"') - if 'fp' in comment: - compute_conn = 0 - elif 'exp' in comment: - compute_conn = 1 - else: - assert 0 & "edge connected to faddiexp has to have comment specified to either 'exp' or 'fp'" - new_conns = { - f'rd_scan_to_compute_{compute_conn}': [ - ([(rd_scan, "coord_out"), (compute, f"data{compute_conn}")], 17), - ] - } - # Now update the PE/compute to use the next connection next time - other.update_input_connections() + edge_attr = edge.get_attributes() + if "specified_port" in edge_attr and edge_attr["specified_port"] is not None: + compute_conn = edge_attr["specified_port"] + other.mapped_input_ports.append(compute_conn.strip("data")) + new_conns = { + f'rd_scan_to_compute_{compute_conn}': [ + ([(rd_scan, "coord_out"), (compute, f"{compute_conn}")], 17), + ] + } + else: + compute_conn = other.mapped_input_ports[compute_conn] + new_conns = { + f'rd_scan_to_compute_{compute_conn}': [ + ([(rd_scan, "coord_out"), (compute, f"data{compute_conn}")], 17), + ] + } + # Now update the PE/compute to use the next connection next time + other.update_input_connections() return new_conns diff --git a/sam/onyx/hw_nodes/reduce_node.py b/sam/onyx/hw_nodes/reduce_node.py index b19cf5e2..0700373e 100644 --- a/sam/onyx/hw_nodes/reduce_node.py +++ b/sam/onyx/hw_nodes/reduce_node.py @@ -1,4 +1,8 @@ from sam.onyx.hw_nodes.hw_node import * +from peak.assembler import Assembler +from hwtypes.modifiers import strip_modifiers +from lassen.sim import PE_fc as lassen_fc +import lassen.asm as asm class ReduceNode(HWNode): @@ -68,17 +72,10 @@ def connect(self, other, edge, kwargs=None): raise NotImplementedError(f'Cannot connect ReduceNode to {other_type}') elif other_type == ComputeNode: pe = other.get_name() - if 'Max 0' in other.op: - other_conn = 1 - else: - other_conn = other.get_num_inputs() + other_conn = other.mapped_input_ports[other.get_num_inputs()] new_conns = { f'reduce_to_pe_{other_conn}': [ - # send output to rd scanner ([(red, "reduce_data_out"), (pe, f"data{other_conn}")], 17), - # ([(red, "eos_out"), (wr_scan, "eos_in_0")], 1), - # ([(wr_scan, "ready_out_0"), (red, "ready_in")], 1), - # ([(red, "valid_out"), (wr_scan, "valid_in_0")], 1), ] } other.update_input_connections() @@ -112,7 +109,10 @@ def configure(self, attributes): # data I/O to and from the PE should be internal with the reduce pe_in_external = 0 # op is set to integer add for the PE TODO: make this configurable in the sam graph - op = 0 + # TODO: make this use the metamapper + instr_type = strip_modifiers(lassen_fc.Py.input_t.field_dict['inst']) + asm_ = Assembler(instr_type) + op = int(asm_.assemble(asm.add())) cfg_kwargs = { 'stop_lvl': stop_lvl, 'pe_connected_to_reduce': pe_connected_to_reduce, diff --git a/sam/onyx/parse_dot.py b/sam/onyx/parse_dot.py index fa18a557..7e40fa5f 100644 --- a/sam/onyx/parse_dot.py +++ b/sam/onyx/parse_dot.py @@ -1,7 +1,13 @@ import argparse from numpy import broadcast import pydot +import coreir +import os +import subprocess +from hwtypes import BitVector from sam.onyx.hw_nodes.hw_node import HWNodeType +from lassen.utils import float2bfbin +import json class SAMDotGraphLoweringError(Exception): @@ -12,7 +18,7 @@ def __init__(self, *args: object) -> None: class SAMDotGraph(): def __init__(self, filename=None, local_mems=True, use_fork=False, - use_fa=False, unroll=1) -> None: + use_fa=False, unroll=1, collat_dir=None) -> None: assert filename is not None, "filename is None" self.graphs = pydot.graph_from_dot_file(filename) self.graph = self.graphs[0] @@ -24,7 +30,9 @@ def __init__(self, filename=None, local_mems=True, use_fork=False, self.use_fork = use_fork self.use_fa = use_fa self.fa_color = 0 + self.collat_dir = collat_dir + self.alu_nodes = [] self.shared_writes = {} if unroll > 1: @@ -48,6 +56,9 @@ def __init__(self, filename=None, local_mems=True, use_fork=False, else: self.rewrite_rsg_broadcast() self.map_nodes() + self.map_alu() + if len(self.alu_nodes) > 0: + self.rewrite_complex_ops() def get_mode_map(self): sc = self.graph.get_comment().strip('"') @@ -78,11 +89,126 @@ def get_mode_map(self): # return self.mode_map return self.remaining + def generate_coreir_spec(self, context, attributes, name): + # Declare I/O of ALU + module_typ = context.Record({"in0": context.Array(1, context.Array(16, context.BitIn())), + "in1": context.Array(1, context.Array(16, context.BitIn())), + "out": context.Array(16, context.Bit())}) + module = context.global_namespace.new_module("ALU_" + name, module_typ) + assert module.definition is None, "Should not have a definition" + module_def = module.new_definition() + alu_op = attributes['type'].strip('"') + # import the desired operation from coreir, commonlib, float_DW, or the float lib + # specify the width of the operations + # TODO: parameterize bit width + op = None + lib_name = None + if alu_op in context.get_namespace("coreir").generators: + coreir_op = context.get_namespace("coreir").generators[alu_op] + op = coreir_op(width=16) + lib_name = "coreir" + elif alu_op in context.get_namespace("commonlib").generators: + commonlib_op = context.get_namespace("commonlib").generators[alu_op] + op = commonlib_op(width=16) + lib_name = "commonlib" + elif alu_op in context.get_namespace("float_DW").generators: + float_DW_op = context.get_namespace("float_DW").generators[alu_op] + op = float_DW_op(exp_width=8, ieee_compliance=False, sig_width=7) + lib_name = "float_DW" + elif alu_op in context.get_namespace("float").generators: + float_op = context.get_namespace("float").generators[alu_op] + op = float_op(exp_bits=8, frac_bits=7) + lib_name = "float" + else: + # if the op is in none of the libs, it may be mapped using the custom rewrite rules + # in map_app.py + custom_rule_names = { + "mult_middle": "commonlib.mult_middle", + "abs": "commonlib.abs", + "fp_exp": "float.exp", + "fp_max": "float.max", + "fp_div": "float.div", + "fp_mux": "float.mux", + "fp_mul": "float_DW.fp_mul", + "fp_add": "float_DW.fp_add", + "fp_sub": "float.sub", + } + if alu_op in custom_rule_names: + custom_rule_lib = custom_rule_names[alu_op].split(".")[0] + custom_relu_op = custom_rule_names[alu_op].split(".")[1] + custom_op = context.get_namespace(custom_rule_lib).generators[custom_relu_op] + op = custom_op(exp_bits=8, frac_bits=7) + lib_name = custom_rule_lib + else: + raise NotImplementedError(f"fail to map node {alu_op} to compute") + # add the operation instance to the module + op_inst = module_def.add_module_instance(alu_op, op) + # instantiate the constant operand instances, if any + const_cnt = 0 + const_inst = [] + # TODO: does any floating point use more then three inputs? + for i in range(2): + if f"const{i}" in attributes: + const_cnt += 1 + coreir_const = context.get_namespace("coreir").generators["const"] + const = coreir_const(width=16) + # constant string contains a decimal point, its a floating point constant + if ("." in attributes[f"const{i}"].strip('"')): + assert "fp" in alu_op, "only support floating point constant for fp ops" + const_value = float(attributes[f"const{i}"].strip('"')) + const_value = int(float2bfbin(const_value), 2) + else: + const_value = int(attributes[f"const{i}"].strip('"')) + const_inst.append(module_def.add_module_instance(f"const{i}", + const, + context.new_values({"value": BitVector[16](const_value)}))) + + # connect the input to the op + # connect module input to the non-constant alu input ports + # note that for ops in commonlib, coreir, and float, the input ports are `in0`, `in1`, `in2` + # and the output port is `out`. + # however, the inputs for ops in float_DW are a, b, c, and the output is z + float_DW_port_mapping = ['a', 'b', 'c'] + for i in range(2 - const_cnt): + _input = module_def.interface.select(f"in{i}").select("0") + if lib_name != "float_DW": + try: + _alu_in = op_inst.select(f"in{i}") + except Exception: + print(f"Cannot select port 'in{i}', fall back to using port 'in'") + # FIXME for now the only op that raise this exception is the single input + # op fp_exp + _alu_in = op_inst.select("in") + # connect the input and break to exit the loop since there're no more port + # to connect + module_def.connect(_input, _alu_in) + break + else: + _alu_in = op_inst.select(float_DW_port_mapping[i]) + module_def.connect(_input, _alu_in) + # connect constant output to alu input ports + if const_cnt > 0: + for i in range(const_cnt, 2): + _const_out = const_inst[i - const_cnt].select("out") + if lib_name != "float_DW": + _alu_in = op_inst.select(f"in{i}") + else: + _alu_in = op_inst.select(float_DW_port_mapping[i]) + module_def.connect(_const_out, _alu_in) + # connect alu output to module output + _output = module_def.interface.select("out") + if lib_name != "float_DW": + _alu_out = op_inst.select("out") + else: + _alu_out = op_inst.select("z") + module_def.connect(_output, _alu_out) + module.definition = module_def + assert module.definition is not None, "Should have a definitation by now" + def map_nodes(self): ''' Iterate through the nodes and map them to the proper HWNodes ''' - for node in self.graph.get_nodes(): # Simple write the HWNodeType attribute if 'hwnode' not in node.get_attributes(): @@ -99,12 +225,6 @@ def map_nodes(self): hw_nt = f"HWNodeType.RepSigGen" elif n_type == "repeat": hw_nt = f"HWNodeType.Repeat" - elif n_type == "mul" or n_type == "add" or n_type == "max" or n_type == "and": - hw_nt = f"HWNodeType.Compute" - elif n_type == "fgetfint" or n_type == "fgetffrac" or n_type == "faddiexp": - hw_nt = f"HWNodeType.Compute" - elif n_type == "fp_mul" or n_type == "fp_max" or n_type == "fp_add": - hw_nt = f"HWNodeType.Compute" elif n_type == "reduce": hw_nt = f"HWNodeType.Reduce" elif n_type == "intersect" or n_type == "union": @@ -114,13 +234,39 @@ def map_nodes(self): elif n_type == "crdhold": hw_nt = f"HWNodeType.CrdHold" elif n_type == "vectorreducer": - hw_nt = f"HWNodeType.VectorReducer " + hw_nt = f"HWNodeType.VectorReducer" else: - print(n_type) - raise SAMDotGraphLoweringError(f"Node is of type {n_type}") - + # if the current node is not any of the primitives, it must be a compute + hw_nt = f"HWNodeType.Compute" + self.alu_nodes.append(node) node.get_attributes()['hwnode'] = hw_nt + def map_alu(self): + if len(self.alu_nodes) > 0: + # coreir lib is loaded by default, need to load commonlib for smax + # and float_DW for floating point ops + c = coreir.Context() + c.load_library("commonlib") + c.load_library("float_DW") + # iterate through all compute nodes and generate their coreir spec + for alu_node in self.alu_nodes: + self.generate_coreir_spec(c, + alu_node.get_attributes(), + alu_node.get_name()) + c.save_to_file(self.collat_dir + "/alu_coreir_spec.json") + + # use metamapper to map it + # set environment variable PIPELINED to zero to disable input buffering in the alu + # in order to make sure the output comes out within the same cycle the input is given + metamapper_env = os.environ.copy() + metamapper_env["PIPELINED"] = "0" + # no need to peroform branch delay matching because we have rv interface in sparse + metamapper_env["PROVE"] = "0" + # FIXME: disable for now until verification of floating point computation is fixed + metamapper_env["MATCH_BRANCH_DELAY"] = "0" + subprocess.run(["python", "/aha/MetaMapper/scripts/map_app.py", self.collat_dir + "/alu_coreir_spec.json"], + env=metamapper_env) + def get_next_seq(self): ret = self.seq self.seq += 1 @@ -132,6 +278,131 @@ def find_node_by_name(self, name): return node assert False + def rewrite_complex_ops(self): + # parse the mapped json file, instantiate the compute/memory nodes generated by + # metamapper breaking down the complex op node + with open(self.collat_dir + "/alu_coreir_spec_mapped.json", 'r') as alu_mapped_file: + alu_mapped = json.load(alu_mapped_file) + # iterate through all the modules + modules_dict = alu_mapped["namespaces"]["global"]["modules"] + for node in self.alu_nodes: + module = modules_dict[f"ALU_{node.get_name()}_mapped"] + if len(module["instances"]) <= 3: + # for non complex op, each module contains three instances + # 1. the pe module itself + # 2. the coreir.const that supplies the op code + # 3. the coreir.const that supplies the clk_en signal + continue + complex_node_op = node.get_type().strip('"') + complex_node_label = node.get_label().strip('"') + incoming_edges = [edge for edge in self.graph.get_edges() if edge.get_destination() == node.get_name()] + outgoing_edges = [edge for edge in self.graph.get_edges() if edge.get_source() == node.get_name()] + # have more than three instances, it is a complex op + instances_dict = module["instances"] + instance_name_node_mappging = {} + for instance_name in instances_dict: + instance = instances_dict[instance_name] + # stamp out PEs and ROMs only, not the constant + if "modref" in instance and instance["modref"] == "global.PE": + # skip the bit constant PE that supplies ren data to the rom + # as the rom in sparse flow will use fiber access + if instance_name.split("_")[0] == "bit" and instance_name.split("_")[1] == "const": + continue + # the last two string of the instance name is the stance id, we only want the op + new_alu_node_op = '_'.join(instance_name.split("_")[0:-2]) + new_alu_node = pydot.Node(f"{instance_name}_{self.get_next_seq()}", + label=f"{complex_node_label}_{new_alu_node_op}", + hwnode=f"{HWNodeType.Compute}", + original_complex_op_id=node.get_name(), + is_mapped_from_complex_op=True, + type=f"{new_alu_node_op}", comment=f"type={new_alu_node_op}") + new_alu_node.create_attribute_methods(new_alu_node.get_attributes()) + self.graph.add_node(new_alu_node) + instance_name_node_mappging[instance_name] = new_alu_node + # create rom node using arrayvals + elif "genref" in instance and instance["genref"] == "memory.rom2": + attrs = {} + attrs["tensor"] = complex_node_op + rom_arrayvals_node = pydot.Node(f"{complex_node_op}_lut_{self.get_next_seq()}", + label=f"{complex_node_label}_lut", tensor=f"{complex_node_op}", + type="arrayvals", comment=f"type=arrayvals,tensor={complex_node_op}") + rom_arrayvals_node.create_attribute_methods(rom_arrayvals_node.get_attributes()) + self.graph.add_node(rom_arrayvals_node) + instance_name_node_mappging[instance_name] = rom_arrayvals_node + # connect the nodes + for connection in module["connections"]: + for i in range(2): + # the connection endpoint with 'datax', 'raddr', and 'self.out' is a connection to + # a PE, an arrayvals, or an original output of the complex op + if 'data0' in connection[i] or 'data1' in connection[i] or 'data2' in connection[i] \ + or 'raddr' in connection[i] or 'self.out' in connection[i]: + edge_attr = {} + specified_port = None + edge_type = None + # internal connection within the complex op + if 'self.out' not in connection[i]: + dest_node_name = connection[i].split(".")[0] + dest_node = instance_name_node_mappging[dest_node_name] + # for internal conection we need to specify the edge type. + # for connection to arrayvals, the connection logic in hwnodes wil take + # care of the port name, no need to specify port name here + if dest_node.get_type() == "arrayvals": + edge_type = "ref" + specified_port = None + else: + edge_type = "val" + specified_port = connection[i].split(".")[1] + # FIXME: only support a single output complex op for now + # if it is an existing outgoing edge, inherit the edge properties + else: + dest_node = outgoing_edges[0].get_destination() + edge_attr = outgoing_edges[0].get_attributes() + self.graph.del_edge(outgoing_edges[0].get_source(), outgoing_edges[0].get_destination()) + + # select the other port as src + if i == 0: + src_port = connection[1] + else: + src_port = connection[0] + src_node_name = src_port.split(".")[0] + # if the src port is a node originally connects to the input of the complex op + # inherit the edge properties of that edge + if "self.in" in src_port: + # FIXME: only support a single input complex op for now + src_node = incoming_edges[0].get_source() + # an edge connot be a incoming to and outgoing from the complex op simultaneously + assert not edge_attr + edge_attr = incoming_edges[0].get_attributes() + # connecting to a new node, use the port specified by metamapper + self.graph.del_edge(incoming_edges[0].get_source(), incoming_edges[0].get_destination()) + # the srouce node is not a PE we just stamp out, skip the connection + elif src_node_name not in instance_name_node_mappging: + break + else: + src_node = instance_name_node_mappging[src_node_name] + # a new edge + if not edge_attr: + new_edge = pydot.Edge(src=src_node, + dst=dest_node, + type=edge_type, + label=edge_type, + specified_port=specified_port, + comment=edge_type) + # existing edge, inherit its attributes + else: + new_edge = pydot.Edge(src=src_node, + dst=dest_node, + **edge_attr, + specified_port=specified_port) + + self.graph.add_edge(new_edge) + # done adding the edge for the connection, don't need to check the other port + break + # finally remove the original complex op node + self.graph.del_node(node) + # turn the lut arrayvals into FA + self.rewrite_arrays() + def rewrite_VectorReducer(self): # Get the vr node and the resulting fiberwrites @@ -211,6 +482,7 @@ def rewrite_VectorReducer(self): add = pydot.Node(f"vr_add_{self.get_next_seq()}", label=f"{og_label}_Add", hwnode=f"{HWNodeType.Compute}", type="add", sub="0", comment="type=add,sub=0") + self.alu_nodes.append(add) crd_buffet = pydot.Node(f"vr_crd_buffet_{self.get_next_seq()}", label=f"{og_label}_crd_buffet", hwnode=f"{HWNodeType.Buffet}", @@ -857,7 +1129,11 @@ def rewrite_arrays(self): ''' Rewrites the array nodes to become (lookup, buffet) triples ''' - nodes_to_proc = [node for node in self.graph.get_nodes() if 'arrayvals' in node.get_comment()] + # nodes_to_proc = [node for node in self.graph.get_nodes() if 'arrayvals' in node.get_comment()] + nodes_to_proc = [] + for node in self.graph.get_nodes(): + if 'arrayvals' in node.get_comment() and 'hwnode' not in node.get_attributes(): + nodes_to_proc.append(node) for node in nodes_to_proc: # Now we have arrayvals, let's turn it into same stuff # Rewrite this node to a read