Skip to content

Commit

Permalink
Merge pull request #96 from weiya711/mapping_to_cgra_opal
Browse files Browse the repository at this point in the history
Update mapping code and sam simulation model for compression
  • Loading branch information
kalhankoul96 authored Oct 18, 2023
2 parents a50f69d + 5dd14cd commit aa3d1ed
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 8 deletions.
10 changes: 7 additions & 3 deletions sam/onyx/hw_nodes/merge_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@


class MergeNode(HWNode):
def __init__(self, name=None, outer=None, inner=None) -> None:
def __init__(self, name=None, outer=None, inner=None, mode=None) -> None:
super().__init__(name=name)
self.outer = outer
self.inner = inner
self.mode = mode

def get_outer(self):
return self.outer
Expand Down Expand Up @@ -127,9 +128,12 @@ def configure(self, attributes):
# TODO what is this supposed to be?
cmrg_stop_lvl = 1
op = 0
# 0 for compression, 1 for crddrop
cmrg_mode = self.mode
cfg_kwargs = {
'cmrg_enable': cmrg_enable,
'cmrg_stop_lvl': cmrg_stop_lvl,
'op': op
'op': op,
'cmrg_mode': cmrg_mode
}
return (cmrg_enable, cmrg_stop_lvl, op), cfg_kwargs
return (cmrg_enable, cmrg_stop_lvl, op, cmrg_mode), cfg_kwargs
24 changes: 24 additions & 0 deletions sam/sim/src/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from abc import ABC, abstractmethod
import numpy as np
import warnings

# warnings.simplefilter(action='ignore', category=FutureWarning)


def gen_stkns(dim=10):
Expand All @@ -24,6 +28,10 @@ def is_valid_crdpt(elem):
return isinstance(elem, int) or elem in valid_tkns


def is_valid_num(elem, dim=10):
return isinstance(elem, int) or isinstance(elem, float)


def is_valid_val(elem, dim=10):
valid_tkns = ['', 'D'] + gen_stkns(dim)
return isinstance(elem, int) or isinstance(elem, float) or elem in valid_tkns
Expand All @@ -33,12 +41,23 @@ def is_0tkn(elem):
return elem == 'N'


# Checks if a token is a non-control (numerical) token
def is_nc_tkn(elem, datatype=int):
return isinstance(elem, datatype)


def is_stkn(elem):
if isinstance(elem, str):
return elem.startswith('S') and (len(elem) == 2)
return False


def is_dtkn(elem):
if isinstance(elem, str):
return elem == 'D'
return False


def stkn_order(elem):
assert is_stkn(elem)
return int(elem[1])
Expand Down Expand Up @@ -89,6 +108,11 @@ def is_debug(self):
def update(self):
pass

# Check the input token of something
def valid_token(self, element, datatype=int):
return element != "" and element is not None and \
(is_dtkn(element) or is_stkn(element) or is_nc_tkn(element, datatype) or is_0tkn(element))

def reset(self):
self.done = False

Expand Down
111 changes: 106 additions & 5 deletions sam/sim/src/compression.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
from .base import *
from .token import EmptyFiberStknDrop, StknDrop


class Compression(Primitive):
def __init__(self, **kwargs):
class ValDropper(Primitive):
def __init__(self, drop_refs=False, **kwargs):
super().__init__(**kwargs)

self.in_val = []
self.in_crd = []
self.in_ref = []

self.curr_crd = ''
self.curr_ref = ''
self.curr_val = ''
self.out_crds = ''
self.out_refs = ''
self.out_vals = ''
self.val_stkn_dropper = EmptyFiberStknDrop()
self.crd_stkn_dropper = EmptyFiberStknDrop()
self.ref_stkn_dropper = EmptyFiberStknDrop()
self.drop_refs = drop_refs

if self.backpressure_en:
self.ready_backpressure = True
self.data_valid = True
Expand Down Expand Up @@ -49,13 +61,88 @@ def update(self):
if len(self.in_val) > 0 or len(self.in_crd) > 0:
self.block_start = False

icrd = ""
icrd = ''
ival = ''

if self.done:
self.curr_crd = ''
self.curr_val = ''
self.out_crds = ''
self.out_vals = ''
if self.drop_refs:
self.curr_ref = ''
self.out_refs = ''
return
elif (len(self.in_val) > 0 and len(self.in_crd) == 0) \
or (len(self.in_crd) > 0 and len(self.in_val) == 0) \
or (len(self.in_val) == 0 and len(self.in_crd) == 0):
self.out_crds = ''
self.out_vals = ''
if self.drop_refs:
self.out_refs = ''
elif len(self.in_val) > 0 and len(self.in_crd) > 0:
icrd = self
ival = self.in_val.pop(0)
icrd = self.in_crd.pop(0)
iref = ''
if self.drop_refs:
iref = self.in_ref.pop(0)

# print("ival:", ival)
# print("icrd:", icrd)

assert ival != '', "ival is an empty str"

if is_valid_num(ival):
# assert isinstance(icrd, int), "both val and crd need ot match"
if not isinstance(icrd, int) and icrd != 'N':
print("Both val and icrd need to match")
print(icrd, ival)
exit(1)
if ival == 0.0:
self.curr_crd = ''
self.curr_ref = ''
self.curr_val = ''
else:
self.curr_crd = icrd
self.curr_val = ival
if self.drop_refs:
self.curr_ref = iref
elif isinstance(ival, str) and ival != 'D':
assert isinstance(icrd, str), "both val and coord need to match"
self.curr_crd = icrd
self.curr_val = ival
if self.drop_refs:
self.curr_ref = iref
elif ival == 'D':
assert icrd == 'D'
self.curr_val = ival
self.curr_crd = icrd
if self.drop_refs:
self.curr_ref = iref
self.done = True
else:
self.curr_crd = icrd
self.curr_val = ival
if self.drop_refs:
self.curr_ref = iref

# if self.curr_crd == self.out_crds or icrd == self.out_crds:
# self.out_crds = ''
# self.out_vals = ''
# self.out_refs = ''
# else:
self.val_stkn_dropper.set_in_stream(self.curr_val)
self.crd_stkn_dropper.set_in_stream(self.curr_crd)
self.val_stkn_dropper.update()
self.crd_stkn_dropper.update()
# self.out_crds = self.crd_stkn_dropper.out_val()
# self.out_vals = self.val_stkn_dropper.out_val()
self.out_crds = self.curr_crd
self.out_vals = self.curr_val
if self.drop_refs:
self.ref_stkn_dropper.set_in_stream(self.curr_ref)
self.ref_stkn_dropper.update()
self.out_refs = self.curr_ref

if self.debug:
print("Curr OuterCrd:", self.curr_ocrd, "\tCurr InnerCrd:", icrd, "\t Curr OutputCrd:", self.curr_crd,
Expand All @@ -74,6 +161,20 @@ def set_crd(self, crd, parent=None):
if self.backpressure_en:
parent.set_backpressure(self.fifo_avail_crd)

def set_ref(self, ref, parent=None):
if ref != '' and ref is not None:
self.in_ref.append(ref)
if self.backpressure_en:
parent.set_backpressure(self.fifo_avail_crd)

def out_crd(self):
if (self.backpressure_en and self.data_valid) or not self.backpressure_en:
return self.curr_crd
return self.out_crds

def out_ref(self):
if (self.backpressure_en and self.data_valid) or not self.backpressure_en:
return self.out_refs

def out_val(self):
if (self.backpressure_en and self.data_valid) or not self.backpressure_en:
return self.out_vals
85 changes: 85 additions & 0 deletions sam/sim/test/primitives/test_compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import pytest
from sam.sim.src.compression import ValDropper
from sam.sim.src.base import remove_emptystr
from sam.sim.test.test import TIMEOUT
import numpy as np


@pytest.mark.parametrize("dim1", [8])
def test_compress_1d(dim1, debug_sim):
nums = np.random.choice([0, 1], size=dim1, p=[.4, .6])
in1 = nums.tolist() + ['S0', 'D']
crd_nums = np.arange(dim1)
crd = crd_nums.tolist() + ['S0', 'D']
# assert (len(in1) == len(in1))

gold_val = nums[nums != 0].tolist() + ['S0', 'D']
gold_crd = np.delete(crd_nums, np.where(nums == 0)).tolist() + ['S0', 'D']

comp = ValDropper(debug=debug_sim)

done = False
time = 0

out_val = []
out_crd = []

while not done and time < TIMEOUT:

if len(in1) > 0:
comp.set_val(in1.pop(0))
comp.set_crd(crd.pop(0))

comp.update()
out_val.append(comp.out_val())
out_crd.append(comp.out_crd())

if debug_sim:
print("Timestep", time, "\t Out:", comp.out_val())

done = comp.out_done()
time += 1

out_val = remove_emptystr(out_val)
out_crd = remove_emptystr(out_crd)
print("Ref val:", gold_val)
print("Out val:", out_val)

print("Ref crd:", gold_crd)
print("Out crd:", out_crd)

assert (out_val == gold_val)
assert (out_crd == gold_crd)

# @pytest.mark.parametrize("dim1", [4, 16, 32, 64])
# def test_exp_1d(dim1, debug_sim):
# in1 = [x for x in range(dim1)] + ['S0', 'D']
# in2 = None
# # assert (len(in1) == len(in1))

# gold_val = np.exp(np.arange(dim1)).tolist() + ['S0', 'D']

# exp1 = Exp(debug=debug_sim)

# done = False
# time = 0
# out_val = []
# exp1.set_in2(in2)
# while not done and time < TIMEOUT:
# if len(in1) > 0:
# exp1.set_in1(in1.pop(0))

# exp1.update()

# out_val.append(exp1.out_val())

# print("Timestep", time, "\t Out:", exp1.out_val())

# done = exp1.out_done()
# time += 1

# out_val = remove_emptystr(out_val)
# print("Ref:", gold_val)
# print("Out:", out_val)

# assert (out_val == gold_val)

0 comments on commit aa3d1ed

Please sign in to comment.