Skip to content

Commit

Permalink
VR connection rules and SAM graph changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mcoduoza committed Nov 16, 2023
1 parent f8d2cd3 commit f181dd0
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 41 deletions.
4 changes: 2 additions & 2 deletions compiler/sam-outputs/dot/matmul_ikj_hand_BLACKBOX.gv
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
digraph SAM {
comment="X=ss01,B=ss01,C=ss01"
20 [comment="type=vectorreducer,index=j" label="VectorReducer j" color=brown shape=box style=filled type="vectorreducer" index="j"]
20 [comment="type=vectorreducer,index=j" label="VectorReducer j" color=brown shape=box style=filled type="vectorreducer" accum_index="j"]
0 [comment="type=fiberwrite,mode=vals,tensor=X,size=1*B0_dim*C1_dim,sink=true" label="FiberWrite Vals: X" color=green3 shape=box style=filled type="fiberwrite" tensor="X" mode="vals" size="1*B0_dim*C1_dim" sink="true"]
1 [comment="type=fiberwrite,index=j,tensor=X,mode=1,format=compressed,segsize=B0_dim+1,crdsize=B0_dim*C1_dim,sink=true" label="FiberWrite j: X1\ncompressed" color=green3 shape=box style=filled type="fiberwrite" index="j" tensor="X" mode="1" format="compressed" segsize="B0_dim+1" crdsize="B0_dim*C1_dim" sink="true"]
19 [comment="type=fiberlookup,index=i,tensor=B,mode=0,format=compressed,src=true,root=true" label="FiberLookup i: B0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="i" tensor="B" mode="0" format="compressed" src="true" root="true"]
Expand Down Expand Up @@ -30,7 +30,7 @@ digraph SAM {
7 -> 6 [label="val" type="val"]
13 -> 12 [label="ref_out-C" style=bold type="ref" comment="out-C"]
12 -> 11 [label="crd" style=dashed type="crd" comment=""]
16 -> 4 [label="crd_i" style=dashed type="crd" comment="i"]
19 -> 4 [label="crd_i" style=dashed type="crd" comment="i"]
11 -> 4 [label="crd_j" style=dashed type="crd" comment="j"]
11 -> 10 [label="crd" style=dashed type="crd" comment=""]
10 -> 9 [label="repsig" style=dotted type="repsig"]
Expand Down
9 changes: 3 additions & 6 deletions compiler/sam-outputs/onyx-dot/matmul_ikj_hand_BLACKBOX.gv
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
digraph SAM {
comment="X=ss01,B=ss01,C=ss01"
20 [comment="type=vectorreducer,index=j" label="VectorReducer j" color=brown shape=box style=filled type="vectorreducer" index="j"]
20 [comment="type=vectorreducer,index=j" label="VectorReducer j" color=brown shape=box style=filled type="vectorreducer" accum_index="j"]
0 [comment="type=fiberwrite,mode=vals,tensor=X,size=1*B0_dim*C1_dim,sink=true" label="FiberWrite Vals: X" color=green3 shape=box style=filled type="fiberwrite" tensor="X" mode="vals" size="1*B0_dim*C1_dim" sink="true"]
1 [comment="type=fiberwrite,index=j,tensor=X,mode=1,format=compressed,segsize=B0_dim+1,crdsize=B0_dim*C1_dim,sink=true" label="FiberWrite j: X1\ncompressed" color=green3 shape=box style=filled type="fiberwrite" index="j" tensor="X" mode="1" format="compressed" segsize="B0_dim+1" crdsize="B0_dim*C1_dim" sink="true"]
19 [comment="type=fiberlookup,index=i,tensor=B,mode=0,format=compressed,src=true,root=true" label="FiberLookup i: B0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="i" tensor="B" mode="0" format="compressed" src="true" root="true"]
18 [comment="type=broadcast" shape=point style=invis type="broadcast"]
4 [comment="type=crddrop,outer=i,inner=j" label="CrdDrop i,j" color=orange shape=box style=filled type="crddrop" outer="i" inner="j"]
2 [comment="type=fiberwrite,index=i,tensor=X,mode=0,format=compressed,segsize=2,crdsize=B0_dim,sink=true" label="FiberWrite i: X0\ncompressed" color=green3 shape=box style=filled type="fiberwrite" index="i" tensor="X" mode="0" format="compressed" segsize="2" crdsize="B0_dim" sink="true"]
17 [comment="type=repsiggen,index=i" label="RepeatSignalGenerator i" color=cyan3 shape=box style=filled type="repsiggen" index="i"]
16 [comment="type=repeat,index=i,tensor=C,root=true" label="Repeat i: C" color=cyan2 shape=box style=filled type="repeat" index="i" tensor="C" root="true"]
Expand All @@ -20,7 +19,6 @@ digraph SAM {
8 [comment="type=arrayvals,tensor=C" label="Array Vals: C" color=green2 shape=box style=filled type="arrayvals" tensor="C"]
14 [comment="type=fiberlookup,index=k,tensor=B,mode=1,format=compressed,src=true,root=false" label="FiberLookup k: B1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="k" tensor="B" mode="1" format="compressed" src="true" root="false"]
19 -> 18 [label="crd" style=dashed type="crd" comment=""]
4 -> 2 [label="crd_out-i" style=dashed type="crd" comment="out-i"]
18 -> 17 [label="crd" style=dashed type="crd" comment=""]
17 -> 16 [label="repsig" style=dotted type="repsig"]
16 -> 15 [label="ref" style=bold type="ref"]
Expand All @@ -30,8 +28,8 @@ digraph SAM {
7 -> 6 [label="val" type="val"]
13 -> 12 [label="ref_out-C" style=bold type="ref" comment="out-C"]
12 -> 11 [label="crd" style=dashed type="crd" comment=""]
16 -> 4 [label="crd_i" style=dashed type="crd" comment="i"]
11 -> 4 [label="crd_j" style=dashed type="crd" comment="j"]
19 -> 2 [label="crd_i" style=dashed type="crd" comment="i"]
11 -> 20 [label="crd_j" style=dashed type="crd" comment="j" special="true"]
11 -> 10 [label="crd" style=dashed type="crd" comment=""]
10 -> 9 [label="repsig" style=dotted type="repsig"]
12 -> 8 [label="ref" style=bold type="ref" comment=""]
Expand All @@ -40,7 +38,6 @@ digraph SAM {
19 -> 14 [label="ref" style=bold type="ref" comment=""]
14 -> 13 [label="crd_in-B" style=dashed type="crd" comment="in-B"]
14 -> 13 [label="ref_in-B" style=bold type="ref" comment="in-B"]
4 -> 20 [label="crddrp-crd-j-out" style=dashed type="crd" comment="crddrp-crd-j-out"]
6 -> 20 [label="mul_val_out" type="val"]
20 -> 0 [label="final_vals" type="val"]
20 -> 1 [label="crd_out-j" style=dashed type="crd" comment="out-j"]
Expand Down
5 changes: 3 additions & 2 deletions sam/onyx/hw_nodes/compute_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def connect(self, other, edge, kwargs=None):
pe = self.get_name()
# isect_conn = other.get_num_inputs()

if 'vector_reduce_mode' in edge.get_comment():
isect_conn = 0
if 'vector_reduce_mode' in edge.get_attributes():
if edge.get_attributes()['vector_reduce_mode'] == True:
isect_conn = 0
else:
if 'tensor' not in edge.get_attributes():
# Taking some liberties here - but technically this is the combo val
Expand Down
3 changes: 3 additions & 0 deletions sam/onyx/hw_nodes/fiberaccess_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ def configure(self, attributes, flavor):

cfg_tuple, cfg_kwargs = self.get_flavor(flavor=flavor).configure(attributes)
cfg_kwargs['flavor'] = flavor
print("THESE ARE MY CONFIG KWARGS")
print(cfg_kwargs)
#breakpoint()

#vr_mode = 0
#cfg_tuple += (vr_mode,)
Expand Down
1 change: 1 addition & 0 deletions sam/onyx/hw_nodes/intersect_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def connect(self, other, edge, kwargs=None):
print(edge.get_attributes())
edge_comment = edge.get_attributes()['comment'].strip('"')
tensor = edge_comment.split('-')[1]
print(self.tensor_to_conn)
out_conn = self.tensor_to_conn[tensor]
compute_conn = compute.get_num_inputs()
new_conns = {
Expand Down
12 changes: 8 additions & 4 deletions sam/onyx/hw_nodes/read_scanner_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ def connect(self, other, edge, kwargs=None):
edge_attr = edge.get_attributes()
if 'use_alt_out_port' in edge_attr:
out_conn = 'block_rd_out'
elif (edge.get_comment() is not None) and ('vector_reduce_mode' in edge.get_comment()):
out_conn = 'pos_out'
elif ('vector_reduce_mode' in edge_attr):
if (edge_attr['vector_reduce_mode'] == True):
out_conn = 'pos_out'
else:
out_conn = 'coord_out'

Expand All @@ -104,8 +105,11 @@ def connect(self, other, edge, kwargs=None):
elif other_type == IntersectNode:
# Send both....
isect = other.get_name()
if 'vector_reduce_mode' in edge.get_comment():
isect_conn = 1
if 'vector_reduce_mode' in edge.get_attributes():
if edge.get_attributes()['vector_reduce_mode'] == True:
isect_conn = 1
elif 'special' in edge.get_attributes():
isect_conn = 0
else:
isect_conn = other.get_connection_from_tensor(self.get_tensor())

Expand Down
68 changes: 41 additions & 27 deletions sam/onyx/parse_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def rewrite_VectorReducer(self):
del attrs['label']

# TODO: Get redux crd
output_crd = attrs['index'].strip('"')
output_crd = attrs['accum_index'].strip('"')
#input_crd = None

incoming_edges = [edge for edge in self.graph.get_edges() if edge.get_destination() == vr_node.get_name()]
Expand Down Expand Up @@ -198,26 +198,24 @@ def rewrite_VectorReducer(self):
# **attrs, label=f"{og_label}_repeat", hwnode=f"{HWNodeType.Repeat}",
# root="true", type=og_type, spacc="true")

union = pydot.Node(f"vr_union_{self.get_next_seq()}",
**attrs, label=f"{og_label}_union", hwnode=f"{HWNodeType.Intersect}",
type="union", vector_reduce_mode="true")
union = pydot.Node(f"vr_union_{self.get_next_seq()}", label=f"{og_label}_union", hwnode=f"{HWNodeType.Intersect}",
type="union", vector_reduce_mode="true", comment=f"type=union,index={output_crd}", index=output_crd)

add = pydot.Node(f"vr_add_{self.get_next_seq()}",
**attrs, label=f"{og_label}_add", hwnode=f"{HWNodeType.Compute}",
type=og_type)
add = pydot.Node(f"vr_add_{self.get_next_seq()}", label=f"{og_label}_Add", hwnode=f"{HWNodeType.Compute}",
type="add", sub="0", comment="type=add,sub=0")

crd_buffet = pydot.Node(f"vr_crd_buffet_{self.get_next_seq()}",
**attrs, label=f"{og_label}_crd_buffet", hwnode=f"{HWNodeType.Buffet}",
type=og_type, vector_reduce_mode="true", fa_color=self.fa_color)
label=f"{og_label}_crd_buffet", hwnode=f"{HWNodeType.Buffet}",
type="buffet", vector_reduce_mode="true", fa_color=self.fa_color, comment="crd_buffet")

crd_rd_scanner = pydot.Node(f"vr_crd_rd_scanner_{self.get_next_seq()}",
**attrs, label=f"{og_label}_crd_rd_scanner", hwnode=f"{HWNodeType.ReadScanner}",
tensor="x", type=og_type, root="false", format="compressed",
mode="0", index=f"{output_crd}", vector_reduce_mode="true", fa_color=self.fa_color)
label=f"{og_label}_crd_rd_scanner", hwnode=f"{HWNodeType.ReadScanner}",
tensor="X", type="fiberlookup", root="false", format="compressed",
mode="0", index=f"{output_crd}", vector_reduce_mode="true", fa_color=self.fa_color, comment="crd_rd_scanner")

crd_wr_scanner = pydot.Node(f"vr_crd_wr_scanner_{self.get_next_seq()}",
**attrs, label=f"{og_label}_crd_wr_scanner", hwnode=f"{HWNodeType.WriteScanner}",
type=og_type, mode="0", format="compressed", vector_reduce_mode="true", fa_color=self.fa_color)
label=f"{og_label}_crd_wr_scanner", hwnode=f"{HWNodeType.WriteScanner}",
type="fiberwrite", mode="0", format="compressed", vector_reduce_mode="true", fa_color=self.fa_color, comment="crd_wr_scanner")

self.fa_color += 1

Expand All @@ -226,17 +224,28 @@ def rewrite_VectorReducer(self):
# tensor="x", mode="0", format="compressed", type=og_type)

vals_buffet = pydot.Node(f"vr_vals_buffet_{self.get_next_seq()}",
**attrs, label=f"{og_label}_vals_buffet", hwnode=f"{HWNodeType.Buffet}",
type=og_type, vector_reduce_mode="true", fa_color=self.fa_color)
label=f"{og_label}_vals_buffet", hwnode=f"{HWNodeType.Buffet}",
type="buffet", vector_reduce_mode="true", fa_color=self.fa_color, comment="vals_buffet")

#vals_rd_scanner = pydot.Node(f"vr_vals_rd_scanner_{self.get_next_seq()}",
# label=f"{og_label}_vals_rd_scanner", hwnode=f"{HWNodeType.ReadScanner}",
# tensor="X", type="arrayvals", root="false", format="vals",
# mode="vals", vector_reduce_mode="true", fa_color=self.fa_color, comment="vals_rd_scanner")

vals_rd_scanner = pydot.Node(f"vr_vals_rd_scanner_{self.get_next_seq()}",
**attrs, label=f"{og_label}_vals_rd_scanner", hwnode=f"{HWNodeType.ReadScanner}",
tensor="x", type=og_type, root="false", format="vals",
mode="vals", vector_reduce_mode="true", fa_color=self.fa_color)
label=f"{og_label}_vals_rd_scanner", hwnode=f"{HWNodeType.ReadScanner}",
tensor="X", type="fiberlookup", root="false", format="compressed",
mode="1", vector_reduce_mode="true", fa_color=self.fa_color, comment="vals_rd_scanner")

#vals_wr_scanner = pydot.Node(f"vr_vals_wr_scanner_{self.get_next_seq()}",
# label=f"{og_label}_vals_wr_scanner", hwnode=f"{HWNodeType.WriteScanner}",
# type="fiberwrite", mode="vals", vector_reduce_mode="true", fa_color=self.fa_color, comment="vals_wr_scanner")


vals_wr_scanner = pydot.Node(f"vr_vals_wr_scanner_{self.get_next_seq()}",
**attrs, label=f"{og_label}_vals_wr_scanner", hwnode=f"{HWNodeType.WriteScanner}",
type=og_type, mode="vals", format="compressed", vector_reduce_mode="true", fa_color=self.fa_color)
label=f"{og_label}_vals_wr_scanner", hwnode=f"{HWNodeType.WriteScanner}",
type="fiberwrite", mode="1", format="compressed", vector_reduce_mode="true", fa_color=self.fa_color, comment="vals_wr_scanner")


# glb_vals = pydot.Node(f"vr_crd_vals_{self.get_next_seq()}", **attrs,
# label=f"{og_label}_glb_vals_read", hwnode=f"{HWNodeType.GLB}",
Expand All @@ -260,19 +269,20 @@ def rewrite_VectorReducer(self):

del in_edge_attrs[in_crd_node]['comment']
del in_edge_attrs[in_val_node]['type']
del in_edge_attrs[in_crd_node]['type']

# Edges
#input_to_rsg_edge = pydot.Edge(src=in_input_node, dst=rsg, **in_edge_attrs[in_input_node])
#rsg_to_repeat = pydot.Edge(src=rsg, dst=repeat)
#repeat_to_crd_rd_scan = pydot.Edge(src=repeat, dst=crd_rd_scanner)
#crd_rd_scan_to_val_rd_scan = pydot.Edge(src=crd_rd_scanner, dst=vals_rd_scanner)
in_crd_to_union = pydot.Edge(src=in_crd_node, dst=union,
**in_edge_attrs[in_crd_node], type="crd", comment=f"in-crd-B")
**in_edge_attrs[in_crd_node], type="crd", comment=f"in-B")
in_val_to_union = pydot.Edge(src=in_val_node, dst=union, **in_edge_attrs[in_val_node],
type="ref", comment=f"in-val-B vector_reduce_mode=true", val="true")
type="ref", comment=f"in-B", val="true", vector_reduce_mode=True)
# type="ref", comment=f"in-C", val="true")
crd_rd_scan_to_union = pydot.Edge(src=crd_rd_scanner, dst=union, type="crd", comment="in-x vector_reduce_mode=true")
val_rd_scan_to_union = pydot.Edge(src=vals_rd_scanner, dst=union, type="ref", comment="in-x vector_reduce_mode=true", val="true")
crd_rd_scan_to_union = pydot.Edge(src=crd_rd_scanner, dst=union, type="crd", comment="in-x", vector_reduce_mode=True)
val_rd_scan_to_union = pydot.Edge(src=vals_rd_scanner, dst=union, type="ref", comment="in-x", val="true", vector_reduce_mode=True)
union_crd_to_crd_wr_scan = pydot.Edge(src=union, dst=crd_wr_scanner, type="crd")
union_val0_to_alu = pydot.Edge(src=union, dst=add, comment='out-B')
# union_val0_to_alu = pydot.Edge(src=union, dst=add, comment='out-C')
Expand Down Expand Up @@ -301,11 +311,15 @@ def rewrite_VectorReducer(self):
self.graph.del_edge(crd_edge.get_source(), crd_edge.get_destination())
self.graph.del_edge(val_edge.get_source(), val_edge.get_destination())

print(crd_edge_attr)
print(val_edge_attr)
del crd_edge_attr['comment']

#crd_rd_scan_to_glb = pydot.Edge(src=crd_rd_scanner, dst=dst_crd, **crd_edge_attr, use_alt_out_port="1")
#val_rd_scan_to_glb = pydot.Edge(src=vals_rd_scanner, dst=dst_vals, **val_edge_attr, use_alt_out_port="1")

crd_rd_scan_to_ds = pydot.Edge(src=crd_rd_scanner, dst=dst_crd, **crd_edge_attr, comment="final-crd vector_reduce_mode=true")
val_rd_scan_to_ds = pydot.Edge(src=vals_rd_scanner, dst=dst_vals, **val_edge_attr, comment="final-val vector_reduce_mode=true")
crd_rd_scan_to_ds = pydot.Edge(src=crd_rd_scanner, dst=dst_crd, **crd_edge_attr, comment="final-crd", vector_reduce_mode=True)
val_rd_scan_to_ds = pydot.Edge(src=vals_rd_scanner, dst=dst_vals, **val_edge_attr, comment="final-val", vector_reduce_mode=True)

#self.graph.add_edge(input_to_rsg_edge)
#self.graph.add_edge(rsg_to_repeat)
Expand Down

0 comments on commit f181dd0

Please sign in to comment.