From d089dfb4d2ca2c439a733c3b28c69821b15555c7 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 16 Nov 2023 11:23:10 -0800 Subject: [PATCH] Add missing primitives --- sam/sim/src/accumulator.py | 407 +++++++++++++++++++++++++++++++++++++ 1 file changed, 407 insertions(+) diff --git a/sam/sim/src/accumulator.py b/sam/sim/src/accumulator.py index 3dd264b2..83b3acef 100644 --- a/sam/sim/src/accumulator.py +++ b/sam/sim/src/accumulator.py @@ -152,6 +152,235 @@ def return_statistics(self): stats_dict = {} return stats_dict +class SparseCrdPtAccumulator1(Primitive): + def __init__(self, maxdim=100, valtype=float, fifos=None, **kwargs): + super().__init__(**kwargs) + + self.outer_crdpt = [] + self.inner_crdpt = [] + self.in_val = [] + + self.curr_in_val = None + self.curr_in_inner_crdpt = None + self.curr_in_outer_crdpt = None + + self.emit_output = [] + self.curr_inner_crdpt = '' + self.curr_outer_crdpt = '' + self.curr_val = '' + + # Maximum possible dimension for this index level + self.maxdim = maxdim + self.order = 1 + + self.seen_done = False + # Accumulation scratchpad storage + self.storage = dict() + self.valtype = valtype + + if fifos is not None and len(fifos) == 3: + self.outer_crdpt = fifos[0] + self.inner_crdpt = fifos[1] + self.in_val = fifos[2] + + if self.get_stats: + self.hits_tracker = {} + self.stop_token_out = 0 + self.drop_token_out = 0 + self.valid_token_out = 0 + self.zero_out = 0 + self.nonzero_out = 0 + self.out_crd_fifo = 0 + self.in_crd_fifo = 0 + self.in_val_fifo = 0 + + def return_fifo(self): + return self.outer_crdpt, self.inner_crdpt, self.in_val + + def update(self): + self.update_done() + if self.debug: + if self.seen_done or self.done: + print(self.seen_done, self.done) + print("@@@", self.outer_crdpt, self.inner_crdpt, self.in_val, self.emit_output, + self.curr_in_outer_crdpt, self.curr_in_inner_crdpt, self.curr_val) + self.print_debug() + if len(self.in_val) > 0 and self.in_val[0] == "D": + print("val", self.outer_crdpt, self.inner_crdpt, self.in_val, self.emit_output, + self.curr_in_outer_crdpt, self.curr_in_inner_crdpt, self.curr_val) + self.print_debug() + if len(self.inner_crdpt) > 0 and self.inner_crdpt[0] == "D": + print("innercrd", self.outer_crdpt, self.inner_crdpt, self.in_val, self.emit_output, + self.curr_in_outer_crdpt, self.curr_in_inner_crdpt, self.curr_val) + self.print_debug() + if len(self.outer_crdpt) > 0 and self.outer_crdpt[0] == "D": + print("outercrd", self.outer_crdpt, self.inner_crdpt, self.in_val, self.emit_output, + self.curr_in_outer_crdpt, self.curr_in_inner_crdpt, self.curr_val) + self.print_debug() + + if len(self.outer_crdpt) > 0 or len(self.inner_crdpt) > 0: + self.block_start = False + + if self.get_stats: + self.out_crd_fifo = max(self.out_crd_fifo, len(self.outer_crdpt)) + self.in_crd_fifo = max(self.in_crd_fifo, len(self.inner_crdpt)) + self.in_val_fifo = max(self.in_val_fifo, len(self.in_val)) + + if self.done: + self.curr_outer_crdpt = '' + self.curr_inner_crdpt = '' + self.curr_val = '' + if self.get_stats: + self.drop_token_out += 1 + return + + if len(self.in_val) > 0 and len(self.outer_crdpt) > 0 and len(self.inner_crdpt) > 0 and not self.seen_done: + self.curr_in_val = self.in_val.pop(0) + self.curr_in_inner_crdpt = self.inner_crdpt.pop(0) + + ocrd = self.outer_crdpt.pop(0) + # if self.curr_in_val == 'D': + # print(self.curr_in_val, self.curr_in_inner_crdpt, ocrd) + # assert self.curr_in_val == "D" and self.curr_in_inner_crdpt == "D" and ocrd == "D" + # print("######", ocrd, self.curr_in_outer_crdpt, self.curr_in_inner_crdpt, self.emit_output) + # print(self.in_val, self.outer_crdpt, self.inner_crdpt, ocrd + # self.curr_in_outer_crdpt, self.curr_in_inner_crdpt, self.curr_in_val) + emit_output = ocrd != self.curr_in_outer_crdpt and self.curr_in_outer_crdpt is not None and \ + self.curr_in_outer_crdpt != "D" + if emit_output: + self.emit_output.append([self.curr_in_outer_crdpt, -1]) + # print("@@@@@", self.curr_in_outer_crdpt) + self.curr_in_outer_crdpt = ocrd + if self.curr_in_outer_crdpt in self.storage.keys(): + inner_dict = self.storage[self.curr_in_outer_crdpt] + if self.get_stats: + for k in inner_dict.keys(): + self.hits_tracker[k] = 1 + if self.curr_in_inner_crdpt in inner_dict.keys(): + if self.get_stats: + self.hits_tracker[self.curr_in_inner_crdpt] += 1 + inner_dict[self.curr_in_inner_crdpt] += self.valtype(self.curr_in_val) + else: + if self.get_stats: + self.hits_tracker[self.curr_in_inner_crdpt] = 1 + inner_dict[self.curr_in_inner_crdpt] = self.valtype(self.curr_in_val) + # If a done token is seen, cannot emit done until all coordinates have been written out + elif self.curr_in_outer_crdpt == 'D': + assert self.curr_in_inner_crdpt == 'D' and self.curr_in_val == 'D', \ + "If one item is a 'D' token, then all inputs must be" + self.seen_done = True + else: + self.storage[self.curr_in_outer_crdpt] = {self.curr_in_inner_crdpt: self.valtype(self.curr_in_val)} + # if self.curr_in_outer_crdpt == "D": + # print("__________", self.emit_output, self.seen_done) + + if len(self.emit_output) > 0: + fiber = self.emit_output[0] + + self.curr_outer_crdpt = fiber[0] + # print("===, ", self.storage) + # print(fiber) + # print(self.emit_output) + # print(self.storage[self.curr_outer_crdpt].keys(), fiber[1]) + self.curr_inner_crdpt = min( + [item for item in self.storage[self.curr_outer_crdpt].keys() if item > fiber[1]]) + self.curr_val = self.storage[self.curr_outer_crdpt][self.curr_inner_crdpt] + + if not [item for item in self.storage[self.curr_outer_crdpt].keys() if item > self.curr_inner_crdpt]: + self.emit_output.pop(0) + else: + self.emit_output[0][1] = self.curr_inner_crdpt + elif self.seen_done: + self.done = True + self.seen_done = False + self.curr_outer_crdpt = 'D' + self.curr_inner_crdpt = 'D' + self.curr_val = 'D' + else: + self.curr_outer_crdpt = '' + self.curr_inner_crdpt = '' + self.curr_val = '' + if self.get_stats: + if self.curr_val == "": + self.drop_token_out += 1 + elif is_stkn(self.curr_val): + self.stop_token_out += 1 + else: + if (isinstance(self.curr_val, float) or isinstance(self.curr_val, int)) and self.curr_val == 0: + self.zero_out += 1 + else: + self.nonzero_out += 1 + self.valid_token_out += 1 + + if self.debug: + print("Done ptaccum:", self.out_done(), self.done, + "\n Curr in ocrd: ", self.curr_in_outer_crdpt, "\t Curr in icrd", self.curr_in_inner_crdpt, + "\t Curr in val", self.curr_in_val, + "\n Curr out ocrd: ", self.curr_outer_crdpt, "\t Curr out icrd: ", self.curr_inner_crdpt, + "\t Curr out val: ", self.curr_val, + "\n Emit crds: ", self.emit_output, + "\n Storage: ", self.storage, + "\n f: ", self.outer_crdpt, self.inner_crdpt, self.in_val) + + def print_debug(self): + print("Crdptaccum_debug Done:", self.out_done(), self.done, + "\n Curr in ocrd: ", self.curr_in_outer_crdpt, "\t Curr in icrd", self.curr_in_inner_crdpt, + "\t Curr in val", self.curr_in_val, + "\n Curr out ocrd: ", self.curr_outer_crdpt, "\t Curr out icrd: ", self.curr_inner_crdpt, + "\t Curr out val: ", self.curr_val, + "\n Emit crds: ", self.emit_output, + "\n Storage: ", self.storage, + "\n Fifos: ", self.outer_crdpt, self.inner_crdpt, self.in_val) + + def set_inner_crdpt(self, crdpt): + assert not is_stkn(crdpt), 'Coordinate points should not have stop tokens' + if crdpt != '' and crdpt is not None: + self.inner_crdpt.append(crdpt) + + def set_outer_crdpt(self, crdpt): + assert not is_stkn(crdpt), 'Coordinate points should not have stop tokens' + if crdpt != '' and crdpt is not None: + self.outer_crdpt.append(crdpt) + + def set_val(self, val): + assert not is_stkn(val), 'Values associated with points should not have stop tokens' + if val != '' and val is not None: + self.in_val.append(val) + + def out_outer_crdpt(self): + return self.curr_outer_crdpt + + def out_inner_crdpt(self): + return self.curr_inner_crdpt + + def out_val(self): + return self.curr_val + + def return_statistics(self): + if self.get_stats: + stats_dict = {"stkn_outs": self.stop_token_out, + "drop_outs": self.drop_token_out, "valid_outs": self.valid_token_out, + "zero_outs": self.zero_out, "nonzero_outs": self.nonzero_out} + stats_dict.update(super().return_statistics()) + else: + stats_dict = {} + return stats_dict + + def return_hits(self): + i = 0 + cnt_gt_zero = 0 + cnt_total = 0 + total_sum = 0 + if self.get_stats: + for k in self.hits_tracker.keys(): + if self.hits_tracker[k] > i: + i = self.hits_tracker[k] + if self.hits_tracker[k] > 1: + cnt_gt_zero += 1 + total_sum += self.hits_tracker[k] + cnt_total += 1 + return i, cnt_gt_zero, cnt_total, total_sum + # NEW VERSION: Accumulation into a vector class SparseAccumulator1(Primitive): @@ -419,6 +648,184 @@ def return_statistics(self): def print_fifos(self): print("Spaccumulator: None available") +class SparseCrdPtAccumulator2(Primitive): + def __init__(self, maxdim=100, valtype=float, **kwargs): + super().__init__(**kwargs) + self.in_crdpt0 = [] + self.in_crdpt1 = [] + self.in_val = [] + + self.curr_in_val = None + self.curr_in0_crdpt = None + self.curr_in1_crdpt = None + + self.emit_output = [] + self.curr_crdpt0 = '' + self.curr_crdpt1 = '' + self.curr_val = '' + + # Maximum possible dimension for this index level + self.maxdim = maxdim + self.order = 1 + + self.seen_done = False + # Accumulation scratchpad storage + self.storage = dict() + self.valtype = valtype + + if self.get_stats: + self.hits_tracker = {} + self.stop_token_out = 0 + self.drop_token_out = 0 + self.valid_token_out = 0 + self.zero_out = 0 + self.nonzero_out = 0 + + def return_fifo(self): + return self.in_crdpt0, self.in_crdpt1, self.in_val + + def update(self): + self.update_done() + if len(self.in_crdpt0) > 0 or len(self.in_crdpt0) > 0 or len(self.in_val) > 0: + self.block_start = False + + if self.done: + self.curr_crdpt0 = '' + self.curr_crdpt1 = '' + self.curr_val = '' + if self.get_stats: + self.drop_token_out += 1 + return + + if len(self.in_val) > 0 and len(self.in_crdpt1) > 0 and len(self.in_crdpt0) > 0: + self.curr_in_val = self.in_val.pop(0) + self.curr_in0_crdpt = self.in_crdpt0.pop(0) + self.curr_in1_crdpt = self.in_crdpt1.pop(0) + + emit_output = self.curr_in1_crdpt == 'D' + if emit_output: + self.emit_output.append([-1, -1]) + assert self.curr_in1_crdpt == 'D' and self.curr_in0_crdpt == 'D' and self.curr_in_val == 'D', \ + "If one item is a 'D' token, then all inputs must be" + self.seen_done = True + else: + if self.curr_in1_crdpt in self.storage.keys(): + inner_dict = self.storage[self.curr_in1_crdpt] + if self.get_stats: + for k in inner_dict.keys(): + self.hits_tracker[k] = 1 + if self.curr_in0_crdpt in inner_dict.keys(): + if self.get_stats: + self.hits_tracker[self.curr_in0_crdpt] += 1 + inner_dict[self.curr_in0_crdpt] += self.valtype(self.curr_in_val) + else: + if self.get_stats: + self.hits_tracker[self.curr_in0_crdpt] = 1 + inner_dict[self.curr_in0_crdpt] = self.valtype(self.curr_in_val) + else: + self.storage[self.curr_in1_crdpt] = {self.curr_in0_crdpt: self.valtype(self.curr_in_val)} + + if len(self.emit_output) > 0: + fiber = self.emit_output.pop(0) + # + key1 = min( + [item for item in self.storage.keys() if item > fiber[0]]) + key0 = min( + [item for item in self.storage[key1].keys() if item > fiber[1]]) + + self.curr_crdpt1 = key1 + self.curr_crdpt0 = key0 + self.curr_val = self.storage[key1][key0] + + # Finished inner coordinates, increment outer coordinate + if not [item for item in self.storage[key1].keys() if item > key0]: + # Do not increment outer coordinate if it's the last one + if [item for item in self.storage.keys() if item > key1]: + self.emit_output.append([key1, -1]) + # Do inner coordinates + else: + self.emit_output.append([fiber[0], key0]) + + elif self.seen_done: + self.done = True + self.seen_done = False + self.curr_crdpt0 = 'D' + self.curr_crdpt1 = 'D' + self.curr_val = 'D' + else: + self.curr_crdpt0 = '' + self.curr_crdpt1 = '' + self.curr_val = '' + + if self.get_stats: + if self.curr_val == "": + self.drop_token_out += 1 + elif is_stkn(self.curr_val): + self.stop_token_out += 1 + else: + if (isinstance(self.curr_val, float) or isinstance(self.curr_val, int)) and self.curr_val == 0: + self.zero_out += 1 + else: + self.nonzero_out += 1 + self.valid_token_out += 1 + + if self.debug: + print("Done:", self.out_done(), + "\n Curr in crd1: ", self.curr_in1_crdpt, + "\t Curr in crd0", self.curr_in0_crdpt, + "\t Curr in val", self.curr_in_val, + "\n Curr out crd1: ", self.curr_crdpt1, + "\t Curr out crd0: ", self.curr_crdpt0, + "\t Curr out val: ", self.curr_val, + "\n Emit crds: ", self.emit_output, + "\n Storage: ", self.storage) + + def set_inner_crdpt(self, crdpt): + assert not is_stkn(crdpt), 'Coordinate points should not have stop tokens' + if crdpt != '' and crdpt is not None: + self.in_crdpt0.append(crdpt) + + def set_outer_crdpt(self, crdpt): + assert not is_stkn(crdpt), 'Coordinate points should not have stop tokens' + if crdpt != '' and crdpt is not None: + self.in_crdpt1.append(crdpt) + + def set_val(self, val): + assert not is_stkn(val), 'Values associated with points should not have stop tokens' + if val != '' and val is not None: + self.in_val.append(val) + + def out_outer_crdpt(self): + return self.curr_crdpt1 + + def out_inner_crdpt(self): + return self.curr_crdpt0 + + def out_val(self): + return self.curr_val + + def return_hits(self): + i = 0 + cnt_gt_zero = 0 + cnt_total = 0 + for k in self.hits_tracker.keys(): + if self.hits_tracker[k] > i: + i = self.hits_tracker[k] + if self.hits_tracker[k] > 1: + cnt_gt_zero += 1 + cnt_total += 1 + return i, cnt_gt_zero, cnt_total + + def return_statistics(self): + if self.get_stats: + stats_dict = {"stkn_outs": self.stop_token_out, + "drop_outs": self.drop_token_out, "valid_outs": self.valid_token_out, + "zero_outs": self.zero_out, "nonzero_outs": self.nonzero_out} + stats_dict.update(super().return_statistics()) + else: + stats_dict = {} + return stats_dict + # NEW VERSION: Accumulation into a matrix class SparseAccumulator2(Primitive):