Skip to content

Commit

Permalink
Merge pull request #136 from weiya711/time_multiplexing
Browse files Browse the repository at this point in the history
Time multiplexing
  • Loading branch information
kalhankoul96 authored Jun 27, 2024
2 parents eaf80ee + b22d902 commit 78ff1ac
Show file tree
Hide file tree
Showing 10 changed files with 376 additions and 97 deletions.
8 changes: 6 additions & 2 deletions sam/onyx/generate_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def create_matrix_from_point_list(name, pt_list, shape, use_fp=False) -> MatrixG
return mg


def convert_aha_glb_output_file(glbfile, output_dir, tiles, batches):
def convert_aha_glb_output_file(glbfile, output_dir, tiles, batches, glb_mem_stride=500):

glbfile_s = os.path.basename(glbfile).rstrip(".txt")

Expand Down Expand Up @@ -531,7 +531,9 @@ def convert_aha_glb_output_file(glbfile, output_dir, tiles, batches):
tile = 0
batch = 0
block = 0
cur_base = 0
for file_path in files:
assert sl_ptr < len(straightline)
num_items = straightline[sl_ptr]
sl_ptr += 1
with open(file_path, "w+") as fh_:
Expand All @@ -546,11 +548,13 @@ def convert_aha_glb_output_file(glbfile, output_dir, tiles, batches):
if block == num_blocks:
block = 0
tile += 1
sl_ptr = cur_base + tile * glb_mem_stride
if tile == tiles:
tile = 0
batch = batch + 1
# TODO hardcoded value for now
# TODO hardcoded value for now, need to consider a larger case
sl_ptr = 32768 * batch # size of glb
cur_base = 32768 * batch


def find_file_based_on_sub_string(files_dir, sub_string_list):
Expand Down
11 changes: 11 additions & 0 deletions sam/onyx/hw_nodes/fiberaccess_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def connect(self, other, edge, kwargs=None):
from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode
from sam.onyx.hw_nodes.merge_node import MergeNode
from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode
from sam.onyx.hw_nodes.stream_arbiter_node import StreamArbiterNode

new_conns = None
other_type = type(other)
Expand Down Expand Up @@ -219,6 +220,16 @@ def connect(self, other, edge, kwargs=None):
final_conns_2 = other.remap_conns(final_conns_1, kwargs['flavor_that'])
print(final_conns_2)
return final_conns_2
elif other_type == StreamArbiterNode:
assert kwargs is not None
assert 'flavor_this' in kwargs
this_flavor = self.get_flavor(kwargs['flavor_this'])
print(kwargs)
print("FIBER ACCESS TO Stream Arbiter")
init_conns = this_flavor.connect(other, edge)
print(init_conns)
final_conns = self.remap_conns(init_conns, kwargs['flavor_this'])
return final_conns
else:
raise NotImplementedError(f'Cannot connect FiberAccessNode to {other_type}')

Expand Down
15 changes: 14 additions & 1 deletion sam/onyx/hw_nodes/glb_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class GLBNode(HWNode):
def __init__(self, name=None, data=None, valid=None, ready=None,
direction=None, num_blocks=None, file_number=None, tx_size=None, IO_id=0,
direction=None, num_blocks=None, seg_mode=None, file_number=None, tx_size=None, IO_id=0,
bespoke=False, tensor=None, mode=None, format=None) -> None:
super().__init__(name=name)

Expand All @@ -12,6 +12,7 @@ def __init__(self, name=None, data=None, valid=None, ready=None,
self.ready = ready
self.direction = direction
self.num_blocks = num_blocks
self.seg_mode = seg_mode
self.file_number = file_number
self.tx_size = tx_size
self.IO_id = IO_id
Expand All @@ -37,6 +38,9 @@ def get_tx_size(self):
def get_num_blocks(self):
return self.num_blocks

def get_seg_mode(self):
return self.seg_mode

def get_data(self):
return self.data

Expand Down Expand Up @@ -74,6 +78,7 @@ def connect(self, other, edge, kwargs=None):
from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode
from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode
from sam.onyx.hw_nodes.fiberaccess_node import FiberAccessNode
from sam.onyx.hw_nodes.pass_through_node import PassThroughNode

other_type = type(other)

Expand Down Expand Up @@ -120,6 +125,14 @@ def connect(self, other, edge, kwargs=None):
print(conns_remapped)

return conns_remapped
elif other_type == PassThroughNode:
pass_through = other.get_name()
new_conns = {
'glb_to_pass_through': [
([(self.data, "io2f_17"), (pass_through, "stream_in")], 17),
]
}
return new_conns

else:
raise NotImplementedError(f'Cannot connect GLBNode to {other_type}')
Expand Down
2 changes: 2 additions & 0 deletions sam/onyx/hw_nodes/hw_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class HWNodeType(Enum):
CrdHold = 14
VectorReducer = 15
FiberAccess = 16
StreamArbiter = 17
PassThrough = 18


class HWNode():
Expand Down
68 changes: 68 additions & 0 deletions sam/onyx/hw_nodes/pass_through_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from sam.onyx.hw_nodes.hw_node import *


class PassThroughNode(HWNode):
def __init__(self, name=None) -> None:
super().__init__(name=name)

def connect(self, other, edge, kwargs=None):

from sam.onyx.hw_nodes.broadcast_node import BroadcastNode
from sam.onyx.hw_nodes.compute_node import ComputeNode
from sam.onyx.hw_nodes.glb_node import GLBNode
from sam.onyx.hw_nodes.buffet_node import BuffetNode
from sam.onyx.hw_nodes.memory_node import MemoryNode
from sam.onyx.hw_nodes.read_scanner_node import ReadScannerNode
from sam.onyx.hw_nodes.write_scanner_node import WriteScannerNode
from sam.onyx.hw_nodes.intersect_node import IntersectNode
from sam.onyx.hw_nodes.reduce_node import ReduceNode
from sam.onyx.hw_nodes.lookup_node import LookupNode
from sam.onyx.hw_nodes.merge_node import MergeNode
from sam.onyx.hw_nodes.repeat_node import RepeatNode
from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode
from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode
from sam.onyx.hw_nodes.fiberaccess_node import FiberAccessNode

new_conns = None
pass_through = self.get_name()

other_type = type(other)

if other_type == WriteScannerNode:
wr_scan = other.get_name()
new_conns = {
'pass_through_to_wr_scan': [
([(pass_through, "stream_out"), (wr_scan, "block_wr_in")], 17),
]
}
return new_conns
elif other_type == FiberAccessNode:
# Only could be using the write scanner portion of the fiber access
# fa = other.get_name()
conns_original = self.connect(other.get_write_scanner(), edge=edge)
print(conns_original)
conns_remapped = other.remap_conns(conns_original, "write_scanner")
print(conns_remapped)

return conns_remapped

else:
raise NotImplementedError(f'Cannot connect GLBNode to {other_type}')

return new_conns

def update_input_connections(self):
self.num_inputs_connected += 1

def get_num_inputs(self):
return self.num_inputs_connected

def configure(self, attributes):
# print("Pass Through CONFIGURE")
# print(attributes)

placeholder = 1
cfg_kwargs = {
'placeholder': placeholder
}
return (placeholder), cfg_kwargs
23 changes: 22 additions & 1 deletion sam/onyx/hw_nodes/read_scanner_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def connect(self, other, edge, kwargs=None):
from sam.onyx.hw_nodes.repeat_node import RepeatNode
from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode
from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode
from sam.onyx.hw_nodes.stream_arbiter_node import StreamArbiterNode

new_conns = None
rd_scan = self.get_name()
Expand Down Expand Up @@ -255,6 +256,17 @@ def connect(self, other, edge, kwargs=None):
]
}

return new_conns
elif other_type == StreamArbiterNode:
cur_inputs = other.get_num_inputs()
assert cur_inputs <= other.max_num_inputs - 1, f"Cannot connect ReadScannerNode to {other_type}, too many inputs"
down_stream_arb = other.get_name()
new_conns = {
f'rd_scan_to_stream_arbiter_{cur_inputs}': [
([(rd_scan, "block_rd_out"), (down_stream_arb, f"stream_in_{cur_inputs}")], 17),
]
}
other.update_input_connections()
return new_conns
else:
raise NotImplementedError(f'Cannot connect ReadScannerNode to {other_type}')
Expand Down Expand Up @@ -313,6 +325,13 @@ def configure(self, attributes):
else:
vr_mode = 0

glb_addr_base = 0
glb_addr_stride = 0
if 'glb_addr_base' in attributes:
glb_addr_base = int(attributes['glb_addr_base'])
if 'glb_addr_stride' in attributes:
glb_addr_stride = int(attributes['glb_addr_stride'])

cfg_kwargs = {
'dense': dense,
'dim_size': dim_size,
Expand All @@ -328,7 +347,9 @@ def configure(self, attributes):
'block_mode': block_mode,
'lookup': lookup,
# 'spacc_mode': spacc_mode
'vr_mode': vr_mode
'vr_mode': vr_mode,
'glb_addr_base': glb_addr_base,
'glb_addr_stride': glb_addr_stride
}

return (inner_offset, max_outer_dim, strides, ranges, is_root, do_repeat,
Expand Down
80 changes: 80 additions & 0 deletions sam/onyx/hw_nodes/stream_arbiter_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from sam.onyx.hw_nodes.hw_node import *


class StreamArbiterNode(HWNode):
def __init__(self, name=None) -> None:
super().__init__(name=name)
self.max_num_inputs = 4
self.num_inputs_connected = 0
self.num_outputs = 1
self.num_outputs_connected = 0

def connect(self, other, edge, kwargs=None):

from sam.onyx.hw_nodes.broadcast_node import BroadcastNode
from sam.onyx.hw_nodes.compute_node import ComputeNode
from sam.onyx.hw_nodes.glb_node import GLBNode
from sam.onyx.hw_nodes.buffet_node import BuffetNode
from sam.onyx.hw_nodes.memory_node import MemoryNode
from sam.onyx.hw_nodes.read_scanner_node import ReadScannerNode
from sam.onyx.hw_nodes.write_scanner_node import WriteScannerNode
from sam.onyx.hw_nodes.intersect_node import IntersectNode
from sam.onyx.hw_nodes.reduce_node import ReduceNode
from sam.onyx.hw_nodes.lookup_node import LookupNode
from sam.onyx.hw_nodes.merge_node import MergeNode
from sam.onyx.hw_nodes.repeat_node import RepeatNode
from sam.onyx.hw_nodes.repsiggen_node import RepSigGenNode
from sam.onyx.hw_nodes.crdhold_node import CrdHoldNode
from sam.onyx.hw_nodes.fiberaccess_node import FiberAccessNode

new_conns = None
stream_arb = self.get_name()

other_type = type(other)

if other_type == GLBNode:
other_data = other.get_data()
other_ready = other.get_ready()
other_valid = other.get_valid()
new_conns = {
'stream_arbiter_to_glb': [
([(stream_arb, "stream_out"), (other_data, "f2io_17")], 17),
]
}
return new_conns
elif other_type == StreamArbiterNode:
cur_inputs = other.get_num_inputs()
assert cur_inputs < self.max_num_inputs - 1, f"Cannot connect StreamArbiterNode to {other_type}, too many inputs"
down_stream_arb = other.get_name()
new_conns = {
f'stream_arbiter_to_stream_arbiter_{cur_inputs}': [
([(stream_arb, "stream_out"), (down_stream_arb, f"stream_in_{cur_inputs}")], 17),
]
}
other.update_input_connections()
return new_conns
else:
raise NotImplementedError(f'Cannot connect IntersectNode to {other_type}')

return new_conns

def update_input_connections(self):
self.num_inputs_connected += 1

def get_num_inputs(self):
return self.num_inputs_connected

def configure(self, attributes):
# print("STREAM ARBITER CONFIGURE")
# print(attributes)

seg_mode = attributes['seg_mode']
num_requests = self.num_inputs_connected
assert num_requests > 0, "StreamArbiterNode must have at least one input"
num_requests = num_requests - 1 # remap to the range of 0-3

cfg_kwargs = {
'num_requests': num_requests,
'seg_mode': seg_mode
}
return (num_requests, seg_mode), cfg_kwargs
7 changes: 6 additions & 1 deletion sam/onyx/hw_nodes/write_scanner_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,18 @@ def configure(self, attributes):
else:
vr_mode = 0

stream_id = 0
if 'stream_id' in attributes:
stream_id = int(attributes['stream_id'])

cfg_tuple = (compressed, lowest_level, stop_lvl, block_mode, vr_mode, init_blank)
cfg_kwargs = {
'compressed': compressed,
'lowest_level': lowest_level,
'stop_lvl': stop_lvl,
'block_mode': block_mode,
'vr_mode': vr_mode,
'init_blank': init_blank
'init_blank': init_blank,
'stream_id': stream_id
}
return cfg_tuple, cfg_kwargs
Loading

0 comments on commit 78ff1ac

Please sign in to comment.