Skip to content

Commit

Permalink
Merge branch 'mapping_to_cgra' into regression_refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
bobcheng15 committed Jul 22, 2024
2 parents da04fa3 + a5d2392 commit 115014e
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 11 deletions.
11 changes: 11 additions & 0 deletions sam/onyx/hw_nodes/fiberaccess_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def connect(self, other, edge, kwargs=None):
from sam.onyx.hw_nodes.merge_node import MergeNode
from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode
from sam.onyx.hw_nodes.stream_arbiter_node import StreamArbiterNode
from sam.onyx.hw_nodes.pass_through_node import PassThroughNode

new_conns = None
other_type = type(other)
Expand Down Expand Up @@ -230,6 +231,16 @@ def connect(self, other, edge, kwargs=None):
print(init_conns)
final_conns = self.remap_conns(init_conns, kwargs['flavor_this'])
return final_conns
elif other_type == PassThroughNode:
assert kwargs is not None
assert 'flavor_this' in kwargs
this_flavor = self.get_flavor(kwargs['flavor_this'])
print(kwargs)
print("FIBER ACCESS TO Pass Through")
init_conns = this_flavor.connect(other, edge)
print(init_conns)
final_conns = self.remap_conns(init_conns, kwargs['flavor_this'])
return final_conns
else:
raise NotImplementedError(f'Cannot connect FiberAccessNode to {other_type}')

Expand Down
26 changes: 26 additions & 0 deletions sam/onyx/hw_nodes/intersect_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def connect(self, other, edge, kwargs=None):
from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode
from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode
from sam.onyx.hw_nodes.fiberaccess_node import FiberAccessNode
from sam.onyx.hw_nodes.pass_through_node import PassThroughNode

new_conns = None
isect = self.get_name()
Expand Down Expand Up @@ -221,7 +222,32 @@ def connect(self, other, edge, kwargs=None):
print(init_conns)
final_conns = other.remap_conns(init_conns, kwargs['flavor_that'])
return final_conns
elif other_type == PassThroughNode:
pass_through = other.get_name()
comment = edge.get_attributes()['comment'].strip('"')
try:
tensor = comment.split("-")[1]
except Exception:
try:
tensor = comment.split("_")[1]
except Exception:
tensor = comment
edge_type = edge.get_attributes()['type'].strip('"')

if 'crd' in edge_type:
new_conns = {
f'isect_to_isect': [
([(isect, f"coord_out"), (pass_through, "stream_in")], 17),
]
}
elif 'ref' in edge_type:
isect_conn = self.get_connection_from_tensor(tensor)
new_conns = {
f'isect_to_isect': [
([(isect, f"pos_out_{isect_conn}"), (pass_through, "stream_in")], 17),
]
}
return new_conns
else:
raise NotImplementedError(f'Cannot connect IntersectNode to {other_type}')

Expand Down
20 changes: 20 additions & 0 deletions sam/onyx/hw_nodes/merge_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def connect(self, other, edge, kwargs=None):
from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode
from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode
from sam.onyx.hw_nodes.fiberaccess_node import FiberAccessNode
from sam.onyx.hw_nodes.pass_through_node import PassThroughNode

new_conns = None
other_type = type(other)
Expand Down Expand Up @@ -106,6 +107,25 @@ def connect(self, other, edge, kwargs=None):
([(merge, f"coord_out_{out_conn}"), (other_merge, f"coord_in_{in_conn}")], 17),
]
}
return new_conns
elif other_type == PassThroughNode:
pass_through = other.get_name()
# Use inner to process outer
comment = edge.get_attributes()['comment'].strip('"')
tensor_lvl = None
if self.get_inner() in comment:
out_conn = 0
tensor_lvl = self.get_inner()
else:
out_conn = 1
tensor_lvl = self.get_outer()

new_conns = {
f'merger_to_merger_{out_conn}_to_pass_through': [
([(merge, f"coord_out_{out_conn}"), (pass_through, "stream_in")], 17),
]
}
return new_conns

elif other_type == RepeatNode:
raise NotImplementedError(f'Cannot connect MergeNode to {other_type}')
Expand Down
121 changes: 112 additions & 9 deletions sam/onyx/hw_nodes/pass_through_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@


class PassThroughNode(HWNode):
def __init__(self, name=None) -> None:
def __init__(self, name=None, conn_to_tensor=None) -> None:
super().__init__(name=name)

self.conn_to_tensor = conn_to_tensor
self.tensor_to_conn = {}
if conn_to_tensor is not None:
for conn, tensor in self.conn_to_tensor.items():
self.tensor_to_conn[tensor] = conn

def connect(self, other, edge, kwargs=None):

from sam.onyx.hw_nodes.broadcast_node import BroadcastNode
Expand All @@ -27,6 +33,7 @@ def connect(self, other, edge, kwargs=None):
pass_through = self.get_name()

other_type = type(other)
print(other_type)

if other_type == WriteScannerNode:
wr_scan = other.get_name()
Expand All @@ -36,21 +43,117 @@ def connect(self, other, edge, kwargs=None):
]
}
return new_conns
elif other_type == ReadScannerNode:
print("PASSTHORUGH TO REPEAT EDGE!")
rd_scan = other.get_name()
new_conns = {
'pass_through_to_rd_scan': [
([(pass_through, "stream_out"), (rd_scan, f"us_pos_in")], 17),
]
}
return new_conns
elif other_type == RepeatNode:
repeat = other.get_name()
print("PASSTHROUGH TO REPEAT EDGE!")
new_conns = {
'pass_through_to_repeat': [
# send output to rd scanner
([(pass_through, "stream_out"), (repeat, "proc_data_in")], 17),
]
}
return new_conns
elif other_type == IntersectNode:
comment = edge.get_attributes()['comment'].strip('"')
try:
tensor = comment.split("-")[1]
except Exception:
try:
tensor = comment.split("_")[1]
except Exception:
tensor = comment

other_isect = other.get_name()
isect_conn = self.get_connection_from_tensor(tensor)
other_isect_conn = other.get_connection_from_tensor(tensor)

edge_type = edge.get_attributes()['type'].strip('"')

if 'crd' in edge_type:
new_conns = {
f'pass_through_to_isect': [
([(pass_through, "stream_out"), (other_isect, f"coord_in_{other_isect_conn}")], 17),
]
}
elif 'ref' in edge_type:
new_conns = {
f'pass_through_to_isect': [
([(pass_through, "stream_out"), (other_isect, f"pos_in_{other_isect_conn}")], 17),
]
}
return new_conns

elif other_type == MergeNode:
edge_attr = edge.get_attributes()
crddrop = other.get_name()
print("CHECKING READ TENSOR - CRDDROP")
print(edge)
crd_drop_outer = other.get_outer()
comment = edge_attr['comment'].strip('"')
conn = 0
# okay this is dumb, stopgap until we can have super consistent output
try:
mapped_to_conn = comment.split("-")[1]
except Exception:
try:
mapped_to_conn = comment.split("_")[1]
except Exception:
mapped_to_conn = comment
if crd_drop_outer in mapped_to_conn:
conn = 1

if 'use_alt_out_port' in edge_attr:
out_conn = 'block_rd_out'
elif ('vector_reduce_mode' in edge_attr):
if (edge_attr['vector_reduce_mode']):
out_conn = 'pos_out'
else:
out_conn = 'coord_out'

new_conns = {
f'rd_scan_to_crddrop_{conn}': [
([(pass_through, "stream_out"), (crddrop, f"coord_in_{conn}")], 17),
]
}

return new_conns
elif other_type == RepSigGenNode:
rsg = other.get_name()
new_conns = {
f'pass_through_to_rsg': [
([(pass_through, "stream_out"), (rsg, f"base_data_in")], 17),
]
}
elif other_type == FiberAccessNode:
# Only could be using the write scanner portion of the fiber access
# fa = other.get_name()
conns_original = self.connect(other.get_write_scanner(), edge=edge)
print(conns_original)
conns_remapped = other.remap_conns(conns_original, "write_scanner")
print(conns_remapped)

return conns_remapped
print("PASSTHROUGH TO FIBER ACCESS")
assert kwargs is not None
assert 'flavor_that' in kwargs
that_flavor = other.get_flavor(kwargs['flavor_that'])
print(kwargs)
init_conns = self.connect(that_flavor, edge)
print(init_conns)
final_conns = other.remap_conns(init_conns, kwargs['flavor_that'])
return final_conns

else:
raise NotImplementedError(f'Cannot connect GLBNode to {other_type}')
raise NotImplementedError(f'Cannot connect Pass Through Node to {other_type}')

return new_conns

def get_connection_from_tensor(self, tensor):
print(self.tensor_to_conn)
return self.tensor_to_conn[tensor]

def update_input_connections(self):
self.num_inputs_connected += 1

Expand Down
23 changes: 23 additions & 0 deletions sam/onyx/hw_nodes/read_scanner_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def connect(self, other, edge, kwargs=None):
from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode
from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode
from sam.onyx.hw_nodes.stream_arbiter_node import StreamArbiterNode
from sam.onyx.hw_nodes.pass_through_node import PassThroughNode

new_conns = None
rd_scan = self.get_name()
Expand Down Expand Up @@ -268,6 +269,28 @@ def connect(self, other, edge, kwargs=None):
}
other.update_input_connections()
return new_conns
elif other_type == PassThroughNode:
pass_through = other.get_name()
e_attr = edge.get_attributes()
e_type = e_attr['type'].strip('"')
if "crd" in e_type:
new_conns = {
f'rd_scan_to_pass_through_crd': [
# send output to rd scanner
([(rd_scan, "coord_out"), (pass_through, "stream_in")], 17),
]
}
elif 'ref' in e_type:
rd_scan_out_port = "pos_out"
if 'val' in e_attr and e_attr['val'].strip('"') == 'true':
rd_scan_out_port = "coord_out"
new_conns = {
f'rd_scan_to_pass_through_pos': [
# send output to rd scanner
([(rd_scan, rd_scan_out_port), (pass_through, "stream_in")], 17),
]
}
return new_conns
else:
raise NotImplementedError(f'Cannot connect ReadScannerNode to {other_type}')

Expand Down
Loading

0 comments on commit 115014e

Please sign in to comment.