Skip to content

Commit

Permalink
Add VR SAM updates
Browse files Browse the repository at this point in the history
  • Loading branch information
mcoduoza committed Nov 16, 2023
1 parent cade108 commit 54c66b7
Show file tree
Hide file tree
Showing 10 changed files with 637 additions and 770 deletions.
12 changes: 10 additions & 2 deletions sam/onyx/hw_nodes/buffet_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,16 @@ def configure(self, attributes):
cap0 = kratos.clog2(capacity_0) - fetch_width_log
cap1 = kratos.clog2(capacity_1) - fetch_width_log

if 'vector_reduce_mode' in attributes:
is_in_vr_mode = attributes['vector_reduce_mode'].strip('"')
if is_in_vr_mode == "true":
vr_mode = 1
else:
vr_mode = 0

cfg_kwargs = {
'capacity_0': cap0,
'capacity_1': cap1
'capacity_1': cap1,
'vr_mode': vr_mode
}
return (capacity_0, capacity_1), cfg_kwargs
return (capacity_0, capacity_1, vr_mode), cfg_kwargs
14 changes: 9 additions & 5 deletions sam/onyx/hw_nodes/compute_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,16 @@ def connect(self, other, edge, kwargs=None):
pe = self.get_name()
# isect_conn = other.get_num_inputs()

if 'tensor' not in edge.get_attributes():
# Taking some liberties here - but technically this is the combo val
# isect_conn = other.get_connection_from_tensor('B')
isect_conn = other.get_connection_from_tensor('C')
if 'vector_reduce_mode' in edge.get_attributes():
if edge.get_attributes()['vector_reduce_mode'] == True:
isect_conn = 0
else:
isect_conn = other.get_connection_from_tensor(edge.get_tensor())
if 'tensor' not in edge.get_attributes():
# Taking some liberties here - but technically this is the combo val
# isect_conn = other.get_connection_from_tensor('B')
isect_conn = other.get_connection_from_tensor('C')
else:
isect_conn = other.get_connection_from_tensor(edge.get_tensor())

new_conns = {
f'pe_to_isect_{in_str}_{isect_conn}': [
Expand Down
9 changes: 6 additions & 3 deletions sam/onyx/hw_nodes/fiberaccess_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,12 @@ def configure(self, attributes, flavor):

cfg_tuple, cfg_kwargs = self.get_flavor(flavor=flavor).configure(attributes)
cfg_kwargs['flavor'] = flavor
print("THESE ARE MY CONFIG KWARGS")
print(cfg_kwargs)
#breakpoint()

vr_mode = 0
cfg_tuple += (vr_mode,)
cfg_kwargs["vr_mode"] = vr_mode
#vr_mode = 0
#cfg_tuple += (vr_mode,)
#cfg_kwargs["vr_mode"] = vr_mode

return cfg_tuple, cfg_kwargs
2 changes: 1 addition & 1 deletion sam/onyx/hw_nodes/hw_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class HWNodeType(Enum):
Broadcast = 12
RepSigGen = 13
CrdHold = 14
SpAccumulator = 15
VectorReducer = 15
FiberAccess = 16


Expand Down
14 changes: 12 additions & 2 deletions sam/onyx/hw_nodes/intersect_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def connect(self, other, edge, kwargs=None):
print(edge.get_attributes())
edge_comment = edge.get_attributes()['comment'].strip('"')
tensor = edge_comment.split('-')[1]
print(self.tensor_to_conn)
out_conn = self.tensor_to_conn[tensor]
compute_conn = compute.get_num_inputs()
new_conns = {
Expand Down Expand Up @@ -248,6 +249,15 @@ def configure(self, attributes):
cmrg_enable = 0
cmrg_stop_lvl = 0
type_op = attributes['type'].strip('"')


if 'vector_reduce_mode' in attributes:
is_in_vr_mode = attributes['vector_reduce_mode'].strip('"')
if is_in_vr_mode == "true":
vr_mode = 1
else:
vr_mode = 0

if type_op == "intersect":
op = JoinerOp.INTERSECT.value
elif type_op == "union":
Expand All @@ -258,6 +268,6 @@ def configure(self, attributes):
'cmrg_enable': cmrg_enable,
'cmrg_stop_lvl': cmrg_stop_lvl,
'op': op,
'vr_mode': 0
'vr_mode': vr_mode
}
return (cmrg_enable, cmrg_stop_lvl, op, 0), cfg_kwargs
return (cmrg_enable, cmrg_stop_lvl, op, vr_mode), cfg_kwargs
11 changes: 10 additions & 1 deletion sam/onyx/hw_nodes/merge_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,16 @@ def connect(self, other, edge, kwargs=None):

return new_conns
elif other_type == IntersectNode:
raise NotImplementedError(f'Cannot connect MergeNode to {other_type}')
isect = other.get_name()
print("MERGE TO UNION FOR VECTOR REDUCE")
new_conns = {
f'merge_to_union_inner': [
([(merge, f"cmrg_coord_out_{0}"), (isect, f"coord_in_{0}")], 17),
]
}

return new_conns
#raise NotImplementedError(f'Cannot connect MergeNode to {other_type}')
elif other_type == ReduceNode:
# raise NotImplementedError(f'Cannot connect MergeNode to {other_type}')
other_red = other.get_name()
Expand Down
39 changes: 29 additions & 10 deletions sam/onyx/hw_nodes/read_scanner_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def connect(self, other, edge, kwargs=None):
edge_attr = edge.get_attributes()
if 'use_alt_out_port' in edge_attr:
out_conn = 'block_rd_out'
elif ('vector_reduce_mode' in edge_attr):
if (edge_attr['vector_reduce_mode'] == True):
out_conn = 'pos_out'
else:
out_conn = 'coord_out'

Expand All @@ -102,7 +105,13 @@ def connect(self, other, edge, kwargs=None):
elif other_type == IntersectNode:
# Send both....
isect = other.get_name()
isect_conn = other.get_connection_from_tensor(self.get_tensor())
if 'vector_reduce_mode' in edge.get_attributes():
if edge.get_attributes()['vector_reduce_mode'] == True:
isect_conn = 1
elif 'special' in edge.get_attributes():
isect_conn = 0
else:
isect_conn = other.get_connection_from_tensor(self.get_tensor())

e_attr = edge.get_attributes()
# isect_conn = 0
Expand Down Expand Up @@ -247,12 +256,12 @@ def configure(self, attributes):
dim_size = 1
stop_lvl = 0

if 'spacc' in attributes:
spacc_mode = 1
assert 'stop_lvl' in attributes
stop_lvl = int(attributes['stop_lvl'].strip('"'))
else:
spacc_mode = 0
#if 'spacc' in attributes:
# spacc_mode = 1
# assert 'stop_lvl' in attributes
# stop_lvl = int(attributes['stop_lvl'].strip('"'))
#else:
# spacc_mode = 0

# This is a fiberwrite's opposing read scanner for comms with GLB
if attributes['type'].strip('"') == 'fiberwrite':
Expand Down Expand Up @@ -283,6 +292,15 @@ def configure(self, attributes):
lookup = 0
block_mode = int(attributes['type'].strip('"') == 'fiberwrite')


if 'vector_reduce_mode' in attributes:
is_in_vr_mode = attributes['vector_reduce_mode'].strip('"')
if is_in_vr_mode == "true":
vr_mode = 1
else:
vr_mode = 0


cfg_kwargs = {
'dense': dense,
'dim_size': dim_size,
Expand All @@ -294,11 +312,12 @@ def configure(self, attributes):
'do_repeat': do_repeat,
'repeat_outer': repeat_outer,
'repeat_factor': repeat_factor,
'stop_lvl': stop_lvl,
#'stop_lvl': stop_lvl,
'block_mode': block_mode,
'lookup': lookup,
'spacc_mode': spacc_mode
#'spacc_mode': spacc_mode
'vr_mode': vr_mode
}

return (inner_offset, max_outer_dim, strides, ranges, is_root, do_repeat,
repeat_outer, repeat_factor, stop_lvl, block_mode, lookup, spacc_mode), cfg_kwargs
repeat_outer, repeat_factor, block_mode, lookup, vr_mode), cfg_kwargs
31 changes: 20 additions & 11 deletions sam/onyx/hw_nodes/write_scanner_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def connect(self, other, edge, kwargs=None):
def configure(self, attributes):

stop_lvl = 0
init_blank = 0

# compressed = int(attributes['format'] == 'compressed')
if 'format' in attributes and 'vals' in attributes['format'].strip('"'):
Expand All @@ -89,14 +90,14 @@ def configure(self, attributes):
else:
compressed = 1

if 'spacc' in attributes:
spacc_mode = 1
init_blank = 1
assert 'stop_lvl' in attributes
stop_lvl = int(attributes['stop_lvl'].strip('"'))
else:
spacc_mode = 0
init_blank = 0
#if 'spacc' in attributes:
# spacc_mode = 1
# init_blank = 1
# assert 'stop_lvl' in attributes
# stop_lvl = int(attributes['stop_lvl'].strip('"'))
#else:
# spacc_mode = 0
# init_blank = 0

# compressed = int(attributes['format'] == 'compressed')
if attributes['type'].strip('"') == 'arrayvals':
Expand All @@ -111,17 +112,25 @@ def configure(self, attributes):
block_mode = 1
else:
block_mode = 0

if 'vector_reduce_mode' in attributes:
is_in_vr_mode = attributes['vector_reduce_mode'].strip('"')
if is_in_vr_mode == "true":
vr_mode = 1
else:
vr_mode = 0

# block_mode = int(attributes['type'].strip('"') == 'fiberlookup')
# cfg_tuple = (inner_offset, compressed, lowest_level, stop_lvl, block_mode)
cfg_tuple = (compressed, lowest_level, stop_lvl, block_mode, init_blank, spacc_mode)
cfg_tuple = (compressed, lowest_level, stop_lvl, block_mode, vr_mode, init_blank)
cfg_kwargs = {
# 'inner_offset': inner_offset,
'compressed': compressed,
'lowest_level': lowest_level,
'stop_lvl': stop_lvl,
'block_mode': block_mode,
'init_blank': init_blank,
'spacc_mode': spacc_mode
'vr_mode': vr_mode,
'init_blank': init_blank
#'spacc_mode': spacc_mode
}
return cfg_tuple, cfg_kwargs
Loading

0 comments on commit 54c66b7

Please sign in to comment.