Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial commit, hooking up logic #138

Merged
merged 5 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions sam/onyx/hw_nodes/fiberaccess_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def connect(self, other, edge, kwargs=None):
from sam.onyx.hw_nodes.merge_node import MergeNode
from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode
from sam.onyx.hw_nodes.stream_arbiter_node import StreamArbiterNode
from sam.onyx.hw_nodes.pass_through_node import PassThroughNode

new_conns = None
other_type = type(other)
Expand Down Expand Up @@ -230,6 +231,16 @@ def connect(self, other, edge, kwargs=None):
print(init_conns)
final_conns = self.remap_conns(init_conns, kwargs['flavor_this'])
return final_conns
elif other_type == PassThroughNode:
assert kwargs is not None
assert 'flavor_this' in kwargs
this_flavor = self.get_flavor(kwargs['flavor_this'])
print(kwargs)
print("FIBER ACCESS TO Pass Through")
init_conns = this_flavor.connect(other, edge)
print(init_conns)
final_conns = self.remap_conns(init_conns, kwargs['flavor_this'])
return final_conns
else:
raise NotImplementedError(f'Cannot connect FiberAccessNode to {other_type}')

Expand Down
34 changes: 34 additions & 0 deletions sam/onyx/hw_nodes/intersect_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def connect(self, other, edge, kwargs=None):
from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode
from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode
from sam.onyx.hw_nodes.fiberaccess_node import FiberAccessNode
from sam.onyx.hw_nodes.pass_through_node import PassThroughNode

new_conns = None
isect = self.get_name()
Expand Down Expand Up @@ -221,7 +222,40 @@ def connect(self, other, edge, kwargs=None):
print(init_conns)
final_conns = other.remap_conns(init_conns, kwargs['flavor_that'])
return final_conns
elif other_type == PassThroughNode:
pass_through = other.get_name()
comment = edge.get_attributes()['comment'].strip('"')
try:
tensor = comment.split("-")[1]
except Exception:
try:
tensor = comment.split("_")[1]
except Exception:
tensor = comment
edge_type = edge.get_attributes()['type'].strip('"')

if 'crd' in edge_type:
new_conns = {
f'isect_to_isect': [
# send output to rd scanner
([(isect, f"coord_out"), (pass_through, "stream_in")], 17),
# ([(isect, f"eos_out_0"), (wr_scan, f"eos_in_0")], 1),
# ([(wr_scan, f"ready_out_0"), (isect, f"ready_in_0")], 1),
# ([(isect, f"valid_out_0"), (wr_scan, f"valid_in_0")], 1),
]
}
elif 'ref' in edge_type:
isect_conn = self.get_connection_from_tensor(tensor)
new_conns = {
f'isect_to_isect': [
# send output to rd scanner
([(isect, f"pos_out_{isect_conn}"), (pass_through, "stream_in")], 17),
# ([(isect, f"eos_out_0"), (wr_scan, f"eos_in_0")], 1),
# ([(wr_scan, f"ready_out_0"), (isect, f"ready_in_0")], 1),
# ([(isect, f"valid_out_0"), (wr_scan, f"valid_in_0")], 1),
]
}
return new_conns
else:
raise NotImplementedError(f'Cannot connect IntersectNode to {other_type}')

Expand Down
20 changes: 20 additions & 0 deletions sam/onyx/hw_nodes/merge_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def connect(self, other, edge, kwargs=None):
from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode
from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode
from sam.onyx.hw_nodes.fiberaccess_node import FiberAccessNode
from sam.onyx.hw_nodes.pass_through_node import PassThroughNode

new_conns = None
other_type = type(other)
Expand Down Expand Up @@ -106,6 +107,25 @@ def connect(self, other, edge, kwargs=None):
([(merge, f"coord_out_{out_conn}"), (other_merge, f"coord_in_{in_conn}")], 17),
]
}
return new_conns
elif other_type == PassThroughNode:
pass_through = other.get_name()
# Use inner to process outer
comment = edge.get_attributes()['comment'].strip('"')
tensor_lvl = None
if self.get_inner() in comment:
out_conn = 0
tensor_lvl = self.get_inner()
else:
out_conn = 1
tensor_lvl = self.get_outer()

new_conns = {
f'merger_to_merger_{out_conn}_to_pass_through': [
([(merge, f"coord_out_{out_conn}"), (pass_through, "stream_in")], 17),
]
}
return new_conns

elif other_type == RepeatNode:
raise NotImplementedError(f'Cannot connect MergeNode to {other_type}')
Expand Down
121 changes: 112 additions & 9 deletions sam/onyx/hw_nodes/pass_through_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@


class PassThroughNode(HWNode):
def __init__(self, name=None) -> None:
def __init__(self, name=None, conn_to_tensor=None) -> None:
super().__init__(name=name)

self.conn_to_tensor = conn_to_tensor
self.tensor_to_conn = {}
if conn_to_tensor is not None:
for conn, tensor in self.conn_to_tensor.items():
self.tensor_to_conn[tensor] = conn

def connect(self, other, edge, kwargs=None):

from sam.onyx.hw_nodes.broadcast_node import BroadcastNode
Expand All @@ -27,6 +33,7 @@ def connect(self, other, edge, kwargs=None):
pass_through = self.get_name()

other_type = type(other)
print(other_type)

if other_type == WriteScannerNode:
wr_scan = other.get_name()
Expand All @@ -36,21 +43,117 @@ def connect(self, other, edge, kwargs=None):
]
}
return new_conns
elif other_type == ReadScannerNode:
print("PASSTHORUGH TO REPEAT EDGE!")
rd_scan = other.get_name()
new_conns = {
'pass_through_to_rd_scan': [
([(pass_through, "stream_out"), (rd_scan, f"us_pos_in")], 17),
]
}
return new_conns
elif other_type == RepeatNode:
repeat = other.get_name()
print("PASSTHROUGH TO REPEAT EDGE!")
new_conns = {
'pass_through_to_repeat': [
# send output to rd scanner
([(pass_through, "stream_out"), (repeat, "proc_data_in")], 17),
]
}
return new_conns
elif other_type == IntersectNode:
comment = edge.get_attributes()['comment'].strip('"')
try:
tensor = comment.split("-")[1]
except Exception:
try:
tensor = comment.split("_")[1]
except Exception:
tensor = comment

other_isect = other.get_name()
isect_conn = self.get_connection_from_tensor(tensor)
other_isect_conn = other.get_connection_from_tensor(tensor)

edge_type = edge.get_attributes()['type'].strip('"')

if 'crd' in edge_type:
new_conns = {
f'pass_through_to_isect': [
([(pass_through, "stream_out"), (other_isect, f"coord_in_{other_isect_conn}")], 17),
]
}
elif 'ref' in edge_type:
new_conns = {
f'pass_through_to_isect': [
([(pass_through, "stream_out"), (other_isect, f"pos_in_{other_isect_conn}")], 17),
]
}
return new_conns

elif other_type == MergeNode:
edge_attr = edge.get_attributes()
crddrop = other.get_name()
print("CHECKING READ TENSOR - CRDDROP")
print(edge)
crd_drop_outer = other.get_outer()
comment = edge_attr['comment'].strip('"')
conn = 0
# okay this is dumb, stopgap until we can have super consistent output
try:
mapped_to_conn = comment.split("-")[1]
except Exception:
try:
mapped_to_conn = comment.split("_")[1]
except Exception:
mapped_to_conn = comment
if crd_drop_outer in mapped_to_conn:
conn = 1

if 'use_alt_out_port' in edge_attr:
out_conn = 'block_rd_out'
elif ('vector_reduce_mode' in edge_attr):
if (edge_attr['vector_reduce_mode']):
out_conn = 'pos_out'
else:
out_conn = 'coord_out'

new_conns = {
f'rd_scan_to_crddrop_{conn}': [
([(pass_through, "stream_out"), (crddrop, f"coord_in_{conn}")], 17),
]
}

return new_conns
elif other_type == RepSigGenNode:
rsg = other.get_name()
new_conns = {
f'pass_through_to_rsg': [
([(pass_through, "stream_out"), (rsg, f"base_data_in")], 17),
]
}
elif other_type == FiberAccessNode:
# Only could be using the write scanner portion of the fiber access
# fa = other.get_name()
conns_original = self.connect(other.get_write_scanner(), edge=edge)
print(conns_original)
conns_remapped = other.remap_conns(conns_original, "write_scanner")
print(conns_remapped)

return conns_remapped
print("PASSTHROUGH TO FIBER ACCESS")
assert kwargs is not None
assert 'flavor_that' in kwargs
that_flavor = other.get_flavor(kwargs['flavor_that'])
print(kwargs)
init_conns = self.connect(that_flavor, edge)
print(init_conns)
final_conns = other.remap_conns(init_conns, kwargs['flavor_that'])
return final_conns

else:
raise NotImplementedError(f'Cannot connect GLBNode to {other_type}')
raise NotImplementedError(f'Cannot connect Pass Through Node to {other_type}')

return new_conns

def get_connection_from_tensor(self, tensor):
print(self.tensor_to_conn)
return self.tensor_to_conn[tensor]

def update_input_connections(self):
self.num_inputs_connected += 1

Expand Down
23 changes: 23 additions & 0 deletions sam/onyx/hw_nodes/read_scanner_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def connect(self, other, edge, kwargs=None):
from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode
from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode
from sam.onyx.hw_nodes.stream_arbiter_node import StreamArbiterNode
from sam.onyx.hw_nodes.pass_through_node import PassThroughNode

new_conns = None
rd_scan = self.get_name()
Expand Down Expand Up @@ -268,6 +269,28 @@ def connect(self, other, edge, kwargs=None):
}
other.update_input_connections()
return new_conns
elif other_type == PassThroughNode:
pass_through = other.get_name()
e_attr = edge.get_attributes()
e_type = e_attr['type'].strip('"')
if "crd" in e_type:
new_conns = {
f'rd_scan_to_pass_through_crd': [
# send output to rd scanner
([(rd_scan, "coord_out"), (pass_through, "stream_in")], 17),
]
}
elif 'ref' in e_type:
rd_scan_out_port = "pos_out"
if 'val' in e_attr and e_attr['val'].strip('"') == 'true':
rd_scan_out_port = "coord_out"
new_conns = {
f'rd_scan_to_pass_through_pos': [
# send output to rd scanner
([(rd_scan, rd_scan_out_port), (pass_through, "stream_in")], 17),
]
}
return new_conns
else:
raise NotImplementedError(f'Cannot connect ReadScannerNode to {other_type}')

Expand Down
Loading
Loading