From 06e0e6f58e4278b65b90414a88be707e2eb8f0a5 Mon Sep 17 00:00:00 2001 From: jack-melchert <56329001+jack-melchert@users.noreply.github.com> Date: Wed, 31 Jan 2024 16:36:57 -0800 Subject: [PATCH] Compute kernel naming improvements (#86) * Added input specific kernel delay * added indent for compute kernel cycle * Fixed rewrite rule ordering bug * New mem naming scheme implemented * Fixed ordering of ports * Fixed sparse rom mapping * metamapper fixes for rv * Added better naming * Added better PE names * Updated rewrite rule naming * Special case to avoid compiling cgralib * Fixing bug in rewrite rule naming * starting working on tests * Updating pond mapping * Changing PIPELINED default behaviour * small bug fix * Changing to bottom up branch delay matching, not register insertion yet * working * Added black box checking * Added return for print_dag * Remove extra prints * Use pono for kernel verification * Bboxing with pipelined kernels, fully connected bbox formula * Bipartite style black box formula * Multiple verification runs for varying black box constraints * Add code comments * Remove unnecessary imports * Rewrite incremental black box formulas * Cleaning up and fixing middle mult bug * not working yet * Unintepreted function optimization working * Trying to fix cam pipe 2x2 * cleaning up * Turn off compute mapping verification for now * Turned back on formal proof * Added abs to custom IR to fix mapping issue * Added more complex compute kernel latencies * Added extra stage of design_top mapping * Fixing instruction selection for coreir.mux * fixed casting from bv to float * adding fp_max * Added ability to do second compute mapping and fp_max * Fix issue with regs * Added fix for naming second stage of mapping muxes * Small change to compute kernel latency naming * Small final fix for merge * Delete .travis.yml --------- Co-authored-by: root Co-authored-by: Caleb Terrill Co-authored-by: yuchen-mei --- .travis.yml | 22 - metamapper/common_passes.py | 453 +++++++--- metamapper/coreir_mapper.py | 32 +- metamapper/coreir_util.py | 3 +- metamapper/delay_matching.py | 114 ++- .../instruction_selection/dag_rewrite.py | 17 +- metamapper/irs/coreir/__init__.py | 43 +- metamapper/irs/coreir/custom_ops_ir.py | 771 ++++++------------ metamapper/irs/coreir/ir.py | 24 + metamapper/map_design_top.py | 127 +++ metamapper/rewrite_table.py | 17 +- scripts/map_app.py | 88 +- scripts/map_dse.py | 59 +- tests/test_kernel_mapping.py | 14 +- tests/test_mem_header.py | 14 - 15 files changed, 1085 insertions(+), 713 deletions(-) delete mode 100644 .travis.yml create mode 100644 metamapper/map_design_top.py delete mode 100644 tests/test_mem_header.py diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 193bce7..0000000 --- a/.travis.yml +++ /dev/null @@ -1,22 +0,0 @@ -dist: xenial -language: python -python: -- 3.8 -cache: - apt: true -addons: - apt: - packages: - - libgmp-dev - - libmpfr-dev - - libmpc-dev - - verilator -install: -- pip install coreir -- pip install -r requirements.txt --src ./src -- pip install pytest-cov -- pip install python-coveralls -- pip install -e . -- cd ./src/peak/peak && ln -s ../examples . && cd - -script: -- pytest -s -x tests/ diff --git a/metamapper/common_passes.py b/metamapper/common_passes.py index 8e37149..cef2c5e 100644 --- a/metamapper/common_passes.py +++ b/metamapper/common_passes.py @@ -8,43 +8,58 @@ from hwtypes.modifiers import strip_modifiers from peak.mapper.utils import Unbound from peak import family +from peak.black_box import BlackBox +from peak.family import _RegFamily, SMTFamily +from peak.register import gen_register from .node import DagNode import hwtypes as ht from graphviz import Digraph - - +from collections import defaultdict +import pono +import smt_switch.pysmt_frontend as fe +import smt_switch.primops as switch_ops +from peak.mapper.utils import rebind_type +import smt_switch as ss def is_unbound_const(node): return isinstance(node, Constant) and node.value is Unbound +def n2s(node): + return f"{str(node)}_{node._id_}" + class DagToPdf(Visitor): def __init__(self, no_unbound): self.no_unbound = no_unbound + def doit(self, dag: Dag): AddID().run(dag) self.graph = Digraph() self.run(dag) + if hasattr(dag, "non_output_sinks"): + for sink in dag.non_output_sinks: + self.graph.edge(n2s(sink), n2s(sink.source)) return self.graph def generic_visit(self, node): Visitor.generic_visit(self, node) - def n2s(node): - return f"{str(node)}_{node._id_}" + if self.no_unbound and not is_unbound_const(node): self.graph.node(n2s(node)) for i, child in enumerate(node.children()): if self.no_unbound and not is_unbound_const(child): self.graph.edge(n2s(child), n2s(node), label=str(i)) + def gen_dag_img(dag, file, no_unbound=True): DagToPdf(no_unbound).doit(dag).render(filename=file) + class DagToPdfSimp(Visitor): def doit(self, dag: Dag): AddID().run(dag) - self.plotted_nodes = {"global.PE", "Input", "Output","PipelineRegister"} + self.plotted_nodes = {"global.PE", "Input", "Output", "PipelineRegister"} self.child_list = [] self.graph = Digraph() self.run(dag) @@ -52,6 +67,7 @@ def doit(self, dag: Dag): def generic_visit(self, node): Visitor.generic_visit(self, node) + def n2s(node): op = node.iname.split("_")[0] return f"{str(node)}_{node._id_}\n{op}" @@ -60,12 +76,12 @@ def find_child(node): if len(node.children()) == 0: return for child in node.children(): - if str(child) in self.plotted_nodes: + if str(child) in self.plotted_nodes: self.child_list.append(child) else: child_f = find_child(child) - if str(node) in self.plotted_nodes: + if str(node) in self.plotted_nodes: find_child(node) for child in self.child_list: self.graph.edge(n2s(child), n2s(node)) @@ -75,7 +91,8 @@ def find_child(node): def gen_dag_img_simp(dag, file): DagToPdfSimp().doit(dag).render(filename=file) -#Translates DagNode + +# Translates DagNode class Constant2CoreIRConstant(Transformer): def __init__(self, nodes: Nodes): self.nodes = nodes @@ -99,19 +116,36 @@ def __init__(self, nodes, rv, Inst2): def visit_Riscv2(self, node): Transformer.generic_visit(self, node) assert node.num_children == 3 - inst2, rs1, rs2, = node.children() + ( + inst2, + rs1, + rs2, + ) = node.children() assert isinstance(inst2, Constant) riscv_node = self.nodes.dag_nodes["R32I_mappable"] - BV= fam().PyFamily().BitVector + BV = fam().PyFamily().BitVector Inst = self.rv.isa.ISA_fc.Py.Inst i0 = Constant(type=Inst, value=inst2.value[:30]) i1 = Constant(type=Inst, value=inst2.value[30:]) - n0 = riscv_node(i0, Constant(type=BV[32],value=Unbound), rs1, rs2, Constant(type=BV[32],value=Unbound)) - n1 = riscv_node(i1, Constant(type=BV[32],value=Unbound), n0.select("rd"), n0.select("rd"), Constant(type=BV[32],value=Unbound)) + n0 = riscv_node( + i0, + Constant(type=BV[32], value=Unbound), + rs1, + rs2, + Constant(type=BV[32], value=Unbound), + ) + n1 = riscv_node( + i1, + Constant(type=BV[32], value=Unbound), + n0.select("rd"), + n0.select("rd"), + Constant(type=BV[32], value=Unbound), + ) return n1 + class TypeLegalize(Transformer): - def __init__(self, WasmNodes:Nodes): + def __init__(self, WasmNodes: Nodes): self.WasmNodes = WasmNodes self.BV = fam().PyFamily().BitVector @@ -128,7 +162,7 @@ def const1(self, value): def constn1(self, value): if value == self.BV[32](-1): constn1 = self.WasmNodes.dag_nodes["constn1"] - return constn1(Constant(value=Unbound,type=self.BV[32])).select("out") + return constn1(Constant(value=Unbound, type=self.BV[32])).select("out") def const12(self, value): if value[:12].sext(20) == value: @@ -141,7 +175,6 @@ def const20(self, value): const20 = self.WasmNodes.dag_nodes["const20"] c = Constant(value=value[:20], type=self.BV[20]) return const20(c).select("out") - def constOther(self, value): lsb = self.const20(value[:16].zext(16)) @@ -162,13 +195,14 @@ def visit_Constant(self, node): self.constn1, self.const12, self.const20, - self.constOther + self.constOther, ): new = f(value) if new is not None: return new raise NotImplementedError() + class Unbound2Const(Visitor): def visit_Constant(self, node): if node.value is Unbound: @@ -188,7 +222,8 @@ def generic_visit(self, node): Visitor.generic_visit(self, node) if node.nodes == self.nodes: self.ops.setdefault(node.node_name, 0) - self.ops[node.node_name] +=1 + self.ops[node.node_name] += 1 + class DagNumNodes(Visitor): def __init__(self): @@ -216,93 +251,276 @@ def verify(self, dag: Dag): def generic_visit(self, node): if hasattr(node, "node_name"): - if node.node_name != "coreir.reg" and node.node_name != "memory.rom2" and node.node_name != "memory.fprom2": + if ( + node.node_name != "coreir.reg" + and node.node_name != "memory.rom2" + and node.node_name != "memory.fprom2" + ): nodes = type(node).nodes - if nodes != self.nodes and nodes != Common: + if nodes != Common and node.node_name not in self.nodes._node_names: self.wrong_nodes.add(node) Visitor.generic_visit(self, node) -from peak.mapper.utils import rebind_type, solved_to_bv -import pysmt.shortcuts as smt -from pysmt.logics import QF_BV -def prove_formula(formula, solver, i1): - with smt.Solver(solver, logic=QF_BV) as solver: - solver.add_assertion(formula.value) - verified = not solver.solve() - if verified: - return None - else: - return solved_to_bv(i1._value_, solver) +def pysmt_to_pono(i, o, regs, solver, convert, cycles, bboxes): + i = convert(i._value_.value) + o = convert(o._value_.value) + + fts = pono.FunctionalTransitionSystem(solver) + mapping = ( + {} + ) # mapping from converted pysmt inputs/registers to pono inputvars/statevars + + mapping[i] = fts.make_inputvar(f"IVAR_{repr(i)}", i.get_sort()) + i = mapping[i] + + # make pono statevars for all registers + for reg, _ in regs: + reg = convert(reg.value) + statevar = fts.make_statevar(f"SVAR_{repr(reg)}", reg.get_sort()) + mapping[reg] = statevar + + # make pono inputvars for all black box outputs + for op_bboxes in list(bboxes.values()): + for bbox in op_bboxes: + outs = bbox[1] + if not isinstance(outs, tuple): + outs = (outs,) + + for out in outs: + out = convert(out.value) + inputvar = fts.make_inputvar(f"IVAR_{repr(out)}", out.get_sort()) + mapping[out] = inputvar + + # convert black box inputs/outputs to corresponding pono/smt-switch terms + for op_bboxes in list(bboxes.values()): + for idx in range(len(op_bboxes)): + ins, outs = op_bboxes[idx] + if not isinstance(ins, tuple): + ins = (ins,) + if not isinstance(outs, tuple): + outs = (outs,) + + ins = tuple([solver.substitute(convert(x.value), mapping) for x in ins]) + outs = tuple([solver.substitute(convert(x.value), mapping) for x in outs]) + + op_bboxes[idx] = (ins, outs) + + # set pono register next values + for reg, reg_next in regs: + reg = convert(reg.value) + reg_next = convert(reg_next.value) + reg_next = solver.substitute(reg_next, mapping) + fts.assign_next(mapping[reg], reg_next) + + o = solver.substitute(o, mapping) + + ur = pono.Unroller(fts) + i = ur.at_time(i, 0) + o = ur.at_time(o, cycles) + + solver.assert_formula(ur.at_time(fts.init, 0)) + + # assert state transitions for each cycle of delay + for cycle in range(cycles): + solver.assert_formula(ur.at_time(fts.trans, cycle)) + + # create new black box dict with entries for each black box at each cycle + bboxes_ur = defaultdict(list) + for cycle in range(cycles + 1): + for op, op_bboxes in list(bboxes.items()): + for ins, outs in op_bboxes: + ins = tuple([ur.at_time(x, cycle) for x in ins]) + outs = tuple([ur.at_time(x, cycle) for x in outs]) + bboxes_ur[op].append((ins, outs)) + + return i, o, bboxes_ur + + +def check_sat(solver, bbox_types_to_ins_outs, i0): + print("\t\tFormally verifying premapped and mapped dags") + res = solver.check_sat() + if res.is_unsat(): + return None -#Returns None if equal, counter example for one input otherwise -def prove_equal(dag0: Dag, dag1: Dag, solver_name="z3"): + return solver.get_value(i0) + + +def prove_equal(dag0: Dag, dag1: Dag, cycles, solver_name="bitwuzla"): if dag0.input.type != dag1.input.type: raise ValueError("Input types are not the same") if dag0.output.type != dag1.output.type: raise ValueError("Output types are not the same") - i0, o0 = SMT().get(dag0) - i1, o1 = SMT().get(dag1) - formula = o0._value_.substitute((i0._value_, i1._value_)) != o1._value_ - return prove_formula(formula, solver_name, i1) + i0, o0, regs0, bboxes0 = SMT().get(dag0) + i1, o1, regs1, bboxes1 = SMT().get(dag1) + + if regs0: + raise ValueError(f"Unmapped dag should not have registers: {regs0}") + + s = fe.Solver(solver_name) + solver = s.solver + convert = s.converter.convert + + i0, o0, bboxes0 = pysmt_to_pono(i0, o0, [], solver, convert, 0, bboxes0) + i1, o1, bboxes1 = pysmt_to_pono(i1, o1, regs1, solver, convert, cycles, bboxes1) + + bbox_types_to_ins_outs = bboxes0 + for k, v in bboxes1.items(): + if k in bbox_types_to_ins_outs: + bbox_types_to_ins_outs[k] += v + else: + bbox_types_to_ins_outs[k] = v + + for idx, (k, v) in enumerate(bbox_types_to_ins_outs.items()): + bvs = v[0][0][0].get_sort() + func = solver.make_sort(ss.sortkinds.FUNCTION, [bvs, bvs, bvs]) + f = solver.make_symbol(f"bb{idx}", func) + for (ins, outs) in v: + func_form = solver.make_term(switch_ops.Apply, f, ins[0], ins[1]) + solver.assert_formula( + solver.make_term(switch_ops.Equal, outs[0], func_form) + ) + + solver.assert_formula(solver.make_term(switch_ops.Equal, i0, i1)) + solver.assert_formula( + solver.make_term(switch_ops.Not, solver.make_term(switch_ops.Equal, o0, o1)) + ) + + return check_sat(solver, bbox_types_to_ins_outs, i0) + def _get_aadt(T): T = rebind_type(T, fam().SMTFamily()) return fam().SMTFamily().get_adt_t(T) + +# TODO: this would recurse forever if two objects reference eachother +def _recursive_filter_fc(obj, cond, fc): + if cond(obj): + fc(obj) + elif hasattr(obj, "__dict__"): + for _, sub_obj in obj.__dict__.items(): + _recursive_filter_fc(sub_obj, cond, fc) + + class SMT(Visitor): def __init__(self): pass def get(self, dag: Dag): self.values = {} - if len(dag.sources) !=1: + self.regs = [] + self.regs_next = [] + self.bboxes = defaultdict(list) + + if len(dag.sources) != 1: raise NotImplementedError + self.run(dag) - return self.values[dag.input], self.values[dag.output] - def visit_Input(self, node : Input): + if dag.input not in self.values: + aadt = _get_aadt(dag.input.type) + val = fam().SMTFamily().BitVector[1]() + self.values[dag.input] = aadt(val) + return ( + self.values[dag.input], + self.values[dag.output], + list(zip(self.regs, self.regs_next)), + self.bboxes, + ) + + def visit_Input(self, node: Input): aadt = _get_aadt(node.type) val = fam().SMTFamily().BitVector[aadt._assembler_.width]() self.values[node] = aadt(val) def visit_Constant(self, node: Constant): val = node.assemble(fam().SMTFamily()) - #aadt = _get_aadt(node.type) - #if node.value is Unbound: - # value = 0 - #else: - # value = node.value - #from hwtypes import AbstractBitVector, AbstractBit - #if issubclass(aadt, (AbstractBit, AbstractBitVector)): - # val = aadt(value) - #else: - # val = aadt(fam().SMTFamily().BitVector[aadt._assembler_.width](value)) self.values[node] = val def visit_Select(self, node: Select): Visitor.generic_visit(self, node) - val =self.values[node.children()[0]] + val = self.values[node.children()[0]] self.values[node] = val[node.field] def visit_Combine(self, node: Combine): Visitor.generic_visit(self, node) - vals = {field: self.values[child] for field, child in zip(node.type.field_dict.keys(), node.children())} + vals = { + field: self.values[child] + for field, child in zip(node.type.field_dict.keys(), node.children()) + } aadt = _get_aadt(node.type) self.values[node] = aadt.from_fields(**vals) def visit_Output(self, node: Output): Visitor.generic_visit(self, node) - vals = {field: self.values[child] for field, child in zip(node.type.field_dict.keys(), node.children())} + vals = { + field: self.values[child] + for field, child in zip(node.type.field_dict.keys(), node.children()) + } aadt = _get_aadt(node.type) self.values[node] = aadt.from_fields(**vals) def generic_visit(self, node: DagNode): Visitor.generic_visit(self, node) - peak_fc = node.nodes.peak_nodes[node.node_name] - vals = {field: self.values[child] for field, child in zip(peak_fc.Py.input_t.field_dict.keys(), node.children())} - outputs = peak_fc.SMT()(**vals) + if node.node_name == "PipelineRegister": + # TODO this is a temporary fix for now + peak_fc = gen_register(node.type) + vals = { + field: self.values[child] + for field, child in zip( + peak_fc.Py.input_t.field_dict.keys(), node.children() + ) + } + vals["en"] = fam().SMTFamily().Bit(1) + else: + peak_fc = node.nodes.peak_nodes[node.node_name] + vals = { + field: self.values[child] + for field, child in zip( + peak_fc.Py.input_t.field_dict.keys(), node.children() + ) + } + peak_fc_smt = peak_fc.SMT() + + def is_reg(x): + return isinstance(x, _RegFamily.RegBase) or isinstance( + x, _RegFamily.AttrRegBase + ) + + def make_freevar(x): + x.value = x.value.__class__() + + def is_bbox(x): + return isinstance(x, BlackBox) + + def set_bbox_outputs(x): + output_t = type(x).output_t + # TODO should make this generalize for types other than bitvector + outputs = tuple([SMTFamily().BitVector[t().num_bits]() for t in output_t]) + if len(outputs) == 1: + outputs = outputs[0] + x._set_outputs(outputs) + + _recursive_filter_fc(peak_fc_smt, is_reg, make_freevar) + _recursive_filter_fc(peak_fc_smt, is_reg, lambda x: self.regs.append(x.value)) + _recursive_filter_fc(peak_fc_smt, is_bbox, set_bbox_outputs) + + outputs = peak_fc_smt(**vals) + + def record_bbox_io(x): + self.bboxes[type(x)].append((x._get_inputs(), x._output_vals)) + + _recursive_filter_fc( + peak_fc_smt, is_reg, lambda x: self.regs_next.append(x.value) + ) + _recursive_filter_fc(peak_fc_smt, is_bbox, record_bbox_io) + + if node.node_name == "PipelineRegister": + self.values[node] = outputs + return + if not isinstance(outputs, tuple): outputs = (outputs,) @@ -320,6 +538,7 @@ def generic_visit(self, node): node._id_ = self.curid self.curid += 1 + class CountPEs(Visitor): def __init__(self): self.res = 0 @@ -329,7 +548,6 @@ def generic_visit(self, node): if hasattr(node, "node_name"): if node.node_name == "global.PE": self.res += 1 - def visit_PE(self, node): Visitor.generic_visit(self, node) @@ -339,6 +557,7 @@ def visit_PE_wrapped(self, node): Visitor.generic_visit(self, node) self.res += 1 + class Printer(Visitor): def __init__(self): self.res = "\n" @@ -372,13 +591,15 @@ def visit_Select(self, node): def visit_Input(self, node): Visitor.generic_visit(self, node) - self.res += f"{node._id_}{hex(id(node))}\n" + self.res += f"{node._id_}\n" def visit_InstanceInput(self, node): self.res += f"{node._id_}\n" def visit_Constant(self, node): - self.res += f"{node._id_}({node.value}{type(node.value)}, {node.type})>\n" + self.res += ( + f"{node._id_}({node.value}{type(node.value)}, {node.type})>\n" + ) def visit_Output(self, node): Visitor.generic_visit(self, node) @@ -393,26 +614,29 @@ def visit_InstanceOutput(self, node): def visit_Combine(self, node: Bind): Visitor.generic_visit(self, node) child_ids = ", ".join([str(child._id_) for child in node.children()]) - self.res += f"{node._id_}({child_ids})\n" + self.res += ( + f"{node._id_}({child_ids})\n" + ) + class BindsToCombines(Transformer): def gen_combine(self, node: Bind): if len(node.paths) == 1 and len(node.paths[0]) == 0: return node.children()[0] - #print("Trying to Bind {") - #print(f" type={list(node.type.field_dict.items())}") - #print(f" paths={node.paths}") - #assert len(node.type.field_dict) <= len(node.paths) - #sort paths based off of first field + # print("Trying to Bind {") + # print(f" type={list(node.type.field_dict.items())}") + # print(f" paths={node.paths}") + # assert len(node.type.field_dict) <= len(node.paths) + # sort paths based off of first field field_info = {} for path, child in zip(node.paths, node.children()): assert len(path) > 0 field = path[0] assert field in node.type.field_dict - field_info.setdefault(field, {"paths":[], "children":[]}) + field_info.setdefault(field, {"paths": [], "children": []}) field_info[field]["paths"].append(path[1:]) field_info[field]["children"].append(child) - #assert field_info.keys() == node.type.field_dict.keys() + # assert field_info.keys() == node.type.field_dict.keys() children = [] tu_field = None for field, T in node.type.field_dict.items(): @@ -422,23 +646,30 @@ def gen_combine(self, node: Bind): tu_field = field sub_paths = field_info[field]["paths"] sub_children = field_info[field]["children"] - sub_bind = Bind(*sub_children, paths=sub_paths, type=T, iname=node.iname + str(field)) + sub_bind = Bind( + *sub_children, paths=sub_paths, type=T, iname=node.iname + str(field) + ) new_child = self.gen_combine(sub_bind) children.append(new_child) - #print(f" children={children}") - #print("}") - return Combine(*children, type=node.type, iname= node.iname, tu_field=tu_field) + # print(f" children={children}") + # print("}") + return Combine(*children, type=node.type, iname=node.iname, tu_field=tu_field) + def visit_Bind(self, node: Bind): Transformer.generic_visit(self, node) return self.gen_combine(node) + from hwtypes.adt import Sum, TaggedUnion, Tuple, Product + # Consolidates constants into a simpler Bind node class SimplifyCombines(Transformer): def visit_Combine(self, node: Combine): Transformer.generic_visit(self, node) - aadt = AssembledADT[strip_modifiers(node.type), Assembler, fam().PyFamily().BitVector] + aadt = AssembledADT[ + strip_modifiers(node.type), Assembler, fam().PyFamily().BitVector + ] if issubclass(node.type, (Product, Tuple)): const_dict = OrderedDict() for child, field in zip(node.children(), node.type.field_dict.keys()): @@ -466,6 +697,7 @@ def visit_Combine(self, node: Combine): raise NotImplementedError() return Constant(value=val._value_, type=node.type) + class CloneInline(Visitor): def clone(self, dag: Dag, input_nodes, iname_prefix: str = ""): assert dag is not None @@ -477,7 +709,7 @@ def clone(self, dag: Dag, input_nodes, iname_prefix: str = ""): dag_copy = Dag( sources=[self.node_map[node] for node in dag.sources], - sinks=[self.node_map[node] for node in dag.sinks] + sinks=[self.node_map[node] for node in dag.sinks], ) return dag_copy, input_nodes_copy @@ -492,14 +724,17 @@ def generic_visit(self, node): new_node.iname = self.iname_prefix + new_node.iname self.node_map[node] = new_node -class CustomInline(Transformer): + +class CustomInline(Transformer): def __init__(self, rewrite_rules): self.rrs = rewrite_rules def visit_Select(self, node: Select): Transformer.generic_visit(self, node) if node.child.node_name in self.rrs: - replace_dag, input_nodes = CloneInline().clone(*self.rrs[node.child.node_name], iname_prefix=node.iname) + replace_dag, input_nodes = CloneInline().clone( + *self.rrs[node.child.node_name], iname_prefix=node.iname + ) for in_node in input_nodes: new_children = list(in_node.children()) for child_idx, child_node in enumerate(in_node.children()): @@ -510,11 +745,10 @@ def visit_Select(self, node: Select): in_node.set_children(*new_children) return replace_dag.output.child - return node - + return node -#Finds Opportunities to skip selecting from a Combine node +# Finds Opportunities to skip selecting from a Combine node class RemoveSelects(Transformer): def visit_Select(self, node: Select): Transformer.generic_visit(self, node) @@ -531,15 +765,20 @@ def visit_Select(self, node: Select): def print_dag(dag: Dag): AddID().run(dag) - print(Printer().run(dag).res) + res = Printer().run(dag).res + print(res) + return res + def count_pes(dag: Dag): return CountPEs().run(dag).res + def dag_to_pdf(dag: Dag, filename): AddID().run(dag) DagToPdf().run(dag).graph.render(filename, view=False) + class CheckIfTree(Visitor): def __init__(self): self.parent_cnt = {} @@ -552,7 +791,7 @@ def is_tree(self, dag: Dag): self.run(dag) return all(cnt < 2 for cnt in self.parent_cnt.values()) - #If it is an input or a select of an input + # If it is an input or a select of an input def is_input(self, node: DagNode): if isinstance(node, Input): return True @@ -560,6 +799,7 @@ def is_input(self, node: DagNode): return self.is_input(node.children()[0]) else: return False + def generic_visit(self, node): for child in node.children(): if self.is_input(child): @@ -578,7 +818,7 @@ def clone(self, dag: Dag, iname_prefix: str = ""): dag_copy = Dag( sources=[self.node_map[node] for node in dag.sources], - sinks=[self.node_map[node] for node in dag.sinks] + sinks=[self.node_map[node] for node in dag.sinks], ) return dag_copy @@ -593,6 +833,7 @@ def generic_visit(self, node): new_node.iname = self.iname_prefix + new_node.iname self.node_map[node] = new_node + class Uses(Visitor): def uses(self, dag: Dag): self.uses = {} @@ -608,9 +849,9 @@ def generic_visit(self, node: DagNode): inst, _, rs1, rs2, _ = node.children() assert isinstance(inst, Constant) self.uses.setdefault(node, {}) - for rs, idx in ((rs1,'rs1'), (rs2,'rs2')): + for rs, idx in ((rs1, "rs1"), (rs2, "rs2")): if isinstance(rs, Constant): - #if rs.value is not Unbound: + # if rs.value is not Unbound: # raise ValueError(f"expected Unbound, not {rs.value}") continue self.uses[node][idx] = self.uses[rs] @@ -643,7 +884,8 @@ def visit_Combine(self, node: DagNode): def visit_Input(self, node): pass -#This will naively linearize the code + +# This will naively linearize the code class Schedule(Visitor): def schedule(self, dag: Dag): self.insts = [] @@ -654,6 +896,22 @@ def visit_R32I_mappable(self, node): Visitor.generic_visit(self, node) self.insts.append(node) +class GetSinks(Visitor): + def __init__(self): + self.sinks = {} + + def doit(self, dag: Dag): + self.run(dag) + for sink in dag.sinks: + self.sinks[sink] = [] + return self.sinks + + def generic_visit(self, node: DagNode): + for child in node.children(): + if child not in self.sinks: + self.sinks[child] = [] + self.sinks[child].append(node) + Visitor.generic_visit(self, node) class ConstantPacking(Transformer): def __init__(self, pe_reg_info): @@ -663,43 +921,52 @@ def pack_constant(self, node, value, port): if not hasattr(node, "assemble"): return False instr = node.assemble(family.PyFamily()) - aadt = AssembledADT[strip_modifiers(node.type), Assembler, family.PyFamily().BitVector] + aadt = AssembledADT[ + strip_modifiers(node.type), Assembler, family.PyFamily().BitVector + ] reg = self.pe_reg_info["port_to_reg"][port] reg_instr = getattr(instr, reg) const_instr = getattr(instr, port) - if reg_instr._value_.value == self.pe_reg_info['instrs']['bypass'] or \ - reg_instr._value_.value == self.pe_reg_info['instrs']['reg']: + if ( + reg_instr._value_.value == self.pe_reg_info["instrs"]["bypass"] + or reg_instr._value_.value == self.pe_reg_info["instrs"]["reg"] + ): # Can constant pack # Change register mode to const instr_size = reg_instr._to_bitvector_().size - new_reg_instr = reg_instr.from_fields(ht.BitVector[instr_size](self.pe_reg_info['instrs']['const'])) + new_reg_instr = reg_instr.from_fields( + ht.BitVector[instr_size](self.pe_reg_info["instrs"]["const"]) + ) setattr(instr, reg, new_reg_instr) # Set value of const setattr(instr, port, value) - + const_dict = OrderedDict() for field in node.type.field_dict.keys(): const_dict[field] = getattr(instr, field) node.value = aadt(**const_dict)._value_ - + return True return False - def generic_visit(self, node): Transformer.generic_visit(self, node) - if node.node_name == "global.PE": + if node.node_name == "global.PE" and hasattr(node, "_metadata_"): ports = node._metadata_ new_children = [child for child in node.children()] for port_idx, child in enumerate(node.children()): if child.node_name == "Select": for child_ in child.children(): if child_.node_name == "coreir.const": - if self.pack_constant(new_children[0], child_.child.value, ports[port_idx][0]): - new_children[port_idx] = Constant(type=ht.BitVector[16],value=Unbound) + if self.pack_constant( + new_children[0], child_.child.value, ports[port_idx][0] + ): + new_children[port_idx] = Constant( + type=ht.BitVector[16], value=Unbound + ) node.set_children(*new_children) return node diff --git a/metamapper/coreir_mapper.py b/metamapper/coreir_mapper.py index ce258d9..88a71a8 100644 --- a/metamapper/coreir_mapper.py +++ b/metamapper/coreir_mapper.py @@ -1,16 +1,15 @@ from metamapper.common_passes import VerifyNodes, print_dag, count_pes, CustomInline, SimplifyCombines, RemoveSelects, prove_equal, \ - Clone, ExtractNames, Unbound2Const, gen_dag_img, ConstantPacking + Clone, ExtractNames, Unbound2Const, gen_dag_img, ConstantPacking, GetSinks import metamapper.coreir_util as cutil from metamapper.rewrite_table import RewriteTable from metamapper.node import Nodes, Dag -from metamapper.delay_matching import DelayMatching, KernelDelay +from metamapper.delay_matching import DelayMatching, branch_delay_match, KernelDelay from metamapper.instruction_selection import GreedyCovering from peak.mapper import RewriteRule as PeakRule, read_serialized_bindings import typing as tp import coreir import json - class DefaultLatency: @staticmethod @@ -20,19 +19,20 @@ def get(node): class Mapper: # Lazy # Discover at mapping time # ops (if lazy=False, search for these) - def __init__(self, CoreIRNodes: Nodes, ArchNodes: Nodes, alg=GreedyCovering, lazy=True, ops=None, rrules=None): + def __init__(self, CoreIRNodes: Nodes, ArchNodes: Nodes, alg=GreedyCovering, lazy=True, ops=None, rrules=None, kernel_name_prefix=False): self.CoreIRNodes = CoreIRNodes self.ArchNodes = ArchNodes self.table = RewriteTable(CoreIRNodes, ArchNodes) self.num_pes = 0 + self.num_regs = 0 self.kernel_cycles = {} self.const_rr = None self.bit_const_rr = None self.gen_rules(ops, rrules) self.compile_time_rule_gen = lambda dag : None - self.inst_sel = alg(self.table) + self.inst_sel = alg(self.table, kernel_name_prefix) def gen_rules(self, ops, rrules=None): @@ -51,6 +51,7 @@ def gen_rules(self, ops, rrules=None): for ind, peak_rule in enumerate(rrules): if ops != None: op = ops[ind] + print(f"Loading {op} ", end=" ", flush=True) if "fp" in op and "pipelined" in op: op = op.split("_pipelined")[0] self.table.add_peak_rule(peak_rule, op) @@ -60,7 +61,6 @@ def gen_rules(self, ops, rrules=None): def do_mapping(self, dag, kname="", convert_unbound=True, prove_mapping=True, node_cycles=None, pe_reg_info=None) -> coreir.Module: self.compile_time_rule_gen(dag) - use_constant_packing = pe_reg_info != None if use_constant_packing: @@ -95,22 +95,24 @@ def do_mapping(self, dag, kname="", convert_unbound=True, prove_mapping=True, no RemoveSelects().run(mapped_dag) self.num_pes += count_pes(mapped_dag) - print("Used", count_pes(mapped_dag), "PEs") + print("\tUsed", count_pes(mapped_dag), "PEs") unmapped = VerifyNodes(self.ArchNodes).verify(mapped_dag) if unmapped is not None: raise ValueError(f"Following nodes were unmapped: {unmapped}") - assert VerifyNodes(self.CoreIRNodes).verify(original_dag) is None if node_cycles is not None: - DelayMatching(node_cycles).run(mapped_dag) - self.kernel_cycles[kname] = KernelDelay(node_cycles).doit(mapped_dag) - - if prove_mapping: - counter_example = prove_equal(original_dag, mapped_dag) + sinks = GetSinks().doit(mapped_dag) + self.kernel_cycles[kname], added_regs = branch_delay_match(mapped_dag, node_cycles, sinks) + print("\tAdded", added_regs, "during branch delay matching") + self.num_regs += added_regs + + if prove_mapping and count_pes(mapped_dag) != 0: + verify_dag = Clone().clone(mapped_dag, iname_prefix=f"verification_") + DelayMatching(node_cycles).run(verify_dag) + counter_example = prove_equal(original_dag, verify_dag, KernelDelay(node_cycles).doit(verify_dag)) if counter_example is not None: - raise ValueError(f"Mapped is not the same {counter_example}") - #Create a new module representing the mapped_dag + raise ValueError(f"Mapped dag is not the same {counter_example}") if convert_unbound: Unbound2Const().run(mapped_dag) diff --git a/metamapper/coreir_util.py b/metamapper/coreir_util.py index 7d21906..d67aec7 100644 --- a/metamapper/coreir_util.py +++ b/metamapper/coreir_util.py @@ -181,7 +181,8 @@ def __init__(self, cmod: coreir.Module, nodes: Nodes, allow_unknown_instances=Fa node_name = self.nodes.name_from_coreir(inst.module) if node_name is None: - print(self.nodes.coreir_modules[f'coreir.{inst.module.name}'].print_()) + breakpoint() + print(self.nodes.coreir_modules[f'{inst.module.namespace.name}.{inst.module.name}'].print_()) print(inst.module.print_()) raise ValueError(f"Unknown module {inst.module.name}") diff --git a/metamapper/delay_matching.py b/metamapper/delay_matching.py index 49d0d30..ebb7347 100755 --- a/metamapper/delay_matching.py +++ b/metamapper/delay_matching.py @@ -1,6 +1,6 @@ from DagVisitor import Transformer, Visitor from metamapper.node import Constant, PipelineRegister - +from metamapper.common_passes import print_dag, GetSinks class DelayMatching(Transformer): def __init__(self, node_latencies): @@ -41,7 +41,6 @@ def generic_visit(self, node): self.aggregate_latencies[node] = max_latency + this_latency return node -#Verifies that a kernel is branch-delay matched class KernelDelay(Visitor): def __init__(self, node_latencies): self.node_latencies = node_latencies @@ -85,3 +84,114 @@ def visit_PipelineRegister(self, node): raise ValueError("Child of pipe register is constant") self.aggregate_latencies[node] = self.aggregate_latencies[child] + 1 +def topological_sort_helper(dag, node, stack, visited): + visited.add(node) + for ns in node.children(): + if ns not in visited: + topological_sort_helper(dag, ns, stack, visited) + stack.append(node) + +def topological_sort(dag): + visited = set() + stack = [] + for n in dag.roots(): + if n not in visited: + topological_sort_helper(dag, n, stack, visited) + return stack[::-1] + +def is_input_sel(node): + fields = [] + curr_node = node + + while True: + if curr_node.node_name == "Select": + fields.append(str(curr_node.field)) + else: + if curr_node.node_name == "Input": + return fields + else: + return None + assert len(curr_node.children()) == 1 + curr_node = curr_node.child + +def get_connected_pe_name(ret_list, source, node, sinks): + if len(sinks[node]) == 0: + return + elif node.node_name == "global.PE": + ret_list.append((node.iname, node._metadata_[node.children().index(source)][0])) + return + elif node.node_name == "PipelineRegister": + ret_list.append((node.iname, "reg")) + return + else: + for sink in sinks[node]: + get_connected_pe_name(ret_list, node, sink, sinks) + + +def branch_delay_match(dag, node_latencies, sinks): + + sorted_nodes = topological_sort(dag) + + added_regs = 0 + node_cycles = {} + input_latencies = {} + + for node in sorted_nodes: + cycles = set() + + if len(sinks[node]) == 0: + cycles = {0} + + for sink in sinks[node]: + if sink not in node_cycles: + c = 0 + else: + c = node_cycles[sink] + + if c != None: + c += node_latencies.get(node) + + cycles.add(c) + + if None in cycles: + cycles.remove(None) + + if len(cycles) > 1: + print(f"\t\tIncorrect node delay: {node} {cycles}") + + max_cycles = max(cycles) + for sink in sinks[node]: + new_child = node + pipeline_type = node.type + new_children = [child for child in sink.children()] + for idx, c in enumerate(new_children): + if c == new_child: + for _ in range(max_cycles - node_cycles[sink]): + print("\t\tbreak", node, sink) + new_child = PipelineRegister(new_child, type=pipeline_type) + added_regs += 1 + + new_children[idx] = new_child + sink.set_children(*new_children) + node_cycles[node] = max_cycles + elif len(cycles) == 1: + node_cycles[node] = max(cycles) + else: + node_cycles[node] = None + + sinks = GetSinks().doit(dag) + + fields = is_input_sel(node) + if fields is not None: + if len(cycles) > 0: + fields.reverse() + + latenciy_dict_key = "_".join(fields) + + connected_pes = [] + get_connected_pe_name(connected_pes, node, node, sinks) + input_latencies[latenciy_dict_key] = {"latency": node_cycles[node], "pe_port": connected_pes} + node_cycles[node] = None + + return input_latencies, added_regs + diff --git a/metamapper/instruction_selection/dag_rewrite.py b/metamapper/instruction_selection/dag_rewrite.py index ac7a984..df3cb25 100644 --- a/metamapper/instruction_selection/dag_rewrite.py +++ b/metamapper/instruction_selection/dag_rewrite.py @@ -14,12 +14,13 @@ def visit_Select(self, node): #Given a Dag, greedly apply the rewrite rule class GreedyReplace(Transformer): - def __init__(self, rr: RewriteRule): + def __init__(self, rr: RewriteRule, kernel_name_prefix=False): self.rr = rr #Match needs to match all output_selects up to but not including input_selects self.output_selects = rr.tile.output.children() self.input_selects = set(rr.tile.input.select(field) for field in rr.tile.input._selects) self.state_roots = rr.tile.sinks[1:] + self.kernel_name_prefix = kernel_name_prefix if len(self.output_selects) > 1 or self.state_roots != []: raise NotImplementedError("TODO") @@ -39,7 +40,7 @@ def match_node(self, tile_node, dag_node, cur_matches): return {tile_node.field: dag_node}, {tile_node: dag_node} # Verify node types are identical - if type(tile_node) != type(dag_node): + if type(tile_node).node_name != type(dag_node).node_name: return None matched_inputs = {} @@ -56,6 +57,8 @@ def match_node(self, tile_node, dag_node, cur_matches): def visit_Select(self, node): #visit all children first Transformer.generic_visit(self, node) + + matched = self.match_node(self.output_selects[0], node, {}) if matched is None: return None @@ -64,22 +67,24 @@ def visit_Select(self, node): #What this is doing is pointing the matched inputs of the dag to the body of the tile. #Then replacing the body of the tile to this node #TODO verify and call with the matched dag - rr_name = node.children()[0].iname + rr_name = str(self.rr.name).replace(".", "_") + if self.kernel_name_prefix: + rr_name = f"{node.child.iname}${rr_name}" replace_dag_copy = Clone().clone(self.rr.replace(None), iname_prefix=f"{rr_name}_{node.iname}_") ReplaceInputs(matched_inputs).run(replace_dag_copy) return replace_dag_copy.output.children()[0] class GreedyCovering: - def __init__(self, rrt: RewriteTable): + def __init__(self, rrt: RewriteTable, kernel_name_prefix=False): self.rrt = rrt + self.kernel_name_prefix = kernel_name_prefix def __call__(self, dag: Dag): #Make a unique copy dag = Clone().clone(dag) for rr in self.rrt.rules: #Will update dag in place - cnt = GreedyReplace(rr).replace(dag) - #print(f"RR {rr.name} used {cnt} times") + cnt = GreedyReplace(rr, self.kernel_name_prefix).replace(dag) return dag diff --git a/metamapper/irs/coreir/__init__.py b/metamapper/irs/coreir/__init__.py index db1049b..5aa0a48 100644 --- a/metamapper/irs/coreir/__init__.py +++ b/metamapper/irs/coreir/__init__.py @@ -18,6 +18,11 @@ def gen_CoreIRNodes(width): CoreIRNodes = Nodes("CoreIR") peak_ir = gen_peak_CoreIR(width) c = CoreIRContext() + cgralib = True + try: + c.load_library("cgralib") + except: + cgralib = False basic = ("mul", "add", "const", "and_", "or_", "neg") other = ("ashr", "eq", "neq", "lshr", "mux", "sub", "slt", "sle", "sgt", "sge", "ult", "ule", "ugt", "uge", "shl") @@ -45,7 +50,6 @@ def gen_CoreIRNodes(width): assert name_ == name assert name in CoreIRNodes.coreir_modules assert CoreIRNodes.name_from_coreir(cmod) == name - name = f"float_DW.fp_add" peak_fc = peak_ir.instructions[name] cmod = None @@ -139,9 +143,44 @@ def gen_CoreIRNodes(width): cmod = None name_ = load_from_peak(CoreIRNodes, peak_fc, cmod=cmod, name="commonlib.mult_middle", modparams=()) - CoreIRNodes.custom_nodes = ["coreir.neq", "commonlib.mult_middle", "float.max", "float.min", "float.div", "float_DW.fp_mul", "float_DW.fp_add", "float.sub", "fp_getmant", "fp_addiexp", "fp_subexp", "fp_cnvexp2f", "fp_getfint", "fp_getffrac", "fp_cnvint2f", "fp_gt", "fp_lt", "float.exp", "float.mux"] + name = f"commonlib.abs" + peak_fc = peak_ir.instructions[name] + cmod = None + name_ = load_from_peak(CoreIRNodes, peak_fc, cmod=cmod, name="commonlib.abs", modparams=()) + + if cgralib: + name = f"cgralib.Mem" + peak_fc = peak_ir.instructions[name] + cmod = c.get_namespace('cgralib').generators['Mem'](ctrl_width=16, has_chain_en=False, has_external_addrgen=False, has_flush=True, has_read_valid=False, has_reset=False, has_stencil_valid=True, has_valid=False, is_rom=True, num_inputs=2, num_outputs=2, use_prebuilt_mem=True, width=16) + name_ = load_from_peak(CoreIRNodes, peak_fc, cmod=cmod, stateful=True, name="cgralib.Mem", modparams=()) + name = f"cgralib.Pond" + peak_fc = peak_ir.instructions[name] + cmod = c.get_namespace('cgralib').generators['Pond'](num_inputs=2, num_outputs=2, width=16) + name_ = load_from_peak(CoreIRNodes, peak_fc, cmod=cmod, stateful=True, name="cgralib.Pond", modparams=()) + + CoreIRNodes.custom_nodes = ["coreir.neq", "commonlib.abs", "commonlib.absd", "commonlib.mult_middle", "float.max", "float.min", "float.div", "float_DW.fp_mul", "float_DW.fp_add", "float.sub", "fp_getmant", "fp_addiexp", "fp_subexp", "fp_cnvexp2f", "fp_getfint", "fp_getffrac", "fp_cnvint2f", "fp_gt", "fp_lt", "float.exp", "float.mux"] + + class Mem_amber(DagNode): + def __init__(self, clk_en, data_in_0, data_in_1, wen_in_0, wen_in_1, *, iname): + super().__init__(clk_en, data_in_0, data_in_1, wen_in_0, wen_in_1, iname=iname) + self.modparams=() + @property + def attributes(self): + return ("iname") + + #Hack to get correct port name + #def select(self, field, original=None): + # self._selects.add("data_out_0") + # return Select(self, field="rdata",type=BitVector[16]) + + nodes = CoreIRNodes + static_attributes = {} + node_name = "cgralib.Mem_amber" + num_children = 3 + type = Product.from_fields("Output",{"data_out_0":BitVector[16], "data_out_1":BitVector[16], "stencil_valid":BitVector[1]}) + class FPRom(DagNode): def __init__(self, raddr, ren, *, init, iname): super().__init__(raddr, ren, init=init, iname=iname) diff --git a/metamapper/irs/coreir/custom_ops_ir.py b/metamapper/irs/coreir/custom_ops_ir.py index 9f2ba86..56cf96b 100755 --- a/metamapper/irs/coreir/custom_ops_ir.py +++ b/metamapper/irs/coreir/custom_ops_ir.py @@ -2,17 +2,24 @@ from hwtypes import BitVector, Bit from hwtypes.adt import Product from peak import Peak, name_outputs, family_closure, Const +from peak.black_box import BlackBox from peak.family import AbstractFamily, MagmaFamily, SMTFamily +from peak.float import float_lib_gen, RoundingMode +from hwtypes import RoundingMode as RoundingMode_hw from ...node import Nodes, Constant, DagNode, Select -from hwtypes import SMTFPVector, FPVector, RoundingMode +from hwtypes import SMTFPVector, FPVector import magma def gen_custom_ops_peak_CoreIR(width): + + + float_lib = float_lib_gen(8, 7) + CoreIR = IR() DATAWIDTH = 16 def BFloat16_fc(family): if isinstance(family, MagmaFamily): - BFloat16 = magma.BFloat[16] + BFloat16 = magma.BFloat[8, 7, RoundingMode.RNE, False] BFloat16.reinterpret_from_bv = lambda bv: BFloat16(bv) BFloat16.reinterpret_as_bv = lambda f: magma.Bits[16](f) return BFloat16 @@ -20,45 +27,114 @@ def BFloat16_fc(family): FPV = SMTFPVector else: FPV = FPVector - BFloat16 = FPV[8, 7, RoundingMode.RNE, False] + BFloat16 = FPV[8, 7, RoundingMode_hw.RNE, False] return BFloat16 @family_closure - def fp_getffrac_fc(family: AbstractFamily): + def fp_add_fc(family: AbstractFamily): + FPAdd = float_lib.const_rm(RoundingMode.RNE).Add_fc(family) + + BitVector = family.BitVector Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] - BFloat16 = BFloat16_fc(family) FPExpBV = family.BitVector[8] FPFracBV = family.BitVector[7] - BitVector = family.BitVector - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) + @family.assemble(locals(), globals()) + class fp_add(Peak, BlackBox): + def __init__(self): + self.Add: FPAdd = FPAdd() + + @name_outputs(out=Data) + def __call__(self, in0 : Data, in1 : Data) -> Data: + + return Data(self.Add(in0, in1)) + + # a_fpadd = reinterpret_from_bv(in0) + # b_fpadd = reinterpret_from_bv(in1) + # return Data((a_fpadd + b_fpadd)) + + return fp_add + + CoreIR.add_instruction("float_DW.fp_add", fp_add_fc) + + @family_closure + def fp_sub_fc(family: AbstractFamily): + Data = family.BitVector[16] + Data32 = family.Unsigned[32] + SInt = family.Signed[16] + UInt = family.Unsigned[16] + Bit = family.Bit + + + @family.assemble(locals(), globals()) + class fp_sub(Peak, BlackBox): + def __init__(self): + self.Add: FPAdd = FPAdd() + + @name_outputs(out=Data) + def __call__(self, in0 : Data, in1 : Data) -> Data: + + in1 = in1 ^ (2 ** (16 - 1)) + return Data(self.Add(in0, in1)) + + return fp_sub + + CoreIR.add_instruction("float.sub", fp_sub_fc) + + + @family_closure + def fp_mul_fc(family: AbstractFamily): - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) + FPMul = float_lib.const_rm(RoundingMode.RNE).Mul_fc(family) - def fp_get_exp(val : Data): - return val[7:15] + Data = family.BitVector[16] + Data32 = family.Unsigned[32] + SInt = family.Signed[16] + UInt = family.Unsigned[16] + Bit = family.Bit - def fp_get_frac(val : Data): - return val[:7] - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) + @family.assemble(locals(), globals()) + class fp_mul(Peak, BlackBox): + def __init__(self): + self.Mul: FPMul = FPMul() - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) + @name_outputs(out=Data) + def __call__(self, in0 : Data, in1 : Data) -> Data: + return Data(self.Mul(in0, in1)) + # a_fpadd = reinterpret_from_bv(in0) + # b_fpadd = reinterpret_from_bv(in1) + # return Data((a_fpadd - b_fpadd)) + + return fp_mul + + CoreIR.add_instruction("float_DW.fp_mul", fp_mul_fc) - def fp_is_neg(val : Data): - return Bit(val[-1]) + + @family_closure + def fp_getffrac_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) + Data = family.BitVector[16] + Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] + + FPExpBV = family.BitVector[8] + FPFracBV = family.BitVector[7] @family.assemble(locals(), globals()) - class fp_getffrac(Peak): + class fp_getffrac(Peak, BlackBox): @name_outputs(out=Data) def __call__(self, in0 : Data, in1 : Data) -> Data: signa = BitVector[16]((in0 & 0x8000)) @@ -88,40 +164,21 @@ def __call__(self, in0 : Data, in1 : Data) -> Data: @family_closure def fp_getfint_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] - BFloat16 = BFloat16_fc(family) FPExpBV = family.BitVector[8] FPFracBV = family.BitVector[7] - BitVector = family.BitVector - - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) - - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) - - def fp_get_exp(val : Data): - return val[7:15] - - def fp_get_frac(val : Data): - return val[:7] - - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_neg(val : Data): - return Bit(val[-1]) @family.assemble(locals(), globals()) - class fp_getfint(Peak): + class fp_getfint(Peak, BlackBox): @name_outputs(out=Data) def __call__(self, in0 : Data, in1 : Data) -> Data: signa = BitVector[16]((in0 & 0x8000)) @@ -149,40 +206,21 @@ def __call__(self, in0 : Data, in1 : Data) -> Data: @family_closure def fp_cnvint2f_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] - BFloat16 = BFloat16_fc(family) FPExpBV = family.BitVector[8] FPFracBV = family.BitVector[7] - BitVector = family.BitVector - - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) - - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) - - def fp_get_exp(val : Data): - return val[7:15] - - def fp_get_frac(val : Data): - return val[:7] - - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_neg(val : Data): - return Bit(val[-1]) @family.assemble(locals(), globals()) - class fp_cnvint2f(Peak): + class fp_cnvint2f(Peak, BlackBox): @name_outputs(out=Data) def __call__(self, in0 : Data, in1 : Data) -> Data: @@ -251,40 +289,21 @@ def __call__(self, in0 : Data, in1 : Data) -> Data: @family_closure def fp_cnvexp2f_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] - BFloat16 = BFloat16_fc(family) FPExpBV = family.BitVector[8] FPFracBV = family.BitVector[7] - BitVector = family.BitVector - - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) - - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) - - def fp_get_exp(val : Data): - return val[7:15] - - def fp_get_frac(val : Data): - return val[:7] - - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_neg(val : Data): - return Bit(val[-1]) @family.assemble(locals(), globals()) - class fp_cnvexp2f(Peak): + class fp_cnvexp2f(Peak, BlackBox): @name_outputs(out=Data) def __call__(self, in0 : Data, in1 : Data) -> Data: expa0 = BitVector[8](in0[7:15]) @@ -337,40 +356,21 @@ def __call__(self, in0 : Data, in1 : Data) -> Data: @family_closure def fp_subexp_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] - BFloat16 = BFloat16_fc(family) FPExpBV = family.BitVector[8] FPFracBV = family.BitVector[7] - BitVector = family.BitVector - - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) - - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) - - def fp_get_exp(val : Data): - return val[7:15] - - def fp_get_frac(val : Data): - return val[:7] - - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_neg(val : Data): - return Bit(val[-1]) @family.assemble(locals(), globals()) - class fp_subexp(Peak): + class fp_subexp(Peak, BlackBox): @name_outputs(out=Data) def __call__(self, in0 : Data, in1 : Data) -> Data: signa = BitVector[16]((in0 & 0x8000)) @@ -390,40 +390,21 @@ def __call__(self, in0 : Data, in1 : Data) -> Data: @family_closure def fp_addiexp_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] - BFloat16 = BFloat16_fc(family) FPExpBV = family.BitVector[8] FPFracBV = family.BitVector[7] - BitVector = family.BitVector - - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) - - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) - - def fp_get_exp(val : Data): - return val[7:15] - - def fp_get_frac(val : Data): - return val[:7] - - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_neg(val : Data): - return Bit(val[-1]) @family.assemble(locals(), globals()) - class fp_addiexp(Peak): + class fp_addiexp(Peak, BlackBox): @name_outputs(out=Data) def __call__(self, in0 : Data, in1 : Data) -> Data: @@ -448,14 +429,21 @@ def __call__(self, in0 : Data, in1 : Data) -> Data: @family_closure def fp_getmant_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed[16] - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] + + FPExpBV = family.BitVector[8] + FPFracBV = family.BitVector[7] @family.assemble(locals(), globals()) - class fp_getmant(Peak): + class fp_getmant(Peak, BlackBox): @name_outputs(out=Data) def __call__(self, in0 : Data, in1 : Data) -> Data: return Data(in0 & 0x7F) @@ -466,14 +454,21 @@ def __call__(self, in0 : Data, in1 : Data) -> Data: @family_closure def fp_mux_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed[16] - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] + + FPExpBV = family.BitVector[8] + FPFracBV = family.BitVector[7] @family.assemble(locals(), globals()) - class fp_mux(Peak): + class fp_mux(Peak, BlackBox): @name_outputs(out=Data) def __call__(self, in0 : Data, in1 : Data, sel:Bit) -> Data: return sel.ite(in0, in1) @@ -483,92 +478,57 @@ def __call__(self, in0 : Data, in1 : Data, sel:Bit) -> Data: @family_closure - def fp_add_fc(family: AbstractFamily): + def fp_max_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed[16] - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] - BFloat16 = BFloat16_fc(family) FPExpBV = family.BitVector[8] FPFracBV = family.BitVector[7] - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) - - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) - - def fp_get_exp(val : Data): - return val[7:15] - - def fp_get_frac(val : Data): - return val[:7] - - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_neg(val : Data): - return Bit(val[-1]) - @family.assemble(locals(), globals()) - class fp_add(Peak): + class fp_max(Peak, BlackBox): @name_outputs(out=Data) - def __call__(self, in0 : Data, in1 : Data) -> Data: + def __call__(self, in0 : Data, in1 : Data) -> Bit: - a_fpadd = bv2float_DW(in0) - b_fpadd = bv2float_DW(in1) - return Data(float_DW2bv(a_fpadd + b_fpadd)) + a_fpadd = BFloat(in0) + b_fpadd = BFloat(in1) + gt = Bit(a_fpadd > b_fpadd) + return Data(gt.ite(in0, in1)) - return fp_add + return fp_max - CoreIR.add_instruction("float_DW.fp_add", fp_add_fc) + CoreIR.add_instruction("fp_max", fp_max_fc) @family_closure def fp_gt_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed[16] - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] - BFloat16 = BFloat16_fc(family) FPExpBV = family.BitVector[8] FPFracBV = family.BitVector[7] - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) - - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) - - def fp_get_exp(val : Data): - return val[7:15] - - def fp_get_frac(val : Data): - return val[:7] - - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_neg(val : Data): - return Bit(val[-1]) - @family.assemble(locals(), globals()) - class fp_gt(Peak): + class fp_gt(Peak, BlackBox): @name_outputs(out=Data) def __call__(self, in0 : Data, in1 : Data) -> Bit: - a_fpadd = bv2float_DW(in0) - b_fpadd = bv2float_DW(in1) + a_fpadd = BFloat(in0) + b_fpadd = BFloat(in1) return Bit(a_fpadd > b_fpadd) return fp_gt @@ -578,44 +538,26 @@ def __call__(self, in0 : Data, in1 : Data) -> Bit: @family_closure def fp_lt_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed[16] - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] - BFloat16 = BFloat16_fc(family) FPExpBV = family.BitVector[8] FPFracBV = family.BitVector[7] - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) - - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) - - def fp_get_exp(val : Data): - return val[7:15] - - def fp_get_frac(val : Data): - return val[:7] - - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_neg(val : Data): - return Bit(val[-1]) - @family.assemble(locals(), globals()) - class fp_lt(Peak): + class fp_lt(Peak, BlackBox): @name_outputs(out=Data) def __call__(self, in0 : Data, in1 : Data) -> Bit: - a_fpadd = bv2float_DW(in0) - b_fpadd = bv2float_DW(in1) + a_fpadd = BFloat(in0) + b_fpadd = BFloat(in1) return Bit(a_fpadd < b_fpadd) return fp_lt @@ -625,43 +567,25 @@ def __call__(self, in0 : Data, in1 : Data) -> Bit: @family_closure def fp_exp_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed[16] - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] - BFloat16 = BFloat16_fc(family) FPExpBV = family.BitVector[8] FPFracBV = family.BitVector[7] - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) - - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) - - def fp_get_exp(val : Data): - return val[7:15] - - def fp_get_frac(val : Data): - return val[:7] - - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_neg(val : Data): - return Bit(val[-1]) - @family.assemble(locals(), globals()) - class fp_exp(Peak): + class fp_exp(Peak, BlackBox): @name_outputs(out=Data) def __call__(self, in0 : Data) -> Data: - a_fpadd = bv2float_DW(in0) + a_fpadd = BFloat(in0) return a_fpadd return fp_exp @@ -670,141 +594,30 @@ def __call__(self, in0 : Data) -> Data: - @family_closure - def fp_sub_fc(family: AbstractFamily): - Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed[16] - UInt = family.Unsigned[16] - Bit = family.Bit - - BFloat16 = BFloat16_fc(family) - FPExpBV = family.BitVector[8] - FPFracBV = family.BitVector[7] - - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) - - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) - - def fp_get_exp(val : Data): - return val[7:15] - - def fp_get_frac(val : Data): - return val[:7] - - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_neg(val : Data): - return Bit(val[-1]) - - @family.assemble(locals(), globals()) - class fp_sub(Peak): - @name_outputs(out=Data) - def __call__(self, in0 : Data, in1 : Data) -> Data: - - a_fpadd = bv2float_DW(in0) - b_fpadd = bv2float_DW(in1) - return Data(float_DW2bv(a_fpadd - b_fpadd)) - - return fp_sub - - CoreIR.add_instruction("float.sub", fp_sub_fc) - - - @family_closure - def fp_mul_fc(family: AbstractFamily): - Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed[16] - UInt = family.Unsigned[16] - Bit = family.Bit - - BFloat16 = BFloat16_fc(family) - FPExpBV = family.BitVector[8] - FPFracBV = family.BitVector[7] - - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) - - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) - - def fp_get_exp(val : Data): - return val[7:15] - - def fp_get_frac(val : Data): - return val[:7] - - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_neg(val : Data): - return Bit(val[-1]) - - @family.assemble(locals(), globals()) - class fp_mul(Peak): - @name_outputs(out=Data) - def __call__(self, in0 : Data, in1 : Data) -> Data: - - a_fpadd = bv2float_DW(in0) - b_fpadd = bv2float_DW(in1) - return Data(float_DW2bv(a_fpadd - b_fpadd)) - - return fp_mul - - CoreIR.add_instruction("float_DW.fp_mul", fp_mul_fc) - @family_closure def fp_div_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed[16] - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] - BFloat16 = BFloat16_fc(family) FPExpBV = family.BitVector[8] FPFracBV = family.BitVector[7] - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) - - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) - - def fp_get_exp(val : Data): - return val[7:15] - - def fp_get_frac(val : Data): - return val[:7] - - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_neg(val : Data): - return Bit(val[-1]) - @family.assemble(locals(), globals()) - class fp_div(Peak): + class fp_div(Peak, BlackBox): @name_outputs(out=Data) def __call__(self, in0 : Data, in1 : Data) -> Data: - a_fpadd = bv2float_DW(in0) - b_fpadd = bv2float_DW(in1) - return Data(float_DW2bv(a_fpadd - b_fpadd)) + a_fpadd = BFloat(in0) + b_fpadd = BFloat(in1) + return Data((a_fpadd - b_fpadd)) return fp_div @@ -812,45 +625,27 @@ def __call__(self, in0 : Data, in1 : Data) -> Data: @family_closure def fp_max_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed[16] - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] - BFloat16 = BFloat16_fc(family) FPExpBV = family.BitVector[8] FPFracBV = family.BitVector[7] - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) - - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) - - def fp_get_exp(val : Data): - return val[7:15] - - def fp_get_frac(val : Data): - return val[:7] - - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_neg(val : Data): - return Bit(val[-1]) - @family.assemble(locals(), globals()) - class fp_max(Peak): + class fp_max(Peak, BlackBox): @name_outputs(out=Data) def __call__(self, in0 : Data, in1 : Data) -> Data: - a_fpadd = bv2float_DW(in0) - b_fpadd = bv2float_DW(in1) - return Data(float_DW2bv(a_fpadd - b_fpadd)) + a_fpadd = BFloat(in0) + b_fpadd = BFloat(in1) + return Data((a_fpadd - b_fpadd)) return fp_max @@ -858,44 +653,26 @@ def __call__(self, in0 : Data, in1 : Data) -> Data: @family_closure def fp_min_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed[16] - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] - BFloat16 = BFloat16_fc(family) FPExpBV = family.BitVector[8] FPFracBV = family.BitVector[7] - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) - - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) - - def fp_get_exp(val : Data): - return val[7:15] - - def fp_get_frac(val : Data): - return val[:7] - - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) - - def fp_is_neg(val : Data): - return Bit(val[-1]) - @family.assemble(locals(), globals()) - class fp_min(Peak): + class fp_min(Peak, BlackBox): @name_outputs(out=Data) def __call__(self, in0 : Data, in1 : Data) -> Data: - a_fpadd = bv2float_DW(in0) - b_fpadd = bv2float_DW(in1) + a_fpadd = BFloat(in0) + b_fpadd = BFloat(in1) return Data(a_fpadd - b_fpadd) return fp_min @@ -906,51 +683,44 @@ def __call__(self, in0 : Data, in1 : Data) -> Data: @family_closure def fp_cmp_fc(family: AbstractFamily): + BitVector = family.BitVector + BFloat = BFloat16_fc(family) Data = family.BitVector[16] - Data32 = family.Unsigned[32] - SInt = family.Signed[16] - UInt = family.Unsigned[16] Bit = family.Bit + SInt = family.Signed + SData = SInt[16] + UInt = family.Unsigned + UData = UInt[16] + UData32 = UInt[32] - BFloat16 = BFloat16_fc(family) - FPExpBV = family.BitVector[8] - FPFracBV = family.BitVector[7] - - def bv2float_DW(bv): - return BFloat16.reinterpret_from_bv(bv) - - def float_DW2bv(bvf): - return BFloat16.reinterpret_as_bv(bvf) - - def fp_get_exp(val : Data): - return val[7:15] - - def fp_get_frac(val : Data): - return val[:7] - - def fp_is_zero(val : Data): - return (fp_get_exp(val) == FPExpBV(0)) & (fp_get_frac(val) == FPFracBV(0)) + float_lib = float_lib.const_rm(RoundingMode.RNE) - def fp_is_inf(val : Data): - return (fp_get_exp(val) == FPExpBV(-1)) & (fp_get_frac(val) == FPFracBV(0)) + is_inf = float_lib.Is_infinite_fc(family) + is_neg = float_lib.Is_negative_fc(family) + is_zero = float_lib.Is_zero_fc(family) - def fp_is_neg(val : Data): - return Bit(val[-1]) + FPExpBV = family.BitVector[8] + FPFracBV = family.BitVector[7] @family.assemble(locals(), globals()) - class fp_cmp(Peak): + class fp_cmp(Peak, BlackBox): + def __init__(self): + self.is_inf: is_inf = is_inf() + self.is_neg: is_neg = is_neg() + self.is_zero: is_zero = is_zero() + @name_outputs(out=Bit) def __call__(self, in0 : Data, in1 : Data) -> Bit: - a_fpadd = bv2float_DW(in0) - b_fpadd = bv2float_DW(in1) - a_inf = fp_is_inf(in0) - b_inf = fp_is_inf(in1) - a_neg = fp_is_neg(in0) - b_neg = fp_is_neg(in1) - - res = Data(float_DW2bv(a_fpadd - b_fpadd)) - Z = fp_is_zero(res) + a_fpadd = BFloat(in0) + b_fpadd = BFloat(in1) + a_inf = self.is_inf(in0) + b_inf = self.is_inf(in1) + a_neg = self.is_neg(in0) + b_neg = self.is_neg(in1) + + res = Data((a_fpadd - b_fpadd)) + Z = self.is_zero(res) if (a_inf & b_inf) & (a_neg == b_neg): Z = Bit(1) @@ -966,10 +736,9 @@ def mult_middle_fc(family: AbstractFamily): Data32 = family.BitVector[32] class mult_middle(Peak): @name_outputs(out=Data) - def __call__(self, in0: Data, in1: Data) -> Data: - mul = Data32(in0) * Data32(in1) - res = mul >> 8 - return Data(res[0:16]) + def __call__(self, in1: Data, in0: Data) -> Data: + mul = Data32(in0.sext(16)) * Data32(in1.sext(16)) + return Data(mul[8:24]) return mult_middle CoreIR.add_instruction("commonlib.mult_middle", mult_middle_fc) diff --git a/metamapper/irs/coreir/ir.py b/metamapper/irs/coreir/ir.py index b30356a..166b278 100644 --- a/metamapper/irs/coreir/ir.py +++ b/metamapper/irs/coreir/ir.py @@ -12,6 +12,30 @@ def gen_peak_CoreIR(width): DATAWIDTH = 16 CoreIR = gen_custom_ops_peak_CoreIR(DATAWIDTH) + @family_closure + def mem_fc(family: AbstractFamily): + Data = family.BitVector[width] + Bit = family.Bit + class mem(Peak): + @name_outputs(data_out_0=Data, data_out_1=Data, stencil_valid=Bit) + def __call__(self, rst_n: Bit, clk_en: Bit, data_in_0: Data, chain_data_in_0: Data, data_in_1: Data, chain_data_in_1: Data, wen_in_0: Bit, ren_in_0: Bit, addr_in_0: Data, flush: Bit) -> (Data, Data, Bit): + return Data(0), Data(0), Bit(0) + return mem + + CoreIR.add_instruction("cgralib.Mem", mem_fc) + + @family_closure + def pond_fc(family: AbstractFamily): + Data = family.BitVector[width] + Bit = family.Bit + class pond(Peak): + @name_outputs(data_out_pond_0=Data, data_out_pond_1=Data, valid_out_pond=Bit) + def __call__(self, rst_n: Bit, clk_en: Bit, data_in_pond_0: Data, data_in_pond_1: Data, flush: Bit) -> (Data, Data, Bit): + return Data(0), Data(0), Bit(0) + return pond + + CoreIR.add_instruction("cgralib.Pond", pond_fc) + @family_closure def rom_fc(family: AbstractFamily): Data = family.BitVector[width] diff --git a/metamapper/map_design_top.py b/metamapper/map_design_top.py new file mode 100644 index 0000000..cbd3e5c --- /dev/null +++ b/metamapper/map_design_top.py @@ -0,0 +1,127 @@ +import glob +import sys +import importlib +import os +import json +from pathlib import Path +import delegator + +from lassen import PE_fc as lassen_fc +from metamapper.irs.coreir import gen_CoreIRNodes +import metamapper.coreir_util as cutil +import metamapper.peak_util as putil +from metamapper.node import Nodes +from metamapper import CoreIRContext +from metamapper.coreir_mapper import Mapper +from metamapper.common_passes import print_dag, gen_dag_img, Constant2CoreIRConstant +from peak.mapper import read_serialized_bindings + + +class _ArchCycles: + def get(self, node): + kind = node.kind()[0] + if kind == "Rom" or kind == "FPRom" or kind == "PipelineRegister": + return 1 + elif kind == "global.PE": + if "PIPELINED" in os.environ and os.environ["PIPELINED"].isnumeric(): + pe_cycles = int(os.environ["PIPELINED"]) + else: + pe_cycles = 1 + return pe_cycles + return 0 + + +lassen_location = os.path.join(Path(__file__).parent.parent.parent.resolve(), "lassen") +lassen_header = os.path.join( + Path(__file__).parent.parent.resolve(), "libs/lassen_header.json" +) + + +def gen_rrules(pipelined=False): + + # c = CoreIRContext() + # cmod = putil.peak_to_coreir(lassen_fc) + # c.serialize_header(lassen_header, [cmod]) + # c.serialize_definitions(pe_def, [cmod]) + mapping_funcs = [] + rrules = [] + ops = [] + + if pipelined: + rrule_files = glob.glob( + f"{lassen_location}/lassen/rewrite_rules/*_pipelined.json" + ) + else: + rrule_files = glob.glob(f"{lassen_location}/lassen/rewrite_rules/*.json") + rrule_files = [ + rrule_file for rrule_file in rrule_files if "pipelined" not in rrule_file + ] + + custom_rule_names = { + "mult_middle": "commonlib.mult_middle", + "abs": "commonlib.abs", + "fp_exp": "float.exp", + "fp_div": "float.div", + "fp_mux": "float.mux", + "fp_mul": "float_DW.fp_mul", + "fp_add": "float_DW.fp_add", + "fp_sub": "float.sub", + "fp_max": "float.max", + } + + for idx, rrule in enumerate(rrule_files): + rule_name = Path(rrule).stem + if ("fp" in rule_name and "pipelined" in rule_name) or rule_name.split( + "_pipelined" + )[0] in custom_rule_names: + rule_name = rule_name.split("_pipelined")[0] + if rule_name in custom_rule_names: + ops.append(custom_rule_names[rule_name]) + else: + ops.append(rule_name) + peak_eq = importlib.import_module(f"lassen.rewrite_rules.{rule_name}") + ir_fc = getattr(peak_eq, rule_name + "_fc") + mapping_funcs.append(ir_fc) + + with open(rrule, "r") as json_file: + rewrite_rule_in = json.load(json_file) + + rewrite_rule = read_serialized_bindings(rewrite_rule_in, ir_fc, lassen_fc) + if False: + counter_example = rewrite_rule.verify() + assert counter_example == None, f"{rule_name} failed" + rrules.append(rewrite_rule) + + return rrules, ops + + +def map_design_top(app_name, nodes, dag): + pe_reg_instrs = {} + pe_reg_instrs["const"] = 0 + pe_reg_instrs["bypass"] = 2 + pe_reg_instrs["reg"] = 3 + + pe_port_to_reg = {} + pe_port_to_reg["data0"] = "rega" + pe_port_to_reg["data1"] = "regb" + pe_port_to_reg["data2"] = "regc" + + pe_reg_info = {} + pe_reg_info['instrs'] = pe_reg_instrs + pe_reg_info['port_to_reg'] = pe_port_to_reg + + if "PIPELINED" in os.environ and os.environ["PIPELINED"].isnumeric(): + pe_cycles = int(os.environ["PIPELINED"]) + else: + pe_cycles = 1 + + + CoreIRNodes = gen_CoreIRNodes(16) + + rrules, ops = gen_rrules(pipelined = pe_cycles != 0) + + mapper = Mapper(CoreIRNodes, nodes, lazy=False, ops=ops, rrules=rrules, kernel_name_prefix=True) + + mapped_dag = mapper.do_mapping(dag, kname=app_name, node_cycles=None, convert_unbound=False, prove_mapping=False, pe_reg_info=pe_reg_info) + + return mapped_dag diff --git a/metamapper/rewrite_table.py b/metamapper/rewrite_table.py index cea37b9..5ce0000 100644 --- a/metamapper/rewrite_table.py +++ b/metamapper/rewrite_table.py @@ -63,13 +63,12 @@ def add_peak_rule(self, rule: PeakRule, name=None): from_dag = peak_to_dag(self.from_, rule.ir_fc, name=name) from_bv = rule.ir_fc(fam().PyFamily()) from_node_name = self.from_.name_from_peak(rule.ir_fc) - # print("from_dag", name) - # print_dag(from_dag) # Create to_dag by Wrapping _to_dag within ibinding and obinding # Get input/output names from peak_cls to_fc = rule.arch_fc to_node_name = self.to.name_from_peak(to_fc, name) + to_node_t = self.to.dag_nodes[to_node_name] assert issubclass(to_node_t, DagNode) to_bv = to_fc(fam().PyFamily()) @@ -141,8 +140,8 @@ def sel_from(path, node: DagNode): RemoveSelects().run(to_dag) #print("After rmSelects") #print_dag(to_dag) - #print("to_dag") - #print_dag(to_dag) + # print("to_dag") + #Verify that the io matches #TODO verify outputs match @@ -179,6 +178,7 @@ def discover(self, from_name, to_name, path_constraints={}, rr_name=None, solver def sort_rules(self): + self.rules.sort(key=lambda x: x.name) rule_nodes = [] for rule in self.rules: dag = rule.tile @@ -187,3 +187,12 @@ def sort_rules(self): keydict = dict(zip(self.rules, rule_nodes)) self.rules.sort(key=keydict.get, reverse=True) + + mul_add_rules = [] + for idx,rule in enumerate(self.rules): + if "mac" in rule.name or "muladd" in rule.name: + mul_add_rules.append(idx) + + for idx in mul_add_rules: + self.rules.insert(0, self.rules.pop(idx)) + diff --git a/scripts/map_app.py b/scripts/map_app.py index 3b98b7b..2cd5deb 100755 --- a/scripts/map_app.py +++ b/scripts/map_app.py @@ -16,17 +16,22 @@ from metamapper.common_passes import print_dag, gen_dag_img, Constant2CoreIRConstant from peak.mapper import read_serialized_bindings + class _ArchCycles: def get(self, node): kind = node.kind()[0] - if kind == "Rom" or kind == "FPRom": + if kind == "Rom" or kind == "FPRom" or kind == "PipelineRegister": return 1 elif kind == "global.PE": return pe_cycles return 0 + lassen_location = os.path.join(Path(__file__).parent.parent.parent.resolve(), "lassen") -lassen_header = os.path.join(Path(__file__).parent.parent.resolve(), "libs/lassen_header.json") +lassen_header = os.path.join( + Path(__file__).parent.parent.resolve(), "libs/lassen_header.json" +) + def gen_rrules(pipelined=False): @@ -39,16 +44,32 @@ def gen_rrules(pipelined=False): ops = [] if pipelined: - rrule_files = glob.glob(f'{lassen_location}/lassen/rewrite_rules/*_pipelined.json') + rrule_files = glob.glob( + f"{lassen_location}/lassen/rewrite_rules/*_pipelined.json" + ) else: - rrule_files = glob.glob(f'{lassen_location}/lassen/rewrite_rules/*.json') - rrule_files = [rrule_file for rrule_file in rrule_files if "pipelined" not in rrule_file] - - custom_rule_names = {"mult_middle": "commonlib.mult_middle", "fp_exp": "float.exp", "fp_div": "float.div", "fp_mux": "float.mux", "fp_mul":"float_DW.fp_mul", "fp_add":"float_DW.fp_add", "fp_sub":"float.sub"} + rrule_files = glob.glob(f"{lassen_location}/lassen/rewrite_rules/*.json") + rrule_files = [ + rrule_file for rrule_file in rrule_files if "pipelined" not in rrule_file + ] + + custom_rule_names = { + "mult_middle": "commonlib.mult_middle", + "abs": "commonlib.abs", + "fp_exp": "float.exp", + "fp_max": "float.max", + "fp_div": "float.div", + "fp_mux": "float.mux", + "fp_mul": "float_DW.fp_mul", + "fp_add": "float_DW.fp_add", + "fp_sub": "float.sub", + } for idx, rrule in enumerate(rrule_files): rule_name = Path(rrule).stem - if ("fp" in rule_name and "pipelined" in rule_name) or rule_name.split("_pipelined")[0] in custom_rule_names: + if ("fp" in rule_name and "pipelined" in rule_name) or rule_name.split( + "_pipelined" + )[0] in custom_rule_names: rule_name = rule_name.split("_pipelined")[0] if rule_name in custom_rule_names: ops.append(custom_rule_names[rule_name]) @@ -69,6 +90,7 @@ def gen_rrules(pipelined=False): return rrules, ops + pe_reg_instrs = {} pe_reg_instrs["const"] = 0 pe_reg_instrs["bypass"] = 2 @@ -80,16 +102,16 @@ def gen_rrules(pipelined=False): pe_port_to_reg["data2"] = "regc" pe_reg_info = {} -pe_reg_info['instrs'] = pe_reg_instrs -pe_reg_info['port_to_reg'] = pe_port_to_reg +pe_reg_info["instrs"] = pe_reg_instrs +pe_reg_info["port_to_reg"] = pe_port_to_reg file_name = str(sys.argv[1]) -if len(sys.argv) > 2: - pe_cycles = int(sys.argv[2]) +if "PIPELINED" in os.environ and os.environ["PIPELINED"].isnumeric(): + pe_cycles = int(os.environ["PIPELINED"]) else: - pe_cycles = 0 + pe_cycles = 1 -rrules, ops = gen_rrules(pipelined = pe_cycles != 0) +rrules, ops = gen_rrules(pipelined=pe_cycles != 0) verilog = False app = os.path.basename(file_name).split(".json")[0] output_dir = os.path.dirname(file_name) @@ -97,22 +119,23 @@ def gen_rrules(pipelined=False): c = CoreIRContext(reset=True) cutil.load_libs(["commonlib", "float_DW"]) CoreIRNodes = gen_CoreIRNodes(16) -cutil.load_from_json(file_name) #libraries=["lakelib"]) +cutil.load_from_json(file_name) # libraries=["lakelib"]) kernels = dict(c.global_namespace.modules) arch_fc = lassen_fc ArchNodes = Nodes("Arch") -putil.load_and_link_peak( - ArchNodes, - lassen_header, - {"global.PE": arch_fc} -) +putil.load_and_link_peak(ArchNodes, lassen_header, {"global.PE": arch_fc}) mr = "memory.fprom2" -ArchNodes.add(mr, CoreIRNodes.peak_nodes[mr], CoreIRNodes.coreir_modules[mr], CoreIRNodes.dag_nodes[mr]) +ArchNodes.add( + mr, + CoreIRNodes.peak_nodes[mr], + CoreIRNodes.coreir_modules[mr], + CoreIRNodes.dag_nodes[mr], +) -mapper = Mapper(CoreIRNodes, ArchNodes, lazy=False, ops = ops, rrules=rrules) +mapper = Mapper(CoreIRNodes, ArchNodes, lazy=False, ops=ops, rrules=rrules) c.run_passes(["rungenerators", "deletedeadinstances"]) mods = [] @@ -122,16 +145,29 @@ def gen_rrules(pipelined=False): dag = cutil.coreir_to_dag(CoreIRNodes, kmod, archnodes=ArchNodes) Constant2CoreIRConstant(CoreIRNodes).run(dag) - mapped_dag = mapper.do_mapping(dag, kname=kname, node_cycles=_ArchCycles(), convert_unbound=False, prove_mapping=False, pe_reg_info=pe_reg_info) - mod = cutil.dag_to_coreir(ArchNodes, mapped_dag, f"{kname}_mapped", convert_unbounds=verilog) + mapped_dag = mapper.do_mapping( + dag, + kname=kname, + node_cycles=_ArchCycles(), + convert_unbound=False, + prove_mapping=True, + pe_reg_info=pe_reg_info, + ) + + mod = cutil.dag_to_coreir( + ArchNodes, mapped_dag, f"{kname}_mapped", convert_unbounds=verilog + ) mods.append(mod) -print(f"Total num PEs used: {mapper.num_pes}") +print('\n\033[92m' + "All compute kernels passed formal checks" + '\033[0m') +print(f"Total num PEs used: {mapper.num_pes}\n") +print(f"Total num regs inserted: {mapper.num_regs}") + output_file = f"{output_dir}/{app}_mapped.json" print(f"saving to {output_file}") c.serialize_definitions(output_file, mods) with open(f'{output_dir}/{app}_kernel_latencies.json', 'w') as outfile: - json.dump(mapper.kernel_cycles, outfile) + json.dump(mapper.kernel_cycles, outfile, indent=4) diff --git a/scripts/map_dse.py b/scripts/map_dse.py index 1d4b966..196b328 100755 --- a/scripts/map_dse.py +++ b/scripts/map_dse.py @@ -19,6 +19,7 @@ from peak_gen.arch import read_arch from peak_gen.peak_wrapper import wrapped_peak_class + class _ArchCycles: def get(self, node): kind = node.kind()[0] @@ -28,9 +29,15 @@ def get(self, node): return pe_cycles return 0 -pe_location = os.path.join(Path(__file__).parent.parent.parent.resolve(), "DSEGraphAnalysis/outputs") + +pe_location = os.path.join( + Path(__file__).parent.parent.parent.resolve(), "DSEGraphAnalysis/outputs" +) pe_header = os.path.join(Path(__file__).parent.parent.resolve(), "libs/pe_header.json") -metamapper_location = os.path.join(Path(__file__).parent.parent.resolve(), "examples/peak_gen") +metamapper_location = os.path.join( + Path(__file__).parent.parent.resolve(), "examples/peak_gen" +) + def gen_rrules(): @@ -42,37 +49,43 @@ def gen_rrules(): mapping_funcs = [] rrules = [] - num_rrules = len(glob.glob(f'{pe_location}/rewrite_rules/*.json')) + num_rrules = len(glob.glob(f"{pe_location}/rewrite_rules/*.json")) - if not os.path.exists(f'{metamapper_location}'): - os.makedirs(f'{metamapper_location}') + if not os.path.exists(f"{metamapper_location}"): + os.makedirs(f"{metamapper_location}") for ind in range(num_rrules): with open(f"{pe_location}/peak_eqs/peak_eq_" + str(ind) + ".py", "r") as file: - with open(f"{metamapper_location}/peak_eq_" + str(ind) + ".py", "w") as outfile: + with open( + f"{metamapper_location}/peak_eq_" + str(ind) + ".py", "w" + ) as outfile: for line in file: - outfile.write(line.replace('mapping_function', 'mapping_function_'+str(ind))) + outfile.write( + line.replace("mapping_function", "mapping_function_" + str(ind)) + ) peak_eq = importlib.import_module("examples.peak_gen.peak_eq_" + str(ind)) ir_fc = getattr(peak_eq, "mapping_function_" + str(ind) + "_fc") mapping_funcs.append(ir_fc) - with open(f"{pe_location}/rewrite_rules/rewrite_rule_" + str(ind) + ".json", "r") as json_file: + with open( + f"{pe_location}/rewrite_rules/rewrite_rule_" + str(ind) + ".json", "r" + ) as json_file: rewrite_rule_in = json.load(json_file) rewrite_rule = read_serialized_bindings(rewrite_rule_in, ir_fc, PE_fc) counter_example = rewrite_rule.verify() - rrules.append(rewrite_rule) return PE_fc, rrules + file_name = str(sys.argv[1]) -if len(sys.argv) > 2: - pe_cycles = int(sys.argv[2]) +if "PIPELINED" in os.environ and os.environ["PIPELINED"].isnumeric(): + pe_cycles = int(os.environ["PIPELINED"]) else: - pe_cycles = 0 + pe_cycles = 1 arch_fc, rrules = gen_rrules() verilog = False @@ -83,15 +96,11 @@ def gen_rrules(): c = CoreIRContext(reset=True) cutil.load_libs(["commonlib", "float_DW"]) CoreIRNodes = gen_CoreIRNodes(16) -cutil.load_from_json(file_name) #libraries=["lakelib"]) +cutil.load_from_json(file_name) # libraries=["lakelib"]) kernels = dict(c.global_namespace.modules) ArchNodes = Nodes("Arch") -putil.load_and_link_peak( - ArchNodes, - pe_header, - {"global.PE": arch_fc} -) +putil.load_and_link_peak(ArchNodes, pe_header, {"global.PE": arch_fc}) mapper = Mapper(CoreIRNodes, ArchNodes, lazy=True, rrules=rrules) @@ -103,8 +112,16 @@ def gen_rrules(): dag = cutil.coreir_to_dag(CoreIRNodes, kmod, archnodes=ArchNodes) Constant2CoreIRConstant(CoreIRNodes).run(dag) - mapped_dag = mapper.do_mapping(dag, kname=kname, node_cycles=_ArchCycles(), convert_unbound=False, prove_mapping=False) - mod = cutil.dag_to_coreir(ArchNodes, mapped_dag, f"{kname}_mapped", convert_unbounds=verilog) + mapped_dag = mapper.do_mapping( + dag, + kname=kname, + node_cycles=_ArchCycles(), + convert_unbound=False, + prove_mapping=False, + ) + mod = cutil.dag_to_coreir( + ArchNodes, mapped_dag, f"{kname}_mapped", convert_unbounds=verilog + ) mods.append(mod) print(f"Num PEs used: {mapper.num_pes}") @@ -113,5 +130,5 @@ def gen_rrules(): c.serialize_definitions(output_file, mods) -with open(f'{output_dir}/{app}_kernel_latencies.json', 'w') as outfile: +with open(f"{output_dir}/{app}_kernel_latencies.json", "w") as outfile: json.dump(mapper.kernel_cycles, outfile) diff --git a/tests/test_kernel_mapping.py b/tests/test_kernel_mapping.py index f70bbcb..44e4a68 100644 --- a/tests/test_kernel_mapping.py +++ b/tests/test_kernel_mapping.py @@ -40,9 +40,10 @@ def get(self, node): def gen_rrules(pipelined=False): - c = CoreIRContext(reset=True) + c = CoreIRContext() cmod = putil.peak_to_coreir(lassen_fc) c.serialize_header(lassen_header, [cmod]) + # c.serialize_definitions(pe_def, [cmod]) mapping_funcs = [] rrules = [] ops = [] @@ -53,7 +54,7 @@ def gen_rrules(pipelined=False): rrule_files = glob.glob(f'{lassen_location}/lassen/rewrite_rules/*.json') rrule_files = [rrule_file for rrule_file in rrule_files if "pipelined" not in rrule_file] - custom_rule_names = {"mult_middle":"commonlib.mult_middle","fp_exp": "float.exp", "fp_div": "float.div", "fp_mux": "float.mux", "fp_mul":"float_DW.fp_mul", "fp_add":"float_DW.fp_add", "fp_sub":"float.sub"} + custom_rule_names = {"mult_middle": "commonlib.mult_middle", "fp_exp": "float.exp", "fp_div": "float.div", "fp_mux": "float.mux", "fp_mul":"float_DW.fp_mul", "fp_add":"float_DW.fp_add", "fp_sub":"float.sub"} for idx, rrule in enumerate(rrule_files): rule_name = Path(rrule).stem @@ -71,7 +72,9 @@ def gen_rrules(pipelined=False): rewrite_rule_in = json.load(json_file) rewrite_rule = read_serialized_bindings(rewrite_rule_in, ir_fc, lassen_fc) - + if False: + counter_example = rewrite_rule.verify() + assert counter_example == None, f"{rule_name} failed" rrules.append(rewrite_rule) return rrules, ops @@ -103,13 +106,12 @@ def test_kernel_mapping(pipelined, app): c = CoreIRContext(reset=True) cutil.load_libs(["commonlib", "float_DW"]) CoreIRNodes = gen_CoreIRNodes(16) - - cutil.load_from_json(app_file) - c.run_passes(["rungenerators", "deletedeadinstances"]) + cutil.load_from_json(app_file) #libraries=["lakelib"]) kernels = dict(c.global_namespace.modules) arch_fc = lassen_fc ArchNodes = Nodes("Arch") + putil.load_and_link_peak( ArchNodes, lassen_header, diff --git a/tests/test_mem_header.py b/tests/test_mem_header.py deleted file mode 100644 index 42547bc..0000000 --- a/tests/test_mem_header.py +++ /dev/null @@ -1,14 +0,0 @@ -from metamapper.lake_mem import gen_MEM_fc -from peak import family -from metamapper import peak_util as putil -from metamapper import CoreIRContext - - -def test_mem_header(): - MEM_fc = gen_MEM_fc() - MEM_py = MEM_fc(family.PyFamily()) - MEM = MEM_fc(family.MagmaFamily()) - cmod = putil.magma_to_coreir(MEM) - c = CoreIRContext() - c.serialize_header("libs/mem_header.json", [cmod]) -