From 1483b7f883916407d7d6b958ce69cf7dddfa4dd2 Mon Sep 17 00:00:00 2001 From: lekcyjna123 <34948061+lekcyjna123@users.noreply.github.com> Date: Tue, 14 May 2024 12:37:34 +0200 Subject: [PATCH] Port ContentAddressableMemory from https://github.com/kuznia-rdzeni/coreblocks/pull/395 (https://github.com/kuznia-rdzeni/coreblocks/pull/573) --- test/test_transactron_lib_storage.py | 135 +++++++++++++++++ test/utils/test_amaranth_ext.py | 92 ++++++++++++ transactron/lib/storage.py | 105 ++++++++++++- transactron/testing/__init__.py | 1 + transactron/testing/infrastructure.py | 58 +++++-- transactron/testing/input_generation.py | 97 ++++++++++++ .../utils/amaranth_ext/elaboratables.py | 142 +++++++++++++++++- transactron/utils/transactron_helpers.py | 2 +- 8 files changed, 611 insertions(+), 21 deletions(-) create mode 100644 test/test_transactron_lib_storage.py create mode 100644 test/utils/test_amaranth_ext.py create mode 100644 transactron/testing/input_generation.py diff --git a/test/test_transactron_lib_storage.py b/test/test_transactron_lib_storage.py new file mode 100644 index 0000000..1f14922 --- /dev/null +++ b/test/test_transactron_lib_storage.py @@ -0,0 +1,135 @@ +from datetime import timedelta +from hypothesis import given, settings, Phase +from transactron.testing import * +from transactron.lib.storage import ContentAddressableMemory + + +class TestContentAddressableMemory(TestCaseWithSimulator): + addr_width = 4 + content_width = 5 + test_number = 30 + nop_number = 3 + addr_layout = data_layout(addr_width) + content_layout = data_layout(content_width) + + def setUp(self): + self.entries_count = 8 + + self.circ = SimpleTestCircuit( + ContentAddressableMemory(self.addr_layout, self.content_layout, self.entries_count) + ) + + self.memory = {} + + def generic_process( + self, + method, + input_lst, + behaviour_check=None, + state_change=None, + input_verification=None, + settle_count=0, + name="", + ): + def f(): + while input_lst: + # wait till all processes will end the previous cycle + yield from self.multi_settle(4) + elem = input_lst.pop() + if isinstance(elem, OpNOP): + yield + continue + if input_verification is not None and not input_verification(elem): + yield + continue + response = yield from method.call(**elem) + yield from self.multi_settle(settle_count) + if behaviour_check is not None: + # Here accesses to circuit are allowed + ret = behaviour_check(elem, response) + if isinstance(ret, Generator): + yield from ret + if state_change is not None: + # It is standard python function by purpose to don't allow accessing circuit + state_change(elem, response) + yield + + return f + + def push_process(self, in_push): + def verify_in(elem): + return not (frozenset(elem["addr"].items()) in self.memory) + + def modify_state(elem, response): + self.memory[frozenset(elem["addr"].items())] = elem["data"] + + return self.generic_process( + self.circ.push, + in_push, + state_change=modify_state, + input_verification=verify_in, + settle_count=3, + name="push", + ) + + def read_process(self, in_read): + def check(elem, response): + addr = elem["addr"] + frozen_addr = frozenset(addr.items()) + if frozen_addr in self.memory: + assert response["not_found"] == 0 + assert response["data"] == self.memory[frozen_addr] + else: + assert response["not_found"] == 1 + + return self.generic_process(self.circ.read, in_read, behaviour_check=check, settle_count=0, name="read") + + def remove_process(self, in_remove): + def modify_state(elem, response): + if frozenset(elem["addr"].items()) in self.memory: + del self.memory[frozenset(elem["addr"].items())] + + return self.generic_process(self.circ.remove, in_remove, state_change=modify_state, settle_count=2, name="remv") + + def write_process(self, in_write): + def verify_in(elem): + ret = frozenset(elem["addr"].items()) in self.memory + return ret + + def check(elem, response): + assert response["not_found"] == int(frozenset(elem["addr"].items()) not in self.memory) + + def modify_state(elem, response): + if frozenset(elem["addr"].items()) in self.memory: + self.memory[frozenset(elem["addr"].items())] = elem["data"] + + return self.generic_process( + self.circ.write, + in_write, + behaviour_check=check, + state_change=modify_state, + input_verification=None, + settle_count=1, + name="writ", + ) + + @settings( + max_examples=10, + phases=(Phase.explicit, Phase.reuse, Phase.generate, Phase.shrink), + derandomize=True, + deadline=timedelta(milliseconds=500), + ) + @given( + generate_process_input(test_number, nop_number, [("addr", addr_layout), ("data", content_layout)]), + generate_process_input(test_number, nop_number, [("addr", addr_layout), ("data", content_layout)]), + generate_process_input(test_number, nop_number, [("addr", addr_layout)]), + generate_process_input(test_number, nop_number, [("addr", addr_layout)]), + ) + def test_random(self, in_push, in_write, in_read, in_remove): + with self.reinitialize_fixtures(): + self.setUp() + with self.run_simulation(self.circ, max_cycles=500) as sim: + sim.add_sync_process(self.push_process(in_push)) + sim.add_sync_process(self.read_process(in_read)) + sim.add_sync_process(self.write_process(in_write)) + sim.add_sync_process(self.remove_process(in_remove)) diff --git a/test/utils/test_amaranth_ext.py b/test/utils/test_amaranth_ext.py new file mode 100644 index 0000000..7b7bd46 --- /dev/null +++ b/test/utils/test_amaranth_ext.py @@ -0,0 +1,92 @@ +from transactron.testing import * +import random +from transactron.utils.amaranth_ext import MultiPriorityEncoder + + +class TestMultiPriorityEncoder(TestCaseWithSimulator): + def get_expected(self, input): + places = [] + for i in range(self.input_width): + if input % 2: + places.append(i) + input //= 2 + places += [None] * self.output_count + return places + + def process(self): + for _ in range(self.test_number): + input = random.randrange(2**self.input_width) + yield self.circ.input.eq(input) + yield Settle() + expected_output = self.get_expected(input) + for ex, real, valid in zip(expected_output, self.circ.outputs, self.circ.valids): + if ex is None: + assert (yield valid) == 0 + else: + assert (yield valid) == 1 + assert (yield real) == ex + yield Delay(1e-7) + + @pytest.mark.parametrize("input_width", [1, 5, 16, 23, 24]) + @pytest.mark.parametrize("output_count", [1, 3, 4]) + def test_random(self, input_width, output_count): + random.seed(input_width + output_count) + self.test_number = 50 + self.input_width = input_width + self.output_count = output_count + self.circ = MultiPriorityEncoder(self.input_width, self.output_count) + + with self.run_simulation(self.circ) as sim: + sim.add_process(self.process) + + @pytest.mark.parametrize("name", ["prio_encoder", None]) + def test_static_create_simple(self, name): + random.seed(14) + self.test_number = 50 + self.input_width = 7 + self.output_count = 1 + + class DUT(Elaboratable): + def __init__(self, input_width, output_count, name): + self.input = Signal(input_width) + self.output_count = output_count + self.input_width = input_width + self.name = name + + def elaborate(self, platform): + m = Module() + out, val = MultiPriorityEncoder.create_simple(m, self.input_width, self.input, name=self.name) + # Save as a list to use common interface in testing + self.outputs = [out] + self.valids = [val] + return m + + self.circ = DUT(self.input_width, self.output_count, name) + + with self.run_simulation(self.circ) as sim: + sim.add_process(self.process) + + @pytest.mark.parametrize("name", ["prio_encoder", None]) + def test_static_create(self, name): + random.seed(14) + self.test_number = 50 + self.input_width = 7 + self.output_count = 2 + + class DUT(Elaboratable): + def __init__(self, input_width, output_count, name): + self.input = Signal(input_width) + self.output_count = output_count + self.input_width = input_width + self.name = name + + def elaborate(self, platform): + m = Module() + out = MultiPriorityEncoder.create(m, self.input_width, self.input, self.output_count, name=self.name) + self.outputs, self.valids = list(zip(*out)) + return m + + self.circ = DUT(self.input_width, self.output_count, name) + + with self.run_simulation(self.circ) as sim: + sim.add_process(self.process) diff --git a/transactron/lib/storage.py b/transactron/lib/storage.py index 3bbf076..835406b 100644 --- a/transactron/lib/storage.py +++ b/transactron/lib/storage.py @@ -3,12 +3,12 @@ from transactron.utils.transactron_helpers import from_method_layout, make_layout from ..core import * -from ..utils import SrcLoc, get_src_loc +from ..utils import SrcLoc, get_src_loc, MultiPriorityEncoder from typing import Optional -from transactron.utils import assign, AssignType, LayoutList +from transactron.utils import assign, AssignType, LayoutList, MethodLayout from .reqres import ArgumentsToResultsZipper -__all__ = ["MemoryBank", "AsyncMemoryBank"] +__all__ = ["MemoryBank", "ContentAddressableMemory", "AsyncMemoryBank"] class MemoryBank(Elaboratable): @@ -37,7 +37,7 @@ def __init__( elem_count: int, granularity: Optional[int] = None, safe_writes: bool = True, - src_loc: int | SrcLoc = 0 + src_loc: int | SrcLoc = 0, ): """ Parameters @@ -138,6 +138,103 @@ def _(arg): return m +class ContentAddressableMemory(Elaboratable): + """Content addresable memory + + This module implements a content-addressable memory (in short CAM) with Transactron interface. + CAM is a type of memory where instead of predefined indexes there are used values fed in runtime + as keys (similar as in python dictionary). To insert new entry a pair `(key, value)` has to be + provided. Such pair takes an free slot which depends on internal implementation. To read value + a `key` has to be provided. It is compared with every valid key stored in CAM. If there is a hit, + a value is read. There can be many instances of the same key in CAM. In such case it is undefined + which value will be read. + + + .. warning:: + Pushing the value with index already present in CAM is an undefined behaviour. + + Attributes + ---------- + read : Method + Nondestructive read + write : Method + If index present - do update + remove : Method + Remove + push : Method + Inserts new data. + """ + + def __init__(self, address_layout: MethodLayout, data_layout: MethodLayout, entries_number: int): + """ + Parameters + ---------- + address_layout : LayoutLike + The layout of the address records. + data_layout : LayoutLike + The layout of the data. + entries_number : int + The number of slots to create in memory. + """ + self.address_layout = from_method_layout(address_layout) + self.data_layout = from_method_layout(data_layout) + self.entries_number = entries_number + + self.read = Method(i=[("addr", self.address_layout)], o=[("data", self.data_layout), ("not_found", 1)]) + self.remove = Method(i=[("addr", self.address_layout)]) + self.push = Method(i=[("addr", self.address_layout), ("data", self.data_layout)]) + self.write = Method(i=[("addr", self.address_layout), ("data", self.data_layout)], o=[("not_found", 1)]) + + def elaborate(self, platform) -> TModule: + m = TModule() + + address_array = Array( + [Signal(self.address_layout, name=f"address_array_{i}") for i in range(self.entries_number)] + ) + data_array = Array([Signal(self.data_layout, name=f"data_array_{i}") for i in range(self.entries_number)]) + valids = Signal(self.entries_number, name="valids") + + m.submodules.encoder_read = encoder_read = MultiPriorityEncoder(self.entries_number, 1) + m.submodules.encoder_write = encoder_write = MultiPriorityEncoder(self.entries_number, 1) + m.submodules.encoder_push = encoder_push = MultiPriorityEncoder(self.entries_number, 1) + m.submodules.encoder_remove = encoder_remove = MultiPriorityEncoder(self.entries_number, 1) + m.d.top_comb += encoder_push.input.eq(~valids) + + @def_method(m, self.push, ready=~valids.all()) + def _(addr, data): + id = Signal(range(self.entries_number), name="id_push") + m.d.top_comb += id.eq(encoder_push.outputs[0]) + m.d.sync += address_array[id].eq(addr) + m.d.sync += data_array[id].eq(data) + m.d.sync += valids.bit_select(id, 1).eq(1) + + @def_method(m, self.write) + def _(addr, data): + write_mask = Signal(self.entries_number, name="write_mask") + m.d.top_comb += write_mask.eq(Cat([addr == stored_addr for stored_addr in address_array]) & valids) + m.d.top_comb += encoder_write.input.eq(write_mask) + with m.If(write_mask.any()): + m.d.sync += data_array[encoder_write.outputs[0]].eq(data) + return {"not_found": ~write_mask.any()} + + @def_method(m, self.read) + def _(addr): + read_mask = Signal(self.entries_number, name="read_mask") + m.d.top_comb += read_mask.eq(Cat([addr == stored_addr for stored_addr in address_array]) & valids) + m.d.top_comb += encoder_read.input.eq(read_mask) + return {"data": data_array[encoder_read.outputs[0]], "not_found": ~read_mask.any()} + + @def_method(m, self.remove) + def _(addr): + rm_mask = Signal(self.entries_number, name="rm_mask") + m.d.top_comb += rm_mask.eq(Cat([addr == stored_addr for stored_addr in address_array]) & valids) + m.d.top_comb += encoder_remove.input.eq(rm_mask) + with m.If(rm_mask.any()): + m.d.sync += valids.bit_select(encoder_remove.outputs[0], 1).eq(0) + + return m + + class AsyncMemoryBank(Elaboratable): """AsyncMemoryBank module. diff --git a/transactron/testing/__init__.py b/transactron/testing/__init__.py index bc5d38f..aa21522 100644 --- a/transactron/testing/__init__.py +++ b/transactron/testing/__init__.py @@ -1,3 +1,4 @@ +from .input_generation import * # noqa: F401 from .functions import * # noqa: F401 from .infrastructure import * # noqa: F401 from .sugar import * # noqa: F401 diff --git a/transactron/testing/infrastructure.py b/transactron/testing/infrastructure.py index e2e768d..f902658 100644 --- a/transactron/testing/infrastructure.py +++ b/transactron/testing/infrastructure.py @@ -214,8 +214,8 @@ def register_logging_handler(self): ch.setFormatter(formatter) root_logger.handlers += [ch] - @pytest.fixture(autouse=True) - def configure_dependency_context(self, request): + @contextmanager + def configure_dependency_context(self): self.dependency_manager = DependencyManager() with DependencyContext(self.dependency_manager): yield @@ -235,20 +235,14 @@ def add_all_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None: self.add_class_mocks(sim) self.add_local_mocks(sim, frame_locals) - @pytest.fixture(autouse=True) - def configure_traces(self, request): + def configure_traces(self): traces_file = None if "__TRANSACTRON_DUMP_TRACES" in os.environ: - traces_file = ".".join(request.node.nodeid.split("/")) + traces_file = self._transactron_current_output_file_name self._transactron_infrastructure_traces_file = traces_file - @pytest.fixture(autouse=True) - def fixture_sim_processes_to_add(self): - # By default return empty lists, it will be updated by other fixtures based on needs - self._transactron_sim_processes_to_add: list[Callable[[], Optional[Callable]]] = [] - - @pytest.fixture(autouse=True) - def configure_profiles(self, request, fixture_sim_processes_to_add, configure_dependency_context): + @contextmanager + def configure_profiles(self): profile = None if "__TRANSACTRON_PROFILE" in os.environ: @@ -268,12 +262,11 @@ def f(): if profile is not None: profile_dir = "test/__profiles__" - profile_file = ".".join(request.node.nodeid.split("/")) + profile_file = self._transactron_current_output_file_name os.makedirs(profile_dir, exist_ok=True) profile.encode(f"{profile_dir}/{profile_file}.json") - @pytest.fixture(autouse=True) - def configure_logging(self, fixture_sim_processes_to_add, register_logging_handler): + def configure_logging(self): def on_error(): assert False, "Simulation finished due to an error" @@ -281,6 +274,37 @@ def on_error(): log_filter = os.environ["__TRANSACTRON_LOG_FILTER"] self._transactron_sim_processes_to_add.append(lambda: make_logging_process(log_level, log_filter, on_error)) + @contextmanager + def reinitialize_fixtures(self): + # File name to be used in the current test run (either standard or hypothesis iteration) + # for standard tests it will always have the suffix "_0". For hypothesis tests, it will be suffixed + # with the current hypothesis iteration number, so that each hypothesis run is saved to a + # the different file. + self._transactron_current_output_file_name = ( + self._transactron_base_output_file_name + "_" + str(self._transactron_hypothesis_iter_counter) + ) + self._transactron_sim_processes_to_add: list[Callable[[], Optional[Callable]]] = [] + with self.configure_dependency_context(): + self.configure_traces() + with self.configure_profiles(): + self.configure_logging() + yield + self._transactron_hypothesis_iter_counter += 1 + + @pytest.fixture(autouse=True) + def fixture_initialize_testing_env(self, request): + # Hypothesis creates a single instance of a test class, which is later reused multiple times. + # This means that pytest fixtures are only run once. We can take advantage of this behaviour and + # initialise hypothesis related variables. + + # The counter for distinguishing between successive hypothesis iterations, it is incremented + # by `reinitialize_fixtures` which should be started at the beginning of each hypothesis run + self._transactron_hypothesis_iter_counter = 0 + # Base name which will be used later to create file names for particular outputs + self._transactron_base_output_file_name = ".".join(request.node.nodeid.split("/")) + with self.reinitialize_fixtures(): + yield + @contextmanager def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True): clk_period = 1e-6 @@ -323,3 +347,7 @@ def random_wait_geom(self, prob: float = 0.5): """ while random.random() > prob: yield + + def multi_settle(self, settle_count: int = 1): + for _ in range(settle_count): + yield Settle() diff --git a/transactron/testing/input_generation.py b/transactron/testing/input_generation.py new file mode 100644 index 0000000..909da7a --- /dev/null +++ b/transactron/testing/input_generation.py @@ -0,0 +1,97 @@ +from amaranth import * +from amaranth.lib.data import StructLayout +from typing import TypeVar +import hypothesis.strategies as st +from hypothesis.strategies import composite, DrawFn, integers, SearchStrategy +from transactron.utils import MethodLayout, RecordIntDict + + +class OpNOP: + def __repr__(self): + return "OpNOP()" + + +T = TypeVar("T") + + +@composite +def generate_shrinkable_list(draw: DrawFn, length: int, generator: SearchStrategy[T]) -> list[T]: + """ + Trick based on https://github.com/HypothesisWorks/hypothesis/blob/ + 6867da71beae0e4ed004b54b92ef7c74d0722815/hypothesis-python/src/hypothesis/stateful.py#L143 + """ + hp_data = draw(st.data()) + lst = [] + if length == 0: + return lst + i = 0 + force_val = None + while True: + b = hp_data.conjecture_data.draw_boolean(p=2**-16, forced=force_val) + if b: + break + lst.append(draw(generator)) + i += 1 + if i == length: + force_val = True + return lst + + +@composite +def generate_based_on_layout(draw: DrawFn, layout: MethodLayout) -> RecordIntDict: + if isinstance(layout, StructLayout): + raise NotImplementedError("StructLayout is not supported in automatic value generation.") + d = {} + for name, sublayout in layout: + if isinstance(sublayout, list): + elem = draw(generate_based_on_layout(sublayout)) + elif isinstance(sublayout, int): + elem = draw(integers(min_value=0, max_value=sublayout)) + elif isinstance(sublayout, range): + elem = draw(integers(min_value=sublayout.start, max_value=sublayout.stop - 1)) + elif isinstance(sublayout, Shape): + if sublayout.signed: + min_value = -(2 ** (sublayout.width - 1)) + max_value = 2 ** (sublayout.width - 1) - 1 + else: + min_value = 0 + max_value = 2**sublayout.width + elem = draw(integers(min_value=min_value, max_value=max_value)) + else: + # Currently type[Enum] and ShapeCastable + raise NotImplementedError("Passed LayoutList with syntax yet unsuported in automatic value generation.") + d[name] = elem + return d + + +def insert_nops(draw: DrawFn, max_nops: int, lst: list): + nops_nr = draw(integers(min_value=0, max_value=max_nops)) + for i in range(nops_nr): + lst.append(OpNOP()) + return lst + + +@composite +def generate_nops_in_list(draw: DrawFn, max_nops: int, generate_list: SearchStrategy[list[T]]) -> list[T | OpNOP]: + lst = draw(generate_list) + out_lst = [] + out_lst = insert_nops(draw, max_nops, out_lst) + for i in lst: + out_lst.append(i) + out_lst = insert_nops(draw, max_nops, out_lst) + return out_lst + + +@composite +def generate_method_input(draw: DrawFn, args: list[tuple[str, MethodLayout]]) -> dict[str, RecordIntDict]: + out = [] + for name, layout in args: + out.append((name, draw(generate_based_on_layout(layout)))) + return dict(out) + + +@composite +def generate_process_input( + draw: DrawFn, elem_count: int, max_nops: int, layouts: list[tuple[str, MethodLayout]] +) -> list[dict[str, RecordIntDict] | OpNOP]: + return draw(generate_nops_in_list(max_nops, generate_shrinkable_list(elem_count, generate_method_input(layouts)))) diff --git a/transactron/utils/amaranth_ext/elaboratables.py b/transactron/utils/amaranth_ext/elaboratables.py index b0ddbae..6048bc7 100644 --- a/transactron/utils/amaranth_ext/elaboratables.py +++ b/transactron/utils/amaranth_ext/elaboratables.py @@ -3,7 +3,7 @@ from typing import Literal, Optional, overload from collections.abc import Iterable from amaranth import * -from transactron.utils._typing import HasElaborate, ModuleLike +from transactron.utils._typing import HasElaborate, ModuleLike, ValueLike __all__ = [ "OneHotSwitchDynamic", @@ -11,6 +11,7 @@ "ModuleConnector", "Scheduler", "RoundRobin", + "MultiPriorityEncoder", ] @@ -237,3 +238,142 @@ def elaborate(self, platform): m.d.sync += self.valid.eq(self.requests.any()) return m + + +class MultiPriorityEncoder(Elaboratable): + """Priority encoder with more outputs + + This is an extension of the `PriorityEncoder` from amaranth that supports + more than one output from an input signal. In other words + it decodes multi-hot encoded signal into lists of signals in binary + format, each with the index of a different high bit in the input. + + Attributes + ---------- + input_width : int + Width of the input signal + outputs_count : int + Number of outputs to generate at once. + input : Signal, in + Signal with 1 on `i`-th bit if `i` can be selected by encoder + outputs : list[Signal], out + Signals with selected indicies, sorted in ascending order, + if the number of ready signals is less than `outputs_count` + then valid signals are at the beginning of the list. + valids : list[Signal], out + One bit for each output signal, indicating whether the output is valid or not. + """ + + def __init__(self, input_width: int, outputs_count: int): + self.input_width = input_width + self.outputs_count = outputs_count + + self.input = Signal(self.input_width) + self.outputs = [Signal(range(self.input_width), name=f"output_{i}") for i in range(self.outputs_count)] + self.valids = [Signal(name=f"valid_{i}") for i in range(self.outputs_count)] + + @staticmethod + def create( + m: Module, input_width: int, input: ValueLike, outputs_count: int = 1, name: Optional[str] = None + ) -> list[tuple[Signal, Signal]]: + """Syntax sugar for creating MultiPriorityEncoder + + This static method allows to use MultiPriorityEncoder in a more functional + way. Instead of creating the instance manually, connecting all the signals and + adding a submodule, you can call this function to do it automatically. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + m.submodules += prio_encoder = PriorityEncoder(cnt) + m.d.top_comb += prio_encoder.input.eq(one_hot_singal) + idx = prio_encoder.outputs + valid = prio.encoder.valids + + Parameters + ---------- + m: Module + Module to add the MultiPriorityEncoder to. + input_width : int + Width of the one hot signal. + input : ValueLike + The one hot signal to decode. + outputs_count : int + Number of different decoder outputs to generate at once. Default: 1. + name : Optional[str] + Name to use when adding MultiPriorityEncoder to submodules. + If None, it will be added as an anonymous submodule. The given name + can not be used in a submodule that has already been added. Default: None. + + Returns + ------- + return : list[tuple[Signal, Signal]] + Returns a list with len equal to outputs_count. Each tuple contains + a pair of decoded index on the first position and a valid signal + on the second position. + """ + prio_encoder = MultiPriorityEncoder(input_width, outputs_count) + if name is None: + m.submodules += prio_encoder + else: + try: + getattr(m.submodules, name) + raise ValueError(f"Name: {name} is already in use, so MultiPriorityEncoder can not be added with it.") + except AttributeError: + setattr(m.submodules, name, prio_encoder) + m.d.comb += prio_encoder.input.eq(input) + return list(zip(prio_encoder.outputs, prio_encoder.valids)) + + @staticmethod + def create_simple( + m: Module, input_width: int, input: ValueLike, name: Optional[str] = None + ) -> tuple[Signal, Signal]: + """Syntax sugar for creating MultiPriorityEncoder + + This is the same as `create` function, but with `outputs_count` hardcoded to 1. + """ + lst = MultiPriorityEncoder.create(m, input_width, input, outputs_count=1, name=name) + return lst[0] + + def build_tree(self, m: Module, in_sig: Signal, start_idx: int): + assert len(in_sig) > 0 + level_outputs = [ + Signal(range(self.input_width), name=f"_lvl_out_idx{start_idx}_{i}") for i in range(self.outputs_count) + ] + level_valids = [Signal(name=f"_lvl_val_idx{start_idx}_{i}") for i in range(self.outputs_count)] + if len(in_sig) == 1: + with m.If(in_sig): + m.d.comb += level_outputs[0].eq(start_idx) + m.d.comb += level_valids[0].eq(1) + else: + middle = len(in_sig) // 2 + r_in = Signal(middle, name=f"_r_in_idx{start_idx}") + l_in = Signal(len(in_sig) - middle, name=f"_l_in_idx{start_idx}") + m.d.comb += r_in.eq(in_sig[0:middle]) + m.d.comb += l_in.eq(in_sig[middle:]) + r_out, r_val = self.build_tree(m, r_in, start_idx) + l_out, l_val = self.build_tree(m, l_in, start_idx + middle) + + with m.Switch(Cat(r_val)): + for i in range(self.outputs_count + 1): + with m.Case((1 << i) - 1): + for j in range(i): + m.d.comb += level_outputs[j].eq(r_out[j]) + m.d.comb += level_valids[j].eq(r_val[j]) + for j in range(i, self.outputs_count): + m.d.comb += level_outputs[j].eq(l_out[j - i]) + m.d.comb += level_valids[j].eq(l_val[j - i]) + return level_outputs, level_valids + + def elaborate(self, platform): + m = Module() + + level_outputs, level_valids = self.build_tree(m, self.input, 0) + + for k in range(self.outputs_count): + m.d.comb += self.outputs[k].eq(level_outputs[k]) + m.d.comb += self.valids[k].eq(level_valids[k]) + + return m diff --git a/transactron/utils/transactron_helpers.py b/transactron/utils/transactron_helpers.py index 048a2bb..0bb55f3 100644 --- a/transactron/utils/transactron_helpers.py +++ b/transactron/utils/transactron_helpers.py @@ -133,7 +133,7 @@ def from_layout_field(shape: ShapeLike | LayoutList) -> ShapeLike: return shape -def make_layout(*fields: LayoutListField): +def make_layout(*fields: LayoutListField) -> StructLayout: return from_method_layout(fields)