From 8c42d9aaac614c285bcc011c9f07f7d3f9520e97 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 30 Oct 2023 14:47:19 -0700 Subject: [PATCH 1/3] matmul + relu --- sam/onyx/generate_matrices.py | 2 +- sam/onyx/hw_nodes/compute_node.py | 6 +++++- sam/onyx/hw_nodes/reduce_node.py | 5 ++++- sam/onyx/parse_dot.py | 2 +- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/sam/onyx/generate_matrices.py b/sam/onyx/generate_matrices.py index 241b46c8..0600317d 100644 --- a/sam/onyx/generate_matrices.py +++ b/sam/onyx/generate_matrices.py @@ -54,7 +54,7 @@ def _create_matrix(self, value_cap=int(math.pow(2, 8)) - 1): ''' Routine to create the actual matrix from the dimension/shape ''' - self.array = numpy.random.randint(low=0, high=value_cap, size=self.shape) + self.array = numpy.random.randint(low=-1*value_cap/2, high=value_cap/2, size=self.shape) for idx, x in numpy.ndenumerate(self.array): if random.random() < self.sparsity: self.array[idx] = 0 diff --git a/sam/onyx/hw_nodes/compute_node.py b/sam/onyx/hw_nodes/compute_node.py index 2fb7620f..ab4e9b5a 100644 --- a/sam/onyx/hw_nodes/compute_node.py +++ b/sam/onyx/hw_nodes/compute_node.py @@ -2,13 +2,15 @@ class ComputeNode(HWNode): - def __init__(self, name=None) -> None: + def __init__(self, name=None, op=None) -> None: super().__init__(name=name) self.num_inputs = 2 self.num_outputs = 1 self.num_inputs_connected = 0 self.num_outputs_connected = 0 + self.op = op + def connect(self, other, edge, kwargs=None): from sam.onyx.hw_nodes.glb_node import GLBNode @@ -157,6 +159,8 @@ def configure(self, attributes): op_code = 0 elif c_op == 'add' and 'sub=1' in comment: op_code = 2 + elif c_op == 'max': + op_code = 4 cfg_kwargs = { 'op': op_code } diff --git a/sam/onyx/hw_nodes/reduce_node.py b/sam/onyx/hw_nodes/reduce_node.py index 5caac658..a046f5fc 100644 --- a/sam/onyx/hw_nodes/reduce_node.py +++ b/sam/onyx/hw_nodes/reduce_node.py @@ -68,7 +68,10 @@ def connect(self, other, edge, kwargs=None): raise NotImplementedError(f'Cannot connect ReduceNode to {other_type}') elif other_type == ComputeNode: pe = other.get_name() - other_conn = other.get_num_inputs() + if 'Max' in other.op: + other_conn = 1 + else: + other_conn = other.get_num_inputs() new_conns = { f'reduce_to_pe_{other_conn}': [ # send output to rd scanner diff --git a/sam/onyx/parse_dot.py b/sam/onyx/parse_dot.py index 7fae6744..1475f6ec 100644 --- a/sam/onyx/parse_dot.py +++ b/sam/onyx/parse_dot.py @@ -99,7 +99,7 @@ def map_nodes(self): hw_nt = f"HWNodeType.RepSigGen" elif n_type == "repeat": hw_nt = f"HWNodeType.Repeat" - elif n_type == "mul" or n_type == "add": + elif n_type == "mul" or n_type == "add" or n_type == "max": hw_nt = f"HWNodeType.Compute" elif n_type == "reduce": hw_nt = f"HWNodeType.Reduce" From e63d35247561a02103eae24de043d8eb4becc395 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 2 Nov 2023 13:33:00 -0700 Subject: [PATCH 2/3] fix style --- sam/onyx/generate_matrices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sam/onyx/generate_matrices.py b/sam/onyx/generate_matrices.py index 0600317d..dbba52a1 100644 --- a/sam/onyx/generate_matrices.py +++ b/sam/onyx/generate_matrices.py @@ -54,7 +54,7 @@ def _create_matrix(self, value_cap=int(math.pow(2, 8)) - 1): ''' Routine to create the actual matrix from the dimension/shape ''' - self.array = numpy.random.randint(low=-1*value_cap/2, high=value_cap/2, size=self.shape) + self.array = numpy.random.randint(low=-1 * value_cap / 2, high=value_cap / 2, size=self.shape) for idx, x in numpy.ndenumerate(self.array): if random.random() < self.sparsity: self.array[idx] = 0 From 2d9960baa2a0fb2721bcf93eb5cfde59509b5a38 Mon Sep 17 00:00:00 2001 From: Bo Wun Cheng Date: Fri, 24 Nov 2023 16:26:57 -0800 Subject: [PATCH 3/3] added code to support routing from ComputeNode to Max, added graph for matmul_ijk_crddrop_relu, mat_elemadd_relu, spmm_ijk_crddrop_relu --- .../sam-outputs/onyx-dot/mat_elemadd_relu.gv | 42 +++++++++++++ .../onyx-dot/matmul_ijk_crddrop_relu.gv | 60 +++++++++++++++++++ .../onyx-dot/spmm_ijk_crddrop_relu.gv | 60 +++++++++++++++++++ sam/onyx/hw_nodes/compute_node.py | 5 ++ 4 files changed, 167 insertions(+) create mode 100644 compiler/sam-outputs/onyx-dot/mat_elemadd_relu.gv create mode 100644 compiler/sam-outputs/onyx-dot/matmul_ijk_crddrop_relu.gv create mode 100644 compiler/sam-outputs/onyx-dot/spmm_ijk_crddrop_relu.gv diff --git a/compiler/sam-outputs/onyx-dot/mat_elemadd_relu.gv b/compiler/sam-outputs/onyx-dot/mat_elemadd_relu.gv new file mode 100644 index 00000000..6e5a47ae --- /dev/null +++ b/compiler/sam-outputs/onyx-dot/mat_elemadd_relu.gv @@ -0,0 +1,42 @@ +digraph SAM { + comment="X=ss01,B=ss01,C=ss01" + 10 [comment="type=fiberlookup,index=i,tensor=B,mode=0,format=compressed,src=true,root=true" label="FiberLookup i: B0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="i" tensor="B" mode="0" format="compressed" src="true" root="true"] + 9 [comment="type=union,index=i" label="union i" color=purple shape=box style=filled type="union" index="i"] + 7 [comment="type=fiberlookup,index=j,tensor=B,mode=1,format=compressed,src=true,root=false" label="FiberLookup j: B1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="B" mode="1" format="compressed" src="true" root="false"] + 6 [comment="type=union,index=j" label="union j" color=purple shape=box style=filled type="union" index="j"] + 4 [comment="type=arrayvals,tensor=B" label="Array Vals: B" color=green2 shape=box style=filled type="arrayvals" tensor="B"] + 3 [comment="type=add,sub=0" label="Add" color=brown shape=box style=filled type="add" sub="0"] + 5 [comment="type=arrayvals,tensor=C" label="Array Vals: C" color=green2 shape=box style=filled type="arrayvals" tensor="C"] + 8 [comment="type=fiberlookup,index=j,tensor=C,mode=1,format=compressed,src=true,root=false" label="FiberLookup j: C1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="C" mode="1" format="compressed" src="true" root="false"] + 11 [comment="type=fiberlookup,index=i,tensor=C,mode=0,format=compressed,src=true,root=true" label="FiberLookup i: C0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="i" tensor="C" mode="0" format="compressed" src="true" root="true"] + 12 [comment="type=max" label="Max 0" color=brown shape=box style=filled type="max"] + 13 [comment="type=crddrop,outer=j,inner=val,mode=0" label="CrdDrop Compression j, val" color=orange style=filled type="crddrop" outer="j" inner="val" mode="0"] + 0 [comment="type=fiberwrite,mode=vals,tensor=X,size=1*B0_dim*B1_dim,sink=true" label="FiberWrite Vals: X" color=green3 shape=box style=filled type="fiberwrite" tensor="X" mode="vals" size="2*B0_dim*B1_dim" sink="true"] + 14 [comment="type=crddrop,outer=i,inner=j" label="CrdDrop i,j" color=orange shape=box style=filled type="crddrop" outer="i" inner="j"] + 2 [comment="type=fiberwrite,index=i,tensor=X,mode=0,format=compressed,segsize=2,crdsize=B0_dim,sink=true" label="FiberWrite i: X0\ncompressed" color=green3 shape=box style=filled type="fiberwrite" index="i" tensor="X" mode="0" format="compressed" segsize="2" crdsize="B0_dim" sink="true"] + 1 [comment="type=fiberwrite,index=j,tensor=X,mode=1,format=compressed,segsize=B0_dim+1,crdsize=B0_dim*B1_dim,sink=true" label="FiberWrite j: X1\ncompressed" color=green3 shape=box style=filled type="fiberwrite" index="j" tensor="X" mode="1" format="compressed" segsize="B0_dim+1" crdsize="B0_dim*B1_dim" sink="true"] + 10 -> 9 [label="crd_in-B" style=dashed type="crd" comment="in-B"] + 9 -> 7 [label="ref_out-B" style=bold type="ref" comment="out-B"] + 7 -> 6 [label="crd_in-B" style=dashed type="crd" comment="in-B"] + 6 -> 4 [label="ref_out-B" style=bold type="ref" comment="out-B"] + 4 -> 3 [label="val" type="val"] + 6 -> 5 [label="ref_out-C" style=bold type="ref" comment="out-C"] + 5 -> 3 [label="val" type="val"] + 7 -> 6 [label="ref_in-B" style=bold type="ref" comment="in-B"] + 9 -> 8 [label="ref_out-C" style=bold type="ref" comment="out-C"] + 8 -> 6 [label="crd_in-C" style=dashed type="crd" comment="in-C"] + 8 -> 6 [label="ref_in-C" style=bold type="ref" comment="in-C"] + 10 -> 9 [label="ref_in-B" style=bold type="ref" comment="in-B"] + 11 -> 9 [label="crd_in-C" style=dashed type="crd" comment="in-C"] + 11 -> 9 [label="ref_in-C" style=bold type="ref" comment="in-C"] + + 3 -> 12 [label="val" type="val" comment="val"] + 12 -> 13 [label="val" type="val" comment="inner-val"] + 6 -> 13 [label="crd_outer-j" style=dashed type="crd" comment="outer-j"] + 13 -> 0 [label="val" type="val", comment="val"] + 13 -> 14 [label="crd_inner-j" style=dashed type="crd" comment="inner-j"] + 9 -> 14 [label="crd_outer-i" style=dashed type="crd" comment="outer-i"] + 14 -> 2 [label="crd_outer-i" style=dashed type="crd" comment="outer-i"] + 14 -> 1 [label="crd_inner-j" style=dashed type="crd" comment="inner-j"] + +} diff --git a/compiler/sam-outputs/onyx-dot/matmul_ijk_crddrop_relu.gv b/compiler/sam-outputs/onyx-dot/matmul_ijk_crddrop_relu.gv new file mode 100644 index 00000000..5625ca41 --- /dev/null +++ b/compiler/sam-outputs/onyx-dot/matmul_ijk_crddrop_relu.gv @@ -0,0 +1,60 @@ +digraph SAM { + comment="X=ss01,B=ss01,C=ss10" + 17 [comment="type=fiberlookup,index=i,tensor=B,mode=0,format=compressed,src=true,root=true" label="FiberLookup i: B0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="i" tensor="B" mode="0" format="compressed" src="true" root="true"] + 16 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 2 [comment="type=fiberwrite,index=i,tensor=X,mode=0,format=compressed,segsize=2,crdsize=B0_dim,sink=true" label="FiberWrite i: X0\ncompressed" color=green3 shape=box style=filled type="fiberwrite" index="i" tensor="X" mode="0" format="compressed" segsize="2" crdsize="B0_dim" sink="true"] + 15 [comment="type=repsiggen,index=i" label="RepeatSignalGenerator i" color=cyan3 shape=box style=filled type="repsiggen" index="i"] + 14 [comment="type=repeat,index=i,tensor=C,root=true" label="Repeat i: C" color=cyan2 shape=box style=filled type="repeat" index="i" tensor="C" root="true"] + 13 [comment="type=fiberlookup,index=j,tensor=C,mode=1,format=compressed,src=true,root=false" label="FiberLookup j: C1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="C" mode="1" format="compressed" src="true" root="false"] + 12 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 1 [comment="type=fiberwrite,index=j,tensor=X,mode=1,format=compressed,segsize=B0_dim+1,crdsize=B0_dim*C1_dim,sink=true" label="FiberWrite j: X1\ncompressed" color=green3 shape=box style=filled type="fiberwrite" index="j" tensor="X" mode="1" format="compressed" segsize="B0_dim+1" crdsize="B0_dim*C1_dim" sink="true"] + 11 [comment="type=repsiggen,index=j" label="RepeatSignalGenerator j" color=cyan3 shape=box style=filled type="repsiggen" index="j"] + 10 [comment="type=repeat,index=j,tensor=B,root=false" label="Repeat j: B" color=cyan2 shape=box style=filled type="repeat" index="j" tensor="B" root="false"] + 8 [comment="type=fiberlookup,index=k,tensor=B,mode=1,format=compressed,src=true,root=false" label="FiberLookup k: B1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="k" tensor="B" mode="1" format="compressed" src="true" root="false"] + 7 [comment="type=intersect,index=k" label="intersect k" color=purple shape=box style=filled type="intersect" index="k"] + 5 [comment="type=arrayvals,tensor=B" label="Array Vals: B" color=green2 shape=box style=filled type="arrayvals" tensor="B"] + 4 [comment="type=mul" label="Mul" color=brown shape=box style=filled type="mul"] + 3 [comment="type=reduce" label="Reduce" color=brown shape=box style=filled type="reduce"] + 0 [comment="type=fiberwrite,mode=vals,tensor=X,size=1*B0_dim*C1_dim,sink=true" label="FiberWrite Vals: X" color=green3 shape=box style=filled type="fiberwrite" tensor="X" mode="vals" size="1*B0_dim*C1_dim" sink="true"] + 6 [comment="type=arrayvals,tensor=C" label="Array Vals: C" color=green2 shape=box style=filled type="arrayvals" tensor="C"] + 9 [comment="type=fiberlookup,index=k,tensor=C,mode=0,format=compressed,src=true,root=false" label="FiberLookup k: C0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="k" tensor="C" mode="0" format="compressed" src="true" root="false"] + 18 [comment="type=crddrop,outer=i,inner=j" label="CrdDrop i,j" color=orange shape=box style=filled type="crddrop" outer="i" inner="j"] + 19 [comment="type=crddrop,outer=j,inner=k" label="CrdDrop j,k" color=orange shape=box style=filled type="crddrop" outer="j" inner="k"] + 20 [comment="type=max" label="Max 0" color=brown shape=box style=filled type="max"] + 21 [comment="type=crddrop,outer=j,inner=val,mode=0" label="CrdDrop Compression j, val" color=orange style=filled type="crddrop" outer="j" inner="val" mode="0"] + 22 [comment="type=crddrop,outer=i,inner=j" label="CrdDrop i,j" color=orange shape=box style=filled type="crddrop" outer="i" inner="j"] + 17 -> 16 [label="crd" style=dashed type="crd" comment=""] + 16 -> 15 [label="crd" style=dashed type="crd"] + 15 -> 14 [label="repsig" style=dotted type="repsig"] + 14 -> 13 [label="ref" style=bold type="ref"] + 13 -> 12 [label="crd" style=dashed type="crd" comment=""] + 12 -> 11 [label="crd" style=dashed type="crd"] + 11 -> 10 [label="repsig" style=dotted type="repsig"] + 10 -> 8 [label="ref" style=bold type="ref"] + 8 -> 7 [label="crd_in-B" style=dashed type="crd" comment="in-B"] + 7 -> 5 [label="ref_out-B" style=bold type="ref" comment="out-B"] + 5 -> 4 [label="val" type="val"] + 7 -> 6 [label="ref_out-C" style=bold type="ref" comment="out-C"] + 6 -> 4 [label="val" type="val"] + 8 -> 7 [label="ref_in-B" style=bold type="ref" comment="in-B"] + 13 -> 9 [label="ref" style=bold type="ref" comment=""] + 9 -> 7 [label="crd_in-C" style=dashed type="crd" comment="in-C"] + 9 -> 7 [label="ref_in-C" style=bold type="ref" comment="in-C"] + 17 -> 10 [label="ref" style=bold type="ref" comment=""] + + 4 -> 19 [label="val_inner-k" type="val" comment="inner-k"] + 12 -> 19 [label="crd_outer-j" style=dashed type="crd" comment="outer-j"] + 19 -> 3 [label="val_inner-k" type="val" comment="inner-k"] + 3 -> 20 [label="val" type="val" comment="val"] + 20 -> 21 [label="val" type="val" comment="inner-val"] + 18 -> 21 [label="crd_inner-j" style=dashed type="crd" comment="outer-j"] + 21 -> 22 [label="crd_inner-j" style=dashed type="crd" comment="outer-j"] + 18 -> 22 [label="crd_outer-i" style=dashed type="crd" comment="outer-i"] + + 19 -> 18 [label="crd_inner-j" style=dashed type="crd" comment="inner-j"] + 16 -> 18 [label="crd_outer-i" style=dashed type="crd" comment="outer-i"] + 21 -> 0 [label="val" type="val" comment="inner-val"] + + 22 -> 2 [label="crd_outer-i" style=dashed type="crd" comment="outer-i"] + 22 -> 1 [label="crd_inner-j" style=dashed type="crd" comment="inner-j"] +} diff --git a/compiler/sam-outputs/onyx-dot/spmm_ijk_crddrop_relu.gv b/compiler/sam-outputs/onyx-dot/spmm_ijk_crddrop_relu.gv new file mode 100644 index 00000000..65e5e1ad --- /dev/null +++ b/compiler/sam-outputs/onyx-dot/spmm_ijk_crddrop_relu.gv @@ -0,0 +1,60 @@ +digraph SAM { + comment="X=ss01,B=dd01,C=ss10" + 17 [comment="type=fiberlookup,index=i,tensor=B,mode=0,format=dense,src=true,root=true" label="FiberLookup i: B0\ndense" color=green4 shape=box style=filled type="fiberlookup" index="i" tensor="B" mode="0" format="dense" src="true" root="true"] + 16 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 2 [comment="type=fiberwrite,index=i,tensor=X,mode=0,format=compressed,segsize=2,crdsize=B0_dim,sink=true" label="FiberWrite i: X0\ncompressed" color=green3 shape=box style=filled type="fiberwrite" index="i" tensor="X" mode="0" format="compressed" segsize="2" crdsize="B0_dim" sink="true"] + 15 [comment="type=repsiggen,index=i" label="RepeatSignalGenerator i" color=cyan3 shape=box style=filled type="repsiggen" index="i"] + 14 [comment="type=repeat,index=i,tensor=C,root=true" label="Repeat i: C" color=cyan2 shape=box style=filled type="repeat" index="i" tensor="C" root="true"] + 13 [comment="type=fiberlookup,index=j,tensor=C,mode=1,format=compressed,src=true,root=false" label="FiberLookup j: C1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="C" mode="1" format="compressed" src="true" root="false"] + 12 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 1 [comment="type=fiberwrite,index=j,tensor=X,mode=1,format=compressed,segsize=B0_dim+1,crdsize=B0_dim*C1_dim,sink=true" label="FiberWrite j: X1\ncompressed" color=green3 shape=box style=filled type="fiberwrite" index="j" tensor="X" mode="1" format="compressed" segsize="B0_dim+1" crdsize="B0_dim*C1_dim" sink="true"] + 11 [comment="type=repsiggen,index=j" label="RepeatSignalGenerator j" color=cyan3 shape=box style=filled type="repsiggen" index="j"] + 10 [comment="type=repeat,index=j,tensor=B,root=false" label="Repeat j: B" color=cyan2 shape=box style=filled type="repeat" index="j" tensor="B" root="false"] + 8 [comment="type=fiberlookup,index=k,tensor=B,mode=1,format=dense,src=true,root=false" label="FiberLookup k: B1\ndense" color=green4 shape=box style=filled type="fiberlookup" index="k" tensor="B" mode="1" format="dense" src="true" root="false"] + 7 [comment="type=intersect,index=k" label="intersect k" color=purple shape=box style=filled type="intersect" index="k"] + 5 [comment="type=arrayvals,tensor=B" label="Array Vals: B" color=green2 shape=box style=filled type="arrayvals" tensor="B"] + 4 [comment="type=mul" label="Mul" color=brown shape=box style=filled type="mul"] + 3 [comment="type=reduce" label="Reduce" color=brown shape=box style=filled type="reduce"] + 0 [comment="type=fiberwrite,mode=vals,tensor=X,size=1*B0_dim*C1_dim,sink=true" label="FiberWrite Vals: X" color=green3 shape=box style=filled type="fiberwrite" tensor="X" mode="vals" size="1*B0_dim*C1_dim" sink="true"] + 6 [comment="type=arrayvals,tensor=C" label="Array Vals: C" color=green2 shape=box style=filled type="arrayvals" tensor="C"] + 9 [comment="type=fiberlookup,index=k,tensor=C,mode=0,format=compressed,src=true,root=false" label="FiberLookup k: C0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="k" tensor="C" mode="0" format="compressed" src="true" root="false"] + 18 [comment="type=crddrop,outer=i,inner=j" label="CrdDrop i,j" color=orange shape=box style=filled type="crddrop" outer="i" inner="j"] + 19 [comment="type=crddrop,outer=j,inner=k" label="CrdDrop j,k" color=orange shape=box style=filled type="crddrop" outer="j" inner="k"] + 20 [comment="type=max" label="Max 0" color=brown shape=box style=filled type="max"] + 21 [comment="type=crddrop,outer=j,inner=val,mode=0" label="CrdDrop Compression j, val" color=orange style=filled type="crddrop" outer="j" inner="val" mode="0"] + 22 [comment="type=crddrop,outer=i,inner=j" label="CrdDrop i,j" color=orange shape=box style=filled type="crddrop" outer="i" inner="j"] + 17 -> 16 [label="crd" style=dashed type="crd" comment=""] + 16 -> 15 [label="crd" style=dashed type="crd"] + 15 -> 14 [label="repsig" style=dotted type="repsig"] + 14 -> 13 [label="ref" style=bold type="ref"] + 13 -> 12 [label="crd" style=dashed type="crd" comment=""] + 12 -> 11 [label="crd" style=dashed type="crd"] + 11 -> 10 [label="repsig" style=dotted type="repsig"] + 10 -> 8 [label="ref" style=bold type="ref"] + 8 -> 7 [label="crd_in-B" style=dashed type="crd" comment="in-B"] + 7 -> 5 [label="ref_out-B" style=bold type="ref" comment="out-B"] + 5 -> 4 [label="val" type="val"] + 7 -> 6 [label="ref_out-C" style=bold type="ref" comment="out-C"] + 6 -> 4 [label="val" type="val"] + 8 -> 7 [label="ref_in-B" style=bold type="ref" comment="in-B"] + 13 -> 9 [label="ref" style=bold type="ref" comment=""] + 9 -> 7 [label="crd_in-C" style=dashed type="crd" comment="in-C"] + 9 -> 7 [label="ref_in-C" style=bold type="ref" comment="in-C"] + 17 -> 10 [label="ref" style=bold type="ref" comment=""] + + 4 -> 19 [label="val_inner-k" type="val" comment="inner-k"] + 12 -> 19 [label="crd_outer-j" style=dashed type="crd" comment="outer-j"] + 19 -> 3 [label="val_inner-k" type="val" comment="inner-k"] + 3 -> 20 [label="val" type="val" comment="val"] + 20 -> 21 [label="val" type="val" comment="inner-val"] + 18 -> 21 [label="crd_inner-j" style=dashed type="crd" comment="outer-j"] + 21 -> 22 [label="crd_inner-j" style=dashed type="crd" comment="outer-j"] + 18 -> 22 [label="crd_outer-i" style=dashed type="crd" comment="outer-i"] + + 19 -> 18 [label="crd_inner-j" style=dashed type="crd" comment="inner-j"] + 16 -> 18 [label="crd_outer-i" style=dashed type="crd" comment="outer-i"] + 21 -> 0 [label="val" type="val" comment="inner-val"] + + 22 -> 2 [label="crd_outer-i" style=dashed type="crd" comment="outer-i"] + 22 -> 1 [label="crd_inner-j" style=dashed type="crd" comment="inner-j"] +} diff --git a/sam/onyx/hw_nodes/compute_node.py b/sam/onyx/hw_nodes/compute_node.py index 61a049dc..96e492bb 100644 --- a/sam/onyx/hw_nodes/compute_node.py +++ b/sam/onyx/hw_nodes/compute_node.py @@ -118,6 +118,11 @@ def connect(self, other, edge, kwargs=None): other_pe = other.get_name() other_conn = other.get_num_inputs() pe = self.get_name() + # TODO: remove hack eventually + if 'Max' in other.op: + other_conn = 1 + else: + other_conn = other.get_num_inputs() new_conns = { f'pe_to_pe_{other_conn}': [ ([(pe, "res"), (other_pe, f"data{other_conn}")], 17),