Skip to content

Commit

Permalink
Merge pull request #117 from weiya711/sparse_metamapper
Browse files Browse the repository at this point in the history
Sparse metamapper
  • Loading branch information
kalhankoul96 authored Feb 12, 2024
2 parents ef1598f + 3be2961 commit ce5ba84
Show file tree
Hide file tree
Showing 10 changed files with 400 additions and 124 deletions.
19 changes: 3 additions & 16 deletions compiler/sam-outputs/onyx-dot/mat_elemadd_leakyrelu_exp.gv
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,9 @@ digraph SAM {
4 [comment="type=arrayvals,tensor=B" label="Array Vals: B" color=green2 shape=box style=filled type="arrayvals" tensor="B"]
3 [comment="type=fp_add" label="FP_Add" color=brown shape=box style=filled type="fp_add"]
12 [comment="broadcast" shape=point style=invis type="broadcast"]
13 [comment="type=fp_mul,rb_const=0.2" label="FP_Mul * 0.2" color=brown shape=box style=filled type="fp_mul" rb_const="0.2"]
13 [comment="type=fp_mul,const0=0.2" label="FP_Mul * 0.2" color=brown shape=box style=filled type="fp_mul" const0="0.2"]
14 [comment="type=fp_max" label="FP_Max" color=brown shape=box style=filled type="fp_max"]
15 [comment="type=fp_mul,rb_const=1.44269504089" label="FP_Mul * 1.44269504089" color=brown shape=box style=filled type="fp_mul" rb_const="1.44269504089"]
16 [comment="type=broadcast" shape=point style=invis type="broadcast"]
17 [comment="type=fgetfint" label="Fgetfint" color=brown shape=box style=filled type="fgetfint"]
18 [comment="type=fgetffrac" label="Fgetffrac" color=brown shape=box style=filled type="fgetffrac"]
19 [comment="type=and,rb_const=255" label="And 0x00FF" color=brown shape=box style=filled type="and" rb_const="255"]
20 [comment="type=faddiexp" label="Faddiexp" color=brown shape=box style=filled type="faddiexp"]
21 [comment="type=arrayvals,tensor=exp" label="Array Vals: exp" color=green2 shape=box style=filled type="arrayvals" tensor="exp"]
15 [comment="type=exp" label="Exp" color=brown shape=box style=filled type="exp"]
0 [comment="type=fiberwrite,mode=vals,tensor=X,size=1*B0_dim*B1_dim,sink=true" label="FiberWrite Vals: X" color=green3 shape=box style=filled type="fiberwrite" tensor="X" mode="vals" size="1*B0_dim*B1_dim" sink="true"]
5 [comment="type=arrayvals,tensor=C" label="Array Vals: C" color=green2 shape=box style=filled type="arrayvals" tensor="C"]
8 [comment="type=fiberlookup,index=j,tensor=C,mode=1,format=compressed,src=true,root=false" label="FiberLookup j: C1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="C" mode="1" format="compressed" src="true" root="false"]
Expand All @@ -34,14 +28,7 @@ digraph SAM {
12 -> 14 [label="val" type="val"]
13 -> 14 [label="val" type="val"]
14 -> 15 [label="val" type="val"]
15 -> 16 [label="val" type="val"]
16 -> 17 [label="val" type="val"]
16 -> 18 [label="val" type="val"]
18 -> 19 [label="val" type="val"]
19 -> 21 [label="ref" style=bold type="ref"]
21 -> 20 [label="val" type="val" comment="fp"]
17 -> 20 [label="val" type="val" comment="exp"]
20 -> 0 [label="val" type="val"]
15 -> 0 [label="val" type="val"]
6 -> 5 [label="ref_out-C" style=bold type="ref" comment="out-C"]
5 -> 3 [label="val" type="val"]
7 -> 6 [label="ref_in-B" style=bold type="ref" comment="in-B"]
Expand Down
2 changes: 1 addition & 1 deletion compiler/sam-outputs/onyx-dot/mat_elemadd_relu.gv
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ digraph SAM {
5 [comment="type=arrayvals,tensor=C" label="Array Vals: C" color=green2 shape=box style=filled type="arrayvals" tensor="C"]
8 [comment="type=fiberlookup,index=j,tensor=C,mode=1,format=compressed,src=true,root=false" label="FiberLookup j: C1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="C" mode="1" format="compressed" src="true" root="false"]
11 [comment="type=fiberlookup,index=i,tensor=C,mode=0,format=compressed,src=true,root=true" label="FiberLookup i: C0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="i" tensor="C" mode="0" format="compressed" src="true" root="true"]
12 [comment="type=max" label="Max 0" color=brown shape=box style=filled type="max"]
12 [comment="type=smax,const0=0" label="Max 0" color=brown shape=box style=filled type="smax", const0="0"]
13 [comment="type=crddrop,outer=j,inner=val,mode=0" label="CrdDrop Compression j, val" color=orange style=filled type="crddrop" outer="j" inner="val" mode="0"]
0 [comment="type=fiberwrite,mode=vals,tensor=X,size=1*B0_dim*B1_dim,sink=true" label="FiberWrite Vals: X" color=green3 shape=box style=filled type="fiberwrite" tensor="X" mode="vals" size="2*B0_dim*B1_dim" sink="true"]
14 [comment="type=crddrop,outer=i,inner=j" label="CrdDrop i,j" color=orange shape=box style=filled type="crddrop" outer="i" inner="j"]
Expand Down
2 changes: 1 addition & 1 deletion compiler/sam-outputs/onyx-dot/matmul_ijk_crddrop_relu.gv
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ digraph SAM {
0 [comment="type=fiberwrite,mode=vals,tensor=X,size=1*B0_dim*C1_dim,sink=true" label="FiberWrite Vals: X" color=green3 shape=box style=filled type="fiberwrite" tensor="X" mode="vals" size="1*B0_dim*C1_dim" sink="true"]
6 [comment="type=arrayvals,tensor=C" label="Array Vals: C" color=green2 shape=box style=filled type="arrayvals" tensor="C"]
9 [comment="type=fiberlookup,index=k,tensor=C,mode=0,format=compressed,src=true,root=false" label="FiberLookup k: C0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="k" tensor="C" mode="0" format="compressed" src="true" root="false"]
20 [comment="type=max" label="Max 0" color=brown shape=box style=filled type="max"]
20 [comment="type=smax,const0=0" label="Max 0" color=brown shape=box style=filled type="smax", const0="0"]
21 [comment="type=crddrop,outer=j,inner=val,mode=0" label="CrdDrop Compression j, val" color=orange style=filled type="crddrop" outer="j" inner="val" mode="0"]
22 [comment="type=crddrop,outer=i,inner=j" label="CrdDrop i,j" color=orange shape=box style=filled type="crddrop" outer="i" inner="j"]
17 -> 16 [label="crd" style=dashed type="crd" comment=""]
Expand Down
2 changes: 1 addition & 1 deletion compiler/sam-outputs/onyx-dot/spmm_ijk_crddrop_relu.gv
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ digraph SAM {
0 [comment="type=fiberwrite,mode=vals,tensor=X,size=1*B0_dim*C1_dim,sink=true" label="FiberWrite Vals: X" color=green3 shape=box style=filled type="fiberwrite" tensor="X" mode="vals" size="1*B0_dim*C1_dim" sink="true"]
6 [comment="type=arrayvals,tensor=C" label="Array Vals: C" color=green2 shape=box style=filled type="arrayvals" tensor="C"]
9 [comment="type=fiberlookup,index=k,tensor=C,mode=0,format=compressed,src=true,root=false" label="FiberLookup k: C0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="k" tensor="C" mode="0" format="compressed" src="true" root="false"]
20 [comment="type=max" label="Max 0" color=brown shape=box style=filled type="max"]
20 [comment="type=smax,const0=0" label="Max 0" color=brown shape=box style=filled type="smax", const0="0"]
21 [comment="type=crddrop,outer=j,inner=val,mode=0" label="CrdDrop Compression j, val" color=orange style=filled type="crddrop" outer="j" inner="val" mode="0"]
22 [comment="type=crddrop,outer=i,inner=j" label="CrdDrop i,j" color=orange shape=box style=filled type="crddrop" outer="i" inner="j"]
17 -> 16 [label="crd" style=dashed type="crd" comment=""]
Expand Down
2 changes: 1 addition & 1 deletion compiler/sam-outputs/onyx-dot/spmv_relu.gv
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ digraph SAM {
9 [comment="type=fiberlookup,index=j,tensor=B,mode=1,format=compressed,src=true,root=false" label="FiberLookup j: B1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="B" mode="1" format="compressed" src="true" root="false"]


20 [comment="type=max" label="Max 0" color=brown shape=box style=filled type="max"]
20 [comment="type=smax,const0=0" label="Max 0" color=brown shape=box style=filled type="smax", const0="0"]
0 [comment="type=fiberwrite,mode=vals,tensor=x,size=1*B0_dim,sink=true" label="FiberWrite Vals: x" color=green3 shape=box style=filled type="fiberwrite" tensor="x" mode="vals" size="1*B0_dim" sink="true"]
21 [comment="type=crddrop,outer=i,inner=val,mode=0" label="CrdDrop Compression i, val" color=orange style=filled type="crddrop" outer="i" inner="val" mode="0"]
2 [comment="type=fiberwrite,index=i,tensor=x,mode=0,format=compressed,segsize=2,crdsize=B0_dim,sink=true" label="FiberWrite i: x0\ncompressed" color=green3 shape=box style=filled type="fiberwrite" index="i" tensor="x" mode="0" format="compressed" segsize="2" crdsize="B0_dim" sink="true"]
Expand Down
130 changes: 76 additions & 54 deletions sam/onyx/hw_nodes/compute_node.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
from sam.onyx.hw_nodes.hw_node import *
from lassen.utils import float2bfbin
import json


class ComputeNode(HWNode):
def __init__(self, name=None, op=None) -> None:
def __init__(self, name=None, op=None, sam_graph_node_id=None,
mapped_coreir_dir=None, is_mapped_from_complex_op=False, original_complex_op_id=None) -> None:
super().__init__(name=name)
self.num_inputs = 2
self.num_outputs = 1
self.num_inputs_connected = 0
self.num_outputs_connected = 0
self.mapped_input_ports = []
self.op = op
self.opcode = None
self.mapped_coreir_dir = mapped_coreir_dir
# parse the mapped coreir file to get the input ports and opcode
self.parse_mapped_json(self.mapped_coreir_dir + "/alu_coreir_spec_mapped.json",
sam_graph_node_id, is_mapped_from_complex_op, original_complex_op_id)
assert self.opcode is not None

def connect(self, other, edge, kwargs=None):

Expand Down Expand Up @@ -118,22 +127,23 @@ def connect(self, other, edge, kwargs=None):
other_pe = other.get_name()
other_conn = other.get_num_inputs()
pe = self.get_name()
# TODO: remove hack eventually
if 'Max 0' in other.op:
other_conn = 1
elif 'Faddiexp' in other.op:
comment = edge.get_attributes()["comment"].strip('"')
if 'fp' in comment:
other_conn = 0
elif 'exp' in comment:
other_conn = 1
else:
assert 0 & "edge connected to faddiexp has to have comment specified to either 'exp' or 'fp'"
new_conns = {
f'pe_to_pe_{other_conn}': [
([(pe, "res"), (other_pe, f"data{other_conn}")], 17),
]
}
edge_attr = edge.get_attributes()
# a destination port name has been specified by metamapper
if "specified_port" in edge_attr and edge_attr["specified_port"] is not None:
other_conn = edge_attr["specified_port"]
other.mapped_input_ports.append(other_conn.strip("data"))
new_conns = {
f'pe_to_pe_{other_conn}': [
([(pe, "res"), (other_pe, f"{other_conn}")], 17),
]
}
else:
other_conn = other.mapped_input_ports[other_conn]
new_conns = {
f'pe_to_pe_{other_conn}': [
([(pe, "res"), (other_pe, f"data{other_conn}")], 17),
]
}
other.update_input_connections()
return new_conns
elif other_type == BroadcastNode:
Expand Down Expand Up @@ -161,6 +171,46 @@ def update_input_connections(self):
def get_num_inputs(self):
return self.num_inputs_connected

def parse_mapped_json(self, filename, node_id, is_mapped_from_complex_op, original_complex_op_id):
with open(filename, 'r') as alu_mapped_file:
alu_mapped = json.load(alu_mapped_file)
# parse out the mapped opcode
alu_instance_name = None
module = None
if not is_mapped_from_complex_op:
module = alu_mapped["namespaces"]["global"]["modules"]["ALU_" + node_id + "_mapped"]
for instance_name, instance in module["instances"].items():
if "modref" in instance and instance["modref"] == "global.PE":
alu_instance_name = instance_name
break
for connection in alu_mapped["namespaces"]["global"]["modules"]["ALU_" + node_id + "_mapped"]["connections"]:
port0, port1 = connection
if "self.in" in port0:
# get the port name of the alu
self.mapped_input_ports.append(port1.split(".")[1].strip("data"))
elif "self.in" in port1:
self.mapped_input_ports.append(port0.split(".")[1].strip("data"))
assert (len(self.mapped_input_ports) > 0)
else:
assert original_complex_op_id is not None
module = alu_mapped["namespaces"]["global"]["modules"]["ALU_" + original_complex_op_id + "_mapped"]
# node namae of a remapped alu node from a complex op is of the format <instance_name>_<id>
alu_instance_name = '_'.join(node_id.split("_")[0:-1])
# no need to find the input and output port for remapped op
# as it is already assigned when we remap the complex op and stored in the edge object
# look for the constant coreir object that supplies the opcode to the alu at question
# insturction is supplied through the "inst" port of the alu
for connection in module["connections"]:
port0, port1 = connection
if f"{alu_instance_name}.inst" in port0:
constant_name = port1.split(".")[0]
elif f"{alu_instance_name}.inst" in port1:
constant_name = port0.split(".")[0]
opcode = module["instances"][constant_name]["modargs"]["value"][1]
opcode = "0x" + opcode.split('h')[1]

self.opcode = int(opcode, 0)

def configure(self, attributes):
print("PE CONFIGURE")
print(attributes)
Expand All @@ -174,46 +224,18 @@ def configure(self, attributes):
pe_only = True
# data I/O should interface with other primitive outside of the cluster
pe_in_external = 1
if c_op == 'mul':
op_code = 1
elif c_op == 'add' and 'sub=1' not in comment:
op_code = 0
elif c_op == 'add' and 'sub=1' in comment:
op_code = 2
elif c_op == 'max':
op_code = 4
elif c_op == 'and':
op_code = 5
elif c_op == 'fp_mul':
op_code = 6
elif c_op == 'fgetfint':
op_code = 7
elif c_op == 'fgetffrac':
op_code = 8
elif c_op == 'faddiexp':
op_code = 9
elif c_op == 'fp_max':
op_code = 10
elif c_op == 'fp_add':
op_code = 11

rb_const = None
if "rb_const" in attributes:
# the b operand of the op is a constant
rb_const = attributes["rb_const"].strip('"')
if "." in rb_const:
# constant is a floating point
rb_const = float(rb_const)
rb_const = int(float2bfbin(rb_const), 2)
else:
# it is a int
rb_const = int(rb_const)
# according to the mapped input ports generate input port config
num_sparse_inputs = list("000")
for port in self.mapped_input_ports:
num_sparse_inputs[2 - int(port)] = '1'
print("".join(num_sparse_inputs))
num_sparse_inputs = int("".join(num_sparse_inputs), 2)

cfg_kwargs = {
'op': op_code,
'op': self.opcode,
'use_dense': use_dense,
'pe_only': pe_only,
'pe_in_external': pe_in_external,
'rb_const': rb_const
'num_sparse_inputs': num_sparse_inputs
}
return (op_code, use_dense, pe_only, pe_in_external, rb_const), cfg_kwargs
return (op_code, use_dense, pe_only, pe_in_external, num_sparse_inputs), cfg_kwargs
6 changes: 1 addition & 5 deletions sam/onyx/hw_nodes/intersect_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,14 @@ def connect(self, other, edge, kwargs=None):
# Could be doing a sparse accum
compute = other
compute_name = other.get_name()
print("INTERSECT TO COMPUTE EDGE!")
print(edge)
print(edge.get_attributes())
edge_comment = edge.get_attributes()['comment'].strip('"')
tensor = edge_comment.split('-')[1]
print(self.tensor_to_conn)
out_conn = self.tensor_to_conn[tensor]
compute_conn = compute.get_num_inputs()
new_conns = {
'intersect_to_repeat': [
# send output to rd scanner
([(isect, f"pos_out_{out_conn}"), (compute_name, f"data{compute_conn}")], 17),
([(isect, f"pos_out_{out_conn}"), (compute_name, f"data{compute.mapped_input_ports[compute_conn]}")], 17),
]
}
compute.update_input_connections()
Expand Down
41 changes: 18 additions & 23 deletions sam/onyx/hw_nodes/read_scanner_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,32 +194,27 @@ def connect(self, other, edge, kwargs=None):
return new_conns
elif other_type == ComputeNode:
compute = other.get_name()
# compute_conn = 0
print("CHECKING READ TENSOR - COMPUTE")
print(edge)
print(self.get_tensor())
# if self.get_tensor() == 'C' or self.get_tensor() == 'c':
# compute_conn = 1

# Can use dynamic information to assign inputs to compute nodes
# since add/mul are commutative
compute_conn = other.get_num_inputs()
# TODO: get rid of this hack
if 'Faddiexp' in other.op:
comment = edge.get_attributes()["comment"].strip('"')
if 'fp' in comment:
compute_conn = 0
elif 'exp' in comment:
compute_conn = 1
else:
assert 0 & "edge connected to faddiexp has to have comment specified to either 'exp' or 'fp'"
new_conns = {
f'rd_scan_to_compute_{compute_conn}': [
([(rd_scan, "coord_out"), (compute, f"data{compute_conn}")], 17),
]
}
# Now update the PE/compute to use the next connection next time
other.update_input_connections()
edge_attr = edge.get_attributes()
if "specified_port" in edge_attr and edge_attr["specified_port"] is not None:
compute_conn = edge_attr["specified_port"]
other.mapped_input_ports.append(compute_conn.strip("data"))
new_conns = {
f'rd_scan_to_compute_{compute_conn}': [
([(rd_scan, "coord_out"), (compute, f"{compute_conn}")], 17),
]
}
else:
compute_conn = other.mapped_input_ports[compute_conn]
new_conns = {
f'rd_scan_to_compute_{compute_conn}': [
([(rd_scan, "coord_out"), (compute, f"data{compute_conn}")], 17),
]
}
# Now update the PE/compute to use the next connection next time
other.update_input_connections()

return new_conns

Expand Down
Loading

0 comments on commit ce5ba84

Please sign in to comment.