diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index bee80912d..1fc7fac1d 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -38,7 +38,7 @@ jobs: - name: Synthesize run: | . venv/bin/activate - PYTHONHASHSEED=0 ./scripts/synthesize.py --verbose --config ${{ matrix.config }} + PYTHONHASHSEED=0 ./scripts/synthesize.py --verbose --strip-debug --config ${{ matrix.config }} - name: Print synthesis information run: cat ./build/top.tim @@ -63,7 +63,7 @@ jobs: build-perf-benchmarks: name: Build performance benchmarks runs-on: ubuntu-latest - container: ghcr.io/kuznia-rdzeni/riscv-toolchain:2023.11.19_v + container: ghcr.io/kuznia-rdzeni/riscv-toolchain:2024.03.12 steps: - name: Checkout uses: actions/checkout@v3 @@ -82,7 +82,7 @@ jobs: run-perf-benchmarks: name: Run performance benchmarks runs-on: ubuntu-latest - timeout-minutes: 60 + timeout-minutes: 30 container: ghcr.io/kuznia-rdzeni/verilator:v5.008-2023.11.19_v needs: build-perf-benchmarks steps: diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f450a431b..06ceb129d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -30,7 +30,7 @@ jobs: python3 -m venv venv . venv/bin/activate python3 -m pip install --upgrade pip - python3 -m pip install -r requirements-dev.txt + python3 -m pip install -r requirements.txt - name: Generate Verilog run: | @@ -40,13 +40,15 @@ jobs: - uses: actions/upload-artifact@v3 with: name: "verilog-full-core" - path: core.v + path: | + core.v + core.v.json build-riscof-tests: name: Build regression tests (riscv-arch-test) runs-on: ubuntu-latest - container: ghcr.io/kuznia-rdzeni/riscv-toolchain:2023.11.19_v + container: ghcr.io/kuznia-rdzeni/riscv-toolchain:2024.03.12 timeout-minutes: 10 env: PYENV_ROOT: "/root/.pyenv" @@ -55,35 +57,77 @@ jobs: defaults: run: working-directory: test/external/riscof/ + steps: - name: Checkout uses: actions/checkout@v3 - - name: Setup PATH + - name: Get submodules HEAD hash + working-directory: . + run: | + # ownership workaround + git config --global --add safe.directory /__w/coreblocks/coreblocks + # paths in command are relative! + git submodule > .gitmodules-hash + + - name: Cache compiled and reference riscv-arch-test + id: cache-riscv-arch-test + uses: actions/cache@v3 + env: + cache-name: cache-riscv-arch-test + with: + path: | + test/external/riscof/riscv-arch-test/**/*.elf + test/external/riscof/riscof_work/**/*.signature + test/external/riscof/**/*Makefile* + + key: ${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles( + '**/test/external/riscof/coreblocks/**', + '**/test/external/riscof/spike_simple/**', + '**/test/external/riscof/config.ini', + '**/.gitmodules-hash', + '**/docker/riscv-toolchain.Dockerfile', + '**/.github/workflows/main.yml' + ) }} + lookup-only: true + + - if: ${{ steps.cache-riscv-arch-test.outputs.cache-hit != 'true' }} + name: Checkout with submodules + uses: actions/checkout@v3 + with: + submodules: recursive + + - if: ${{ steps.cache-riscv-arch-test.outputs.cache-hit != 'true' }} + name: Setup PATH run: echo "/.pyenv/bin" >> $GITHUB_PATH - - name: Setup pyenv python + - if: ${{ steps.cache-riscv-arch-test.outputs.cache-hit != 'true' }} + name: Setup pyenv python run: | eval "$(pyenv init --path)" pyenv global 3.6.15 . /venv3.6/bin/activate - - name: Setup arch test suite + - if: ${{ steps.cache-riscv-arch-test.outputs.cache-hit != 'true' }} + name: Setup arch test suite run: | . /venv3.6/bin/activate - riscof --verbose info arch-test --clone riscof testlist --config=config.ini --suite=riscv-arch-test/riscv-test-suite/ --env=riscv-arch-test/riscv-test-suite/env - - name: Build and run tests on reference and generate Makefiles + - if: ${{ steps.cache-riscv-arch-test.outputs.cache-hit != 'true' }} + name: Build and run tests on reference and generate Makefiles run: | . /venv3.6/bin/activate riscof run --config=config.ini --suite=riscv-arch-test/riscv-test-suite/ --env=riscv-arch-test/riscv-test-suite/env - - name: Build tests for Coreblocks + - if: ${{ steps.cache-riscv-arch-test.outputs.cache-hit != 'true' }} + name: Build tests for Coreblocks run: | MAKEFILE_PATH=riscof_work/Makefile.build-DUT-coreblocks ../../../ci/riscof_run_makefile.sh - - uses: actions/upload-artifact@v3 + - if: ${{ steps.cache-riscv-arch-test.outputs.cache-hit != 'true' }} + name: Upload compiled and reference tests artifact + uses: actions/upload-artifact@v3 with: name: "riscof-tests" path: | @@ -96,7 +140,7 @@ jobs: runs-on: ubuntu-latest container: ghcr.io/kuznia-rdzeni/verilator:v5.008-2023.11.19_v needs: [ build-riscof-tests, build-core ] - timeout-minutes: 20 + timeout-minutes: 30 steps: - name: Checkout uses: actions/checkout@v3 @@ -114,14 +158,35 @@ jobs: python3 -m pip install -r requirements-dev.txt - uses: actions/download-artifact@v3 + name: Download full verilog core with: name: "verilog-full-core" path: . - - uses: actions/download-artifact@v3 + - name: Get submodules HEAD hash + run: | + git config --global --add safe.directory /__w/coreblocks/coreblocks + git submodule > .gitmodules-hash + + - uses: actions/cache@v3 + name: Download tests from cache + env: + cache-name: cache-riscv-arch-test with: - name: "riscof-tests" - path: test/external/riscof/ + path: | + test/external/riscof/riscv-arch-test/**/*.elf + test/external/riscof/riscof_work/**/*.signature + test/external/riscof/**/*Makefile* + + key: ${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles( + '**/test/external/riscof/coreblocks/**', + '**/test/external/riscof/spike_simple/**', + '**/test/external/riscof/config.ini', + '**/.gitmodules-hash', + '**/docker/riscv-toolchain.Dockerfile', + '**/.github/workflows/main.yml' + ) }} + fail-on-cache-miss: true - name: Run tests on Coreblocks run: | @@ -135,14 +200,16 @@ jobs: build-regression-tests: name: Build regression tests (riscv-tests) runs-on: ubuntu-latest - container: ghcr.io/kuznia-rdzeni/riscv-toolchain:2023.11.19_v - outputs: - cache_hit: ${{ steps.cache-regression.outputs.cache-hit }} + container: ghcr.io/kuznia-rdzeni/riscv-toolchain:2024.03.12 + timeout-minutes: 10 steps: - name: Checkout uses: actions/checkout@v3 - with: - submodules: recursive + + - name: Get submodules HEAD hash + run: | + git config --global --add safe.directory /__w/coreblocks/coreblocks + git submodule > .gitmodules-hash - name: Cache regression-tests id: cache-regression @@ -151,15 +218,20 @@ jobs: cache-name: cache-regression-tests with: path: test/external/riscv-tests/test-* - key: ${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles( - '**/test/external/riscv-tests/environment/**', + '**/test/external/riscv-tests/environment/custom/**', '**/test/external/riscv-tests/Makefile', - '**/.git/modules/test/external/riscv-tests/riscv-tests/HEAD', - '**/docker/riscv-toolchain.Dockerfile' + '**/.gitmodules-hash', + '**/docker/riscv-toolchain.Dockerfile', + '**/.github/workflows/main.yml' ) }} - restore-keys: | - ${{ env.cache-name }}-${{ runner.os }}- + lookup-only: true + + - if: ${{ steps.cache-regression.outputs.cache-hit != 'true' }} + name: Checkout with submodules + uses: actions/checkout@v3 + with: + submodules: recursive - if: ${{ steps.cache-regression.outputs.cache-hit != 'true' }} run: cd test/external/riscv-tests && make @@ -179,8 +251,6 @@ jobs: steps: - name: Checkout uses: actions/checkout@v3 - with: - submodules: recursive - name: Set up Python uses: actions/setup-python@v4 @@ -195,21 +265,29 @@ jobs: python3 -m pip install -r requirements-dev.txt - uses: actions/download-artifact@v3 + name: Download full verilog core with: name: "verilog-full-core" path: . + - name: Get submodules HEAD hash + run: | + git config --global --add safe.directory /__w/coreblocks/coreblocks + git submodule > .gitmodules-hash + - uses: actions/cache@v3 + name: Download tests from cache env: cache-name: cache-regression-tests with: path: test/external/riscv-tests/test-* key: ${{ env.cache-name }}-${{ runner.os }}-${{ hashFiles( - '**/test/external/riscv-tests/environment/**', - '**/test/external/riscv-tests/Makefile', - '**/.git/modules/test/external/riscv-tests/riscv-tests/HEAD', - '**/docker/riscv-toolchain.Dockerfile' - ) }} + '**/test/external/riscv-tests/environment/custom/**', + '**/test/external/riscv-tests/Makefile', + '**/.gitmodules-hash', + '**/docker/riscv-toolchain.Dockerfile', + '**/.github/workflows/main.yml' + ) }} fail-on-cache-miss: true - name: Run tests @@ -217,13 +295,16 @@ jobs: . venv/bin/activate scripts/run_tests.py -a regression - - name: Check for test failure - run: ./scripts/check_test_results.py + - name: Check regression with pysim + run: | + . venv/bin/activate + ./scripts/run_tests.py -c 1 -a -b pysim regression + unit-test: name: Run unit tests runs-on: ubuntu-latest - timeout-minutes: 10 + timeout-minutes: 15 steps: - name: Checkout uses: actions/checkout@v3 @@ -244,11 +325,14 @@ jobs: sudo apt-get install -y binutils-riscv64-unknown-elf - name: Run tests - run: ./scripts/run_tests.py --verbose + run: ./scripts/run_tests.py -v - name: Check traces and profiles run: ./scripts/run_tests.py -t -p -c 1 TestCore + - name: Check listing tests + run: ./scripts/run_tests.py -l + lint: name: Check code formatting and typing runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index a27638d1d..c40fe28de 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,9 @@ venv.bak/ # Verilog files *.v +# Verilog generation debug files +*.v.json + # Waveform dumps *.vcd *.gtkw @@ -22,6 +25,9 @@ venv.bak/ # Tests outputs test/__traces__ test/__profiles__/*.json +pytestdebug.log +_coreblocks_regression.lock +_coreblocks_regression.counter # cocotb build /test/regression/cocotb/build diff --git a/.gitmodules b/.gitmodules index 6ecd7035d..8dea05eb8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -5,3 +5,6 @@ path = test/external/embench/embench-iot url = https://github.com/embench/embench-iot.git ignore = dirty +[submodule "test/external/riscof/riscv-arch-test"] + path = test/external/riscof/riscv-arch-test + url = https://github.com/riscv-non-isa/riscv-arch-test.git diff --git a/README.md b/README.md index 8f679b7a4..dcd7fb056 100644 --- a/README.md +++ b/README.md @@ -15,14 +15,14 @@ Coreblocks is an experimental, modular out-of-order [RISC-V](https://riscv.org/s In the future, we would like to achieve the following goals: - * Performace (up to a point, on FPGAs). We would like Coreblocks not to be too sluggish, without compromising the simplicity goal. + * Performance (up to a point, on FPGAs). We would like Coreblocks not to be too sluggish, without compromising the simplicity goal. We don't wish to compete with high performance cores like [BOOM](https://github.com/riscv-boom/riscv-boom) though. * Wide(r) RISC-V support. Currently, we are focusing on getting the support for the core RV32I ISA right, but the ambitious long term plan is to be able to run full operating systems (e.g. Linux) on the core. ## State of the project The core currently supports the full RV32I instruction set and several extensions, including M (multiplication and division) and C (compressed instructions). -Interrupts and exceptions are currently not supported. +Exceptions and some of machine-mode CSRs are supported, the support for interrupts is currently rudimentary and incompatible with the RISC-V spec. Coreblocks can be used with [LiteX](https://github.com/enjoy-digital/litex) (currently using a [patched version](https://github.com/kuznia-rdzeni/litex/tree/coreblocks)). The transaction system we use as the foundation for the core is well-tested and usable. diff --git a/constants/ecp5_platforms.py b/constants/ecp5_platforms.py index 0e3545690..9aade96d5 100644 --- a/constants/ecp5_platforms.py +++ b/constants/ecp5_platforms.py @@ -2,7 +2,7 @@ from itertools import chain from typing import TypeAlias from amaranth.build.dsl import Subsignal -from amaranth.vendor.lattice_ecp5 import LatticeECP5Platform +from amaranth.vendor import LatticeECP5Platform from amaranth.build import Resource, Attrs, Pins, Clock, PinsN from constants.ecp5_pinout import ecp5_bg756_pins, ecp5_bg756_pclk @@ -97,8 +97,8 @@ def make_resources(pins: PinManager) -> list[Resource]: number, en=pins.p(), done=pins.p(), - data_in=pins.p(adapter.data_in.shape().width), - data_out=pins.p(adapter.data_out.shape().width), + data_in=pins.p(adapter.data_in.shape().size), + data_out=pins.p(adapter.data_out.shape().size), ) ] diff --git a/test/common/_test/__init__.py b/coreblocks/cache/__init__.py similarity index 100% rename from test/common/_test/__init__.py rename to coreblocks/cache/__init__.py diff --git a/coreblocks/frontend/icache.py b/coreblocks/cache/icache.py similarity index 70% rename from coreblocks/frontend/icache.py rename to coreblocks/cache/icache.py index 16d4462db..09899afb6 100644 --- a/coreblocks/frontend/icache.py +++ b/coreblocks/cache/icache.py @@ -1,20 +1,24 @@ from functools import reduce import operator -from typing import Protocol from amaranth import * -from amaranth.utils import log2_int +from amaranth.lib.data import View +from amaranth.utils import exact_log2 from transactron.core import def_method, Priority, TModule from transactron import Method, Transaction from coreblocks.params import ICacheLayouts, ICacheParameters from transactron.utils import assign, OneHotSwitchDynamic -from transactron.utils._typing import HasElaborate from transactron.lib import * -from coreblocks.peripherals.wishbone import WishboneMaster +from coreblocks.peripherals.bus_adapter import BusMasterInterface +from coreblocks.cache.iface import CacheInterface, CacheRefillerInterface +from transactron.utils.transactron_helpers import make_layout -__all__ = ["ICache", "ICacheBypass", "ICacheInterface", "SimpleWBCacheRefiller"] +__all__ = [ + "ICache", + "ICacheBypass", +] def extract_instr_from_word(m: TModule, params: ICacheParameters, word: Signal, addr: Value): @@ -31,45 +35,10 @@ def extract_instr_from_word(m: TModule, params: ICacheParameters, word: Signal, return instr_out -class ICacheInterface(HasElaborate, Protocol): - """ - Instruction Cache Interface. - - Parameters - ---------- - issue_req : Method - A method that is used to issue a cache lookup request. - accept_res : Method - A method that is used to accept the result of a cache lookup request. - flush : Method - A method that is used to flush the whole cache. - """ - - issue_req: Method - accept_res: Method - flush: Method - - -class CacheRefillerInterface(HasElaborate, Protocol): - """ - Instruction Cache Refiller Interface. - - Parameters - ---------- - start_refill : Method - A method that is used to start a refill for a given cache line. - accept_refill : Method - A method that is used to accept one word from the requested cache line. - """ - - start_refill: Method - accept_refill: Method - - -class ICacheBypass(Elaboratable, ICacheInterface): - def __init__(self, layouts: ICacheLayouts, params: ICacheParameters, wb_master: WishboneMaster) -> None: +class ICacheBypass(Elaboratable, CacheInterface): + def __init__(self, layouts: ICacheLayouts, params: ICacheParameters, bus_master: BusMasterInterface) -> None: self.params = params - self.wb_master = wb_master + self.bus_master = bus_master self.issue_req = Method(i=layouts.issue_req) self.accept_res = Method(o=layouts.accept_res) @@ -83,17 +52,15 @@ def elaborate(self, platform): @def_method(m, self.issue_req) def _(addr: Value) -> None: m.d.sync += req_addr.eq(addr) - self.wb_master.request( + self.bus_master.request_read( m, - addr=addr >> log2_int(self.params.word_width_bytes), - data=0, - we=0, - sel=C(1).replicate(self.wb_master.wb_params.data_width // self.wb_master.wb_params.granularity), + addr=addr >> exact_log2(self.params.word_width_bytes), + sel=C(1).replicate(self.bus_master.params.data_width // self.bus_master.params.granularity), ) @def_method(m, self.accept_res) def _(): - res = self.wb_master.result(m) + res = self.bus_master.get_read_response(m) return { "instr": extract_instr_from_word(m, self.params, res.data, req_addr), "error": res.err, @@ -106,7 +73,7 @@ def _() -> None: return m -class ICache(Elaboratable, ICacheInterface): +class ICache(Elaboratable, CacheInterface): """A simple set-associative instruction cache. The replacement policy is a pseudo random scheme. Every time a line is trashed, @@ -144,11 +111,20 @@ def __init__(self, layouts: ICacheLayouts, params: ICacheParameters, refiller: C self.flush = Method() self.flush.add_conflict(self.issue_req, Priority.LEFT) - self.addr_layout = [ + self.addr_layout = make_layout( ("offset", self.params.offset_bits), ("index", self.params.index_bits), ("tag", self.params.tag_bits), - ] + ) + + self.perf_loads = HwCounter("frontend.icache.loads", "Number of requests to the L1 Instruction Cache") + self.perf_hits = HwCounter("frontend.icache.hits") + self.perf_misses = HwCounter("frontend.icache.misses") + self.perf_errors = HwCounter("frontend.icache.fetch_errors") + self.perf_flushes = HwCounter("frontend.icache.flushes") + self.req_latency = LatencyMeasurer( + "frontend.icache.req_latency", "Latencies of cache requests", slots_number=2, max_latency=500 + ) def deserialize_addr(self, raw_addr: Value) -> dict[str, Value]: return { @@ -157,12 +133,21 @@ def deserialize_addr(self, raw_addr: Value) -> dict[str, Value]: "tag": raw_addr[-self.params.tag_bits :], } - def serialize_addr(self, addr: Record) -> Value: + def serialize_addr(self, addr: View) -> Value: return Cat(addr.offset, addr.index, addr.tag) def elaborate(self, platform): m = TModule() + m.submodules += [ + self.perf_loads, + self.perf_hits, + self.perf_misses, + self.perf_errors, + self.perf_flushes, + self.req_latency, + ] + m.submodules.mem = self.mem = ICacheMemory(self.params) m.submodules.req_fifo = self.req_fifo = FIFO(layout=self.addr_layout, depth=2) m.submodules.res_fwd = self.res_fwd = Forwarder(layout=self.layouts.accept_res) @@ -170,11 +155,15 @@ def elaborate(self, platform): # State machine logic needs_refill = Signal() refill_finish = Signal() + refill_finish_last = Signal() refill_error = Signal() flush_start = Signal() flush_finish = Signal() + with Transaction().body(m): + self.perf_flushes.incr(m, cond=flush_finish) + with m.FSM(reset="FLUSH") as fsm: with m.State("FLUSH"): with m.If(flush_finish): @@ -199,7 +188,7 @@ def elaborate(self, platform): # Fast path - read requests request_valid = self.req_fifo.read.ready - request_addr = Record(self.addr_layout) + request_addr = Signal(self.addr_layout) tag_hit = [tag_data.valid & (tag_data.tag == request_addr.tag) for tag_data in self.mem.tag_rd_data] tag_hit_any = reduce(operator.or_, tag_hit) @@ -208,25 +197,33 @@ def elaborate(self, platform): for i in OneHotSwitchDynamic(m, Cat(tag_hit)): m.d.comb += mem_out.eq(self.mem.data_rd_data[i]) - instr_out = extract_instr_from_word(m, self.params, mem_out, request_addr[:]) + instr_out = extract_instr_from_word(m, self.params, mem_out, Value.cast(request_addr)) refill_error_saved = Signal() m.d.comb += needs_refill.eq(request_valid & ~tag_hit_any & ~refill_error_saved) with Transaction().body(m, request=request_valid & fsm.ongoing("LOOKUP") & (tag_hit_any | refill_error_saved)): + self.perf_errors.incr(m, cond=refill_error_saved) + self.perf_misses.incr(m, cond=refill_finish_last) + self.perf_hits.incr(m, cond=~refill_finish_last) + self.res_fwd.write(m, instr=instr_out, error=refill_error_saved) m.d.sync += refill_error_saved.eq(0) @def_method(m, self.accept_res) def _(): self.req_fifo.read(m) + self.req_latency.stop(m) return self.res_fwd.read(m) - mem_read_addr = Record(self.addr_layout) + mem_read_addr = Signal(self.addr_layout) m.d.comb += assign(mem_read_addr, request_addr) @def_method(m, self.issue_req, ready=accepting_requests) def _(addr: Value) -> None: + self.perf_loads.incr(m) + self.req_latency.start(m) + deserialized = self.deserialize_addr(addr) # Forward read address only if the method is called m.d.comb += assign(mem_read_addr, deserialized) @@ -258,6 +255,8 @@ def _() -> None: aligned_addr = self.serialize_addr(request_addr) & ~((1 << self.params.offset_bits) - 1) self.refiller.start_refill(m, addr=aligned_addr) + m.d.sync += refill_finish_last.eq(0) + with Transaction().body(m): ret = self.refiller.accept_refill(m) deserialized = self.deserialize_addr(ret.addr) @@ -270,6 +269,7 @@ def _() -> None: m.d.comb += self.mem.data_wr_en.eq(1) m.d.comb += refill_finish.eq(ret.last) + m.d.sync += refill_finish_last.eq(1) m.d.comb += refill_error.eq(ret.error) m.d.sync += refill_error_saved.eq(ret.error) @@ -306,21 +306,21 @@ class ICacheMemory(Elaboratable): def __init__(self, params: ICacheParameters) -> None: self.params = params - self.tag_data_layout = [("valid", 1), ("tag", self.params.tag_bits)] + self.tag_data_layout = make_layout(("valid", 1), ("tag", self.params.tag_bits)) self.way_wr_en = Signal(self.params.num_of_ways) self.tag_rd_index = Signal(self.params.index_bits) - self.tag_rd_data = Array([Record(self.tag_data_layout) for _ in range(self.params.num_of_ways)]) + self.tag_rd_data = Array([Signal(self.tag_data_layout) for _ in range(self.params.num_of_ways)]) self.tag_wr_index = Signal(self.params.index_bits) self.tag_wr_en = Signal() - self.tag_wr_data = Record(self.tag_data_layout) + self.tag_wr_data = Signal(self.tag_data_layout) - self.data_addr_layout = [("index", self.params.index_bits), ("offset", self.params.offset_bits)] + self.data_addr_layout = make_layout(("index", self.params.index_bits), ("offset", self.params.offset_bits)) - self.data_rd_addr = Record(self.data_addr_layout) + self.data_rd_addr = Signal(self.data_addr_layout) self.data_rd_data = Array([Signal(self.params.word_width) for _ in range(self.params.num_of_ways)]) - self.data_wr_addr = Record(self.data_addr_layout) + self.data_wr_addr = Signal(self.data_addr_layout) self.data_wr_en = Signal() self.data_wr_data = Signal(self.params.word_width) @@ -330,7 +330,7 @@ def elaborate(self, platform): for i in range(self.params.num_of_ways): way_wr = self.way_wr_en[i] - tag_mem = Memory(width=len(self.tag_wr_data), depth=self.params.num_of_sets) + tag_mem = Memory(width=len(Value.cast(self.tag_wr_data)), depth=self.params.num_of_sets) tag_mem_rp = tag_mem.read_port() tag_mem_wp = tag_mem.write_port() m.submodules[f"tag_mem_{i}_rp"] = tag_mem_rp @@ -352,7 +352,7 @@ def elaborate(self, platform): # We address the data RAM using machine words, so we have to # discard a few least significant bits from the address. - redundant_offset_bits = log2_int(self.params.word_width_bytes) + redundant_offset_bits = exact_log2(self.params.word_width_bytes) rd_addr = Cat(self.data_rd_addr.offset, self.data_rd_addr.index)[redundant_offset_bits:] wr_addr = Cat(self.data_wr_addr.offset, self.data_wr_addr.index)[redundant_offset_bits:] @@ -365,66 +365,3 @@ def elaborate(self, platform): ] return m - - -class SimpleWBCacheRefiller(Elaboratable, CacheRefillerInterface): - def __init__(self, layouts: ICacheLayouts, params: ICacheParameters, wb_master: WishboneMaster): - self.params = params - self.wb_master = wb_master - - self.start_refill = Method(i=layouts.start_refill) - self.accept_refill = Method(o=layouts.accept_refill) - - def elaborate(self, platform): - m = TModule() - - refill_address = Signal(self.params.word_width - self.params.offset_bits) - refill_active = Signal() - word_counter = Signal(range(self.params.words_in_block)) - - m.submodules.address_fwd = address_fwd = Forwarder( - [("word_counter", word_counter.shape()), ("refill_address", refill_address.shape())] - ) - - with Transaction().body(m): - address = address_fwd.read(m) - self.wb_master.request( - m, - addr=Cat(address["word_counter"], address["refill_address"]), - data=0, - we=0, - sel=C(1).replicate(self.wb_master.wb_params.data_width // self.wb_master.wb_params.granularity), - ) - - @def_method(m, self.start_refill, ready=~refill_active) - def _(addr) -> None: - address = addr[self.params.offset_bits :] - m.d.sync += refill_address.eq(address) - m.d.sync += refill_active.eq(1) - m.d.sync += word_counter.eq(0) - - address_fwd.write(m, word_counter=0, refill_address=address) - - @def_method(m, self.accept_refill, ready=refill_active) - def _(): - fetched = self.wb_master.result(m) - - last = (word_counter == (self.params.words_in_block - 1)) | fetched.err - - next_word_counter = Signal.like(word_counter) - m.d.top_comb += next_word_counter.eq(word_counter + 1) - - m.d.sync += word_counter.eq(next_word_counter) - with m.If(last): - m.d.sync += refill_active.eq(0) - with m.Else(): - address_fwd.write(m, word_counter=next_word_counter, refill_address=refill_address) - - return { - "addr": Cat(C(0, log2_int(self.params.word_width_bytes)), word_counter, refill_address), - "data": fetched.data, - "error": fetched.err, - "last": last, - } - - return m diff --git a/coreblocks/cache/iface.py b/coreblocks/cache/iface.py new file mode 100644 index 000000000..c2c54d2ff --- /dev/null +++ b/coreblocks/cache/iface.py @@ -0,0 +1,42 @@ +from typing import Protocol + +from transactron import Method + +from transactron.utils._typing import HasElaborate + +__all__ = ["CacheInterface", "CacheRefillerInterface"] + + +class CacheInterface(HasElaborate, Protocol): + """ + Cache Interface. + + Parameters + ---------- + issue_req : Method + A method that is used to issue a cache lookup request. + accept_res : Method + A method that is used to accept the result of a cache lookup request. + flush : Method + A method that is used to flush the whole cache. + """ + + issue_req: Method + accept_res: Method + flush: Method + + +class CacheRefillerInterface(HasElaborate, Protocol): + """ + Cache Refiller Interface. + + Parameters + ---------- + start_refill : Method + A method that is used to start a refill for a given cache line. + accept_refill : Method + A method that is used to accept one word from the requested cache line. + """ + + start_refill: Method + accept_refill: Method diff --git a/coreblocks/cache/refiller.py b/coreblocks/cache/refiller.py new file mode 100644 index 000000000..e8a261e26 --- /dev/null +++ b/coreblocks/cache/refiller.py @@ -0,0 +1,72 @@ +from amaranth import * +from coreblocks.cache.icache import CacheRefillerInterface +from coreblocks.params import ICacheLayouts, ICacheParameters +from coreblocks.peripherals.bus_adapter import BusMasterInterface +from transactron.core import Transaction +from transactron.lib import Forwarder, Method, TModule, def_method + +from amaranth.utils import exact_log2 + + +__all__ = ["SimpleCommonBusCacheRefiller"] + + +class SimpleCommonBusCacheRefiller(Elaboratable, CacheRefillerInterface): + def __init__(self, layouts: ICacheLayouts, params: ICacheParameters, bus_master: BusMasterInterface): + self.params = params + self.bus_master = bus_master + + self.start_refill = Method(i=layouts.start_refill) + self.accept_refill = Method(o=layouts.accept_refill) + + def elaborate(self, platform): + m = TModule() + + refill_address = Signal(self.params.word_width - self.params.offset_bits) + refill_active = Signal() + word_counter = Signal(range(self.params.words_in_block)) + + m.submodules.address_fwd = address_fwd = Forwarder( + [("word_counter", word_counter.shape()), ("refill_address", refill_address.shape())] + ) + + with Transaction().body(m): + address = address_fwd.read(m) + self.bus_master.request_read( + m, + addr=Cat(address["word_counter"], address["refill_address"]), + sel=C(1).replicate(self.bus_master.params.data_width // self.bus_master.params.granularity), + ) + + @def_method(m, self.start_refill, ready=~refill_active) + def _(addr) -> None: + address = addr[self.params.offset_bits :] + m.d.sync += refill_address.eq(address) + m.d.sync += refill_active.eq(1) + m.d.sync += word_counter.eq(0) + + address_fwd.write(m, word_counter=0, refill_address=address) + + @def_method(m, self.accept_refill, ready=refill_active) + def _(): + fetched = self.bus_master.get_read_response(m) + + last = (word_counter == (self.params.words_in_block - 1)) | fetched.err + + next_word_counter = Signal.like(word_counter) + m.d.top_comb += next_word_counter.eq(word_counter + 1) + + m.d.sync += word_counter.eq(next_word_counter) + with m.If(last): + m.d.sync += refill_active.eq(0) + with m.Else(): + address_fwd.write(m, word_counter=next_word_counter, refill_address=refill_address) + + return { + "addr": Cat(C(0, exact_log2(self.params.word_width_bytes)), word_counter, refill_address), + "data": fetched.data, + "error": fetched.err, + "last": last, + } + + return m diff --git a/coreblocks/core.py b/coreblocks/core.py index 060b2c7fb..a91b2e827 100644 --- a/coreblocks/core.py +++ b/coreblocks/core.py @@ -1,13 +1,20 @@ from amaranth import * +from amaranth.lib.wiring import flipped, connect -from transactron.utils.dependencies import DependencyManager +from transactron.utils.dependencies import DependencyManager, DependencyContext from coreblocks.stages.func_blocks_unifier import FuncBlocksUnifier from coreblocks.structs_common.instr_counter import CoreInstructionCounter from coreblocks.structs_common.interrupt_controller import InterruptController from transactron.core import Transaction, TModule from transactron.lib import FIFO, ConnectTrans from coreblocks.params.layouts import * -from coreblocks.params.keys import BranchResolvedKey, GenericCSRRegistersKey, InstructionPrecommitKey, WishboneDataKey +from coreblocks.params.keys import ( + BranchVerifyKey, + FetchResumeKey, + GenericCSRRegistersKey, + InstructionPrecommitKey, + CommonBusDataKey, +) from coreblocks.params.genparams import GenParams from coreblocks.params.isa import Extension from coreblocks.frontend.decode_stage import DecodeStage @@ -19,24 +26,34 @@ from coreblocks.scheduler.scheduler import Scheduler from coreblocks.stages.backend import ResultAnnouncement from coreblocks.stages.retirement import Retirement -from coreblocks.frontend.icache import ICache, SimpleWBCacheRefiller, ICacheBypass -from coreblocks.peripherals.wishbone import WishboneMaster, WishboneBus +from coreblocks.cache.icache import ICache, ICacheBypass +from coreblocks.peripherals.bus_adapter import WishboneMasterAdapter +from coreblocks.peripherals.wishbone import WishboneMaster, WishboneInterface +from coreblocks.cache.refiller import SimpleCommonBusCacheRefiller from coreblocks.frontend.fetch import Fetch, UnalignedFetch from transactron.lib.transformers import MethodMap, MethodProduct from transactron.lib import BasicFifo +from transactron.lib.metrics import HwMetricsEnabledKey __all__ = ["Core"] class Core(Elaboratable): - def __init__(self, *, gen_params: GenParams, wb_instr_bus: WishboneBus, wb_data_bus: WishboneBus): + def __init__(self, *, gen_params: GenParams, wb_instr_bus: WishboneInterface, wb_data_bus: WishboneInterface): self.gen_params = gen_params + dep_manager = DependencyContext.get() + if self.gen_params.debug_signals_enabled: + dep_manager.add_dependency(HwMetricsEnabledKey(), True) + self.wb_instr_bus = wb_instr_bus self.wb_data_bus = wb_data_bus - self.wb_master_instr = WishboneMaster(self.gen_params.wb_params) - self.wb_master_data = WishboneMaster(self.gen_params.wb_params) + self.wb_master_instr = WishboneMaster(self.gen_params.wb_params, "instr") + self.wb_master_data = WishboneMaster(self.gen_params.wb_params, "data") + + self.bus_master_instr_adapter = WishboneMasterAdapter(self.wb_master_instr) + self.bus_master_data_adapter = WishboneMasterAdapter(self.wb_master_data) self.core_counter = CoreInstructionCounter(self.gen_params) @@ -55,27 +72,34 @@ def __init__(self, *, gen_params: GenParams, wb_instr_bus: WishboneBus, wb_data_ cache_layouts = self.gen_params.get(ICacheLayouts) if gen_params.icache_params.enable: - self.icache_refiller = SimpleWBCacheRefiller( - cache_layouts, self.gen_params.icache_params, self.wb_master_instr + self.icache_refiller = SimpleCommonBusCacheRefiller( + cache_layouts, self.gen_params.icache_params, self.bus_master_instr_adapter ) self.icache = ICache(cache_layouts, self.gen_params.icache_params, self.icache_refiller) else: - self.icache = ICacheBypass(cache_layouts, gen_params.icache_params, self.wb_master_instr) + self.icache = ICacheBypass(cache_layouts, gen_params.icache_params, self.bus_master_instr_adapter) self.FRAT = FRAT(gen_params=self.gen_params) self.RRAT = RRAT(gen_params=self.gen_params) self.RF = RegisterFile(gen_params=self.gen_params) self.ROB = ReorderBuffer(gen_params=self.gen_params) - connections = gen_params.get(DependencyManager) - connections.add_dependency(WishboneDataKey(), self.wb_master_data) + self.connections = gen_params.get(DependencyManager) + self.connections.add_dependency(CommonBusDataKey(), self.bus_master_data_adapter) + + if Extension.C in self.gen_params.isa.extensions: + self.fetch = UnalignedFetch(self.gen_params, self.icache, self.fetch_continue.method) + else: + self.fetch = Fetch(self.gen_params, self.icache, self.fetch_continue.method) - self.exception_cause_register = ExceptionCauseRegister(self.gen_params, rob_get_indices=self.ROB.get_indices) + self.exception_cause_register = ExceptionCauseRegister( + self.gen_params, rob_get_indices=self.ROB.get_indices, fetch_stall_exception=self.fetch.stall_exception + ) self.func_blocks_unifier = FuncBlocksUnifier( gen_params=gen_params, blocks=gen_params.func_units_config, - extra_methods_required=[InstructionPrecommitKey(), BranchResolvedKey()], + extra_methods_required=[InstructionPrecommitKey(), FetchResumeKey()], ) self.announcement = ResultAnnouncement( @@ -89,17 +113,20 @@ def __init__(self, *, gen_params: GenParams, wb_instr_bus: WishboneBus, wb_data_ self.interrupt_controller = InterruptController(self.gen_params) self.csr_generic = GenericCSRRegisters(self.gen_params) - connections.add_dependency(GenericCSRRegistersKey(), self.csr_generic) + self.connections.add_dependency(GenericCSRRegistersKey(), self.csr_generic) def elaborate(self, platform): m = TModule() - m.d.comb += self.wb_master_instr.wbMaster.connect(self.wb_instr_bus) - m.d.comb += self.wb_master_data.wbMaster.connect(self.wb_data_bus) + connect(m, flipped(self.wb_instr_bus), self.wb_master_instr.wb_master) + connect(m, flipped(self.wb_data_bus), self.wb_master_data.wb_master) m.submodules.wb_master_instr = self.wb_master_instr m.submodules.wb_master_data = self.wb_master_data + m.submodules.bus_master_instr_adapter = self.bus_master_instr_adapter + m.submodules.bus_master_data_adapter = self.bus_master_data_adapter + m.submodules.free_rf_fifo = free_rf_fifo = self.free_rf_fifo m.submodules.FRAT = frat = self.FRAT m.submodules.RRAT = rrat = self.RRAT @@ -110,11 +137,8 @@ def elaborate(self, platform): m.submodules.icache_refiller = self.icache_refiller m.submodules.icache = self.icache - if Extension.C in self.gen_params.isa.extensions: - m.submodules.fetch = self.fetch = UnalignedFetch(self.gen_params, self.icache, self.fetch_continue.use(m)) - else: - m.submodules.fetch = self.fetch = Fetch(self.gen_params, self.icache, self.fetch_continue.use(m)) - + m.submodules.fetch_continue = self.fetch_continue + m.submodules.fetch = self.fetch m.submodules.fifo_fetch = self.fifo_fetch m.submodules.core_counter = self.core_counter m.submodules.args_discard_map = self.core_counter_increment_discard_map @@ -137,8 +161,8 @@ def elaborate(self, platform): m.submodules.exception_cause_register = self.exception_cause_register - m.submodules.verify_branch = ConnectTrans( - self.func_blocks_unifier.get_extra_method(BranchResolvedKey()), self.fetch.verify_branch + m.submodules.fetch_resume_connector = ConnectTrans( + self.func_blocks_unifier.get_extra_method(FetchResumeKey()), self.fetch.resume ) m.submodules.announcement = self.announcement @@ -155,8 +179,7 @@ def elaborate(self, platform): exception_cause_get=self.exception_cause_register.get, exception_cause_clear=self.exception_cause_register.clear, frat_rename=frat.rename, - fetch_continue=self.fetch.verify_branch, - fetch_stall=self.fetch.stall_exception, + fetch_continue=self.fetch.resume, instr_decrement=self.core_counter.decrement, trap_entry=self.interrupt_controller.entry, ) @@ -171,4 +194,9 @@ def elaborate(self, platform): free_rf_fifo.write(m, free_rf_reg) m.d.sync += free_rf_reg.eq(free_rf_reg + 1) + # TODO: Remove when Branch Predictor implemented + with Transaction(name="DiscardBranchVerify").body(m): + read = self.connections.get_dependency(BranchVerifyKey()) + read(m) # Consume to not block JB Unit + return m diff --git a/coreblocks/frontend/decode_stage.py b/coreblocks/frontend/decode_stage.py index b530208b0..6ba649db7 100644 --- a/coreblocks/frontend/decode_stage.py +++ b/coreblocks/frontend/decode_stage.py @@ -2,6 +2,7 @@ from coreblocks.params.isa import Funct3 from coreblocks.params.optypes import OpType +from transactron.lib.metrics import * from transactron import Method, Transaction, TModule from ..params import GenParams from .instr_decoder import InstrDecoder @@ -34,9 +35,12 @@ def __init__(self, gen_params: GenParams, get_raw: Method, push_decoded: Method) self.get_raw = get_raw self.push_decoded = push_decoded + self.perf_illegal_instr = HwCounter("frontend.decode.illegal_instr") + def elaborate(self, platform): m = TModule() + m.submodules.perf_illegal_instr = self.perf_illegal_instr m.submodules.instr_decoder = instr_decoder = InstrDecoder(self.gen_params) with Transaction().body(m): @@ -61,6 +65,7 @@ def elaborate(self, platform): with m.If(raw.access_fault): m.d.comb += exception_funct.eq(Funct3._EINSTRACCESSFAULT) with m.Elif(instr_decoder.illegal): + self.perf_illegal_instr.incr(m) m.d.comb += exception_funct.eq(Funct3._EILLEGALINSTR) self.push_decoded( diff --git a/coreblocks/frontend/fetch.py b/coreblocks/frontend/fetch.py index 1a179b430..33a1a2129 100644 --- a/coreblocks/frontend/fetch.py +++ b/coreblocks/frontend/fetch.py @@ -1,6 +1,8 @@ from amaranth import * +from transactron.core import Priority from transactron.lib import BasicFifo, Semaphore -from coreblocks.frontend.icache import ICacheInterface +from transactron.lib.metrics import * +from coreblocks.cache.iface import CacheInterface from coreblocks.frontend.rvc import InstrDecompress, is_instr_compressed from transactron import def_method, Method, Transaction, TModule from ..params import * @@ -12,14 +14,14 @@ class Fetch(Elaboratable): after each fetch. """ - def __init__(self, gen_params: GenParams, icache: ICacheInterface, cont: Method) -> None: + def __init__(self, gen_params: GenParams, icache: CacheInterface, cont: Method) -> None: """ Parameters ---------- gen_params : GenParams Instance of GenParams with parameters which should be used to generate fetch unit. - icache : ICacheInterface + icache : CacheInterface Instruction Cache cont : Method Method which should be invoked to send fetched data to the next step. @@ -29,11 +31,12 @@ def __init__(self, gen_params: GenParams, icache: ICacheInterface, cont: Method) self.icache = icache self.cont = cont - self.verify_branch = Method(i=self.gen_params.get(FetchLayouts).branch_verify) + self.resume = Method(i=self.gen_params.get(FetchLayouts).resume) self.stall_exception = Method() - - # PC of the last fetched instruction. For now only used in tests. - self.pc = Signal(self.gen_params.isa.xlen) + # Fetch can be resumed to unstall from 'unsafe' instructions, and stalled because + # of exception report, both can happen at any time during normal excecution. + # ExceptionCauseRegister uses separate Transaction for it, so performace is not affected. + self.stall_exception.add_conflict(self.resume, Priority.LEFT) def elaborate(self, platform): m = TModule() @@ -71,9 +74,7 @@ def stall(exception=False): opcode = res.instr[2:7] # whether we have to wait for the retirement of this instruction before we make futher speculation - unsafe_instr = ( - (opcode == Opcode.BRANCH) | (opcode == Opcode.JAL) | (opcode == Opcode.JALR) | (opcode == Opcode.SYSTEM) - ) + unsafe_instr = opcode == Opcode.SYSTEM with m.If(spin == target.spin): instr = Signal(self.gen_params.isa.ilen) @@ -87,14 +88,13 @@ def stall(exception=False): with m.If(unsafe_instr): stall() - m.d.sync += self.pc.eq(target.addr) m.d.comb += instr.eq(res.instr) self.cont(m, instr=instr, pc=target.addr, access_fault=fetch_error, rvc=0) - @def_method(m, self.verify_branch, ready=stalled) - def _(from_pc: Value, next_pc: Value, resume_from_exception: Value): - m.d.sync += speculative_pc.eq(next_pc) + @def_method(m, self.resume, ready=stalled) + def _(pc: Value, resume_from_exception: Value): + m.d.sync += speculative_pc.eq(pc) m.d.sync += stalled_unsafe.eq(0) with m.If(resume_from_exception): m.d.sync += stalled_exception.eq(0) @@ -111,14 +111,14 @@ class UnalignedFetch(Elaboratable): Simple fetch unit that works with unaligned and RVC instructions. """ - def __init__(self, gen_params: GenParams, icache: ICacheInterface, cont: Method) -> None: + def __init__(self, gen_params: GenParams, icache: CacheInterface, cont: Method) -> None: """ Parameters ---------- gen_params : GenParams Instance of GenParams with parameters which should be used to generate fetch unit. - icache : ICacheInterface + icache : CacheInterface Instruction Cache cont : Method Method which should be invoked to send fetched data to the next step. @@ -128,15 +128,17 @@ def __init__(self, gen_params: GenParams, icache: ICacheInterface, cont: Method) self.icache = icache self.cont = cont - self.verify_branch = Method(i=self.gen_params.get(FetchLayouts).branch_verify) + self.resume = Method(i=self.gen_params.get(FetchLayouts).resume) self.stall_exception = Method() + self.stall_exception.add_conflict(self.resume, Priority.LEFT) - # PC of the last fetched instruction. For now only used in tests. - self.pc = Signal(self.gen_params.isa.xlen) + self.perf_rvc = HwCounter("frontend.ifu.rvc", "Number of decompressed RVC instructions") def elaborate(self, platform) -> TModule: m = TModule() + m.submodules += [self.perf_rvc] + m.submodules.req_limiter = req_limiter = Semaphore(2) m.submodules.decompress = decompress = InstrDecompress(self.gen_params) @@ -191,9 +193,7 @@ def elaborate(self, platform) -> TModule: opcode = instr[2:7] # whether we have to wait for the retirement of this instruction before we make futher speculation - unsafe_instr = ( - (opcode == Opcode.BRANCH) | (opcode == Opcode.JAL) | (opcode == Opcode.JALR) | (opcode == Opcode.SYSTEM) - ) + unsafe_instr = opcode == Opcode.SYSTEM # Check if we are ready to dispatch an instruction in the current cycle. # This can happen in three situations: @@ -224,16 +224,16 @@ def elaborate(self, platform) -> TModule: m.d.sync += stalled_unsafe.eq(1) m.d.sync += flushing.eq(1) - m.d.sync += self.pc.eq(current_pc) with m.If(~cache_resp.error): m.d.sync += current_pc.eq(current_pc + Mux(is_rvc, C(2, 3), C(4, 3))) + self.perf_rvc.incr(m, cond=is_rvc) self.cont(m, instr=instr, pc=current_pc, access_fault=cache_resp.error, rvc=is_rvc) - @def_method(m, self.verify_branch, ready=(stalled & ~flushing)) - def _(from_pc: Value, next_pc: Value, resume_from_exception: Value): - m.d.sync += cache_req_pc.eq(next_pc) - m.d.sync += current_pc.eq(next_pc) + @def_method(m, self.resume, ready=(stalled & ~flushing)) + def _(pc: Value, resume_from_exception: Value): + m.d.sync += cache_req_pc.eq(pc) + m.d.sync += current_pc.eq(pc) m.d.sync += stalled_unsafe.eq(0) with m.If(resume_from_exception): m.d.sync += stalled_exception.eq(0) diff --git a/coreblocks/frontend/instr_description.py b/coreblocks/frontend/instr_description.py index 0e9fe6994..632d436cc 100644 --- a/coreblocks/frontend/instr_description.py +++ b/coreblocks/frontend/instr_description.py @@ -164,20 +164,23 @@ class Encoding: Encoding(Opcode.OP, Funct3.MIN, Funct7.MIN), Encoding(Opcode.OP, Funct3.MINU, Funct7.MIN), Encoding(Opcode.OP, Funct3.ORN, Funct7.ORN), + Encoding(Opcode.OP, Funct3.XNOR, Funct7.XNOR), + ], + OpType.BIT_ROTATION: [ Encoding(Opcode.OP, Funct3.ROL, Funct7.ROL), Encoding(Opcode.OP, Funct3.ROR, Funct7.ROR), Encoding(Opcode.OP_IMM, Funct3.ROR, Funct7.ROR), - Encoding(Opcode.OP, Funct3.XNOR, Funct7.XNOR), ], OpType.UNARY_BIT_MANIPULATION_1: [ - Encoding(Opcode.OP_IMM, Funct3.ORCB, funct12=Funct12.ORCB), Encoding(Opcode.OP_IMM, Funct3.REV8, funct12=Funct12.REV8_32), Encoding(Opcode.OP_IMM, Funct3.SEXTB, funct12=Funct12.SEXTB), Encoding(Opcode.OP, Funct3.ZEXTH, funct12=Funct12.ZEXTH), ], - # Instructions SEXTH, SEXTHB, CPOP, CLZ and CTZ cannot be distiguished by their Funct7 code + # Instructions SEXTH, SEXTHB, CPOP, CLZ and CTZ cannot be distiguished by their Funct7 code + # ORCB is here because of optimization to not lookup Funct7 in UNARY_BIT_MANIPULATION_1 OpType.UNARY_BIT_MANIPULATION_2: [ Encoding(Opcode.OP_IMM, Funct3.SEXTH, funct12=Funct12.SEXTH), + Encoding(Opcode.OP_IMM, Funct3.ORCB, funct12=Funct12.ORCB), ], OpType.UNARY_BIT_MANIPULATION_3: [ Encoding(Opcode.OP_IMM, Funct3.CLZ, funct12=Funct12.CLZ), diff --git a/coreblocks/fu/alu.py b/coreblocks/fu/alu.py index bc5bf72b5..114e367ce 100644 --- a/coreblocks/fu/alu.py +++ b/coreblocks/fu/alu.py @@ -82,14 +82,14 @@ def get_instructions(self) -> Sequence[tuple]: (self.Fn.MAXU, OpType.BIT_MANIPULATION, Funct3.MAXU, Funct7.MAX), (self.Fn.MIN, OpType.BIT_MANIPULATION, Funct3.MIN, Funct7.MIN), (self.Fn.MINU, OpType.BIT_MANIPULATION, Funct3.MINU, Funct7.MIN), - (self.Fn.ORCB, OpType.UNARY_BIT_MANIPULATION_1, Funct3.ORCB, Funct7.ORCB), - (self.Fn.REV8, OpType.UNARY_BIT_MANIPULATION_1, Funct3.REV8, Funct7.REV8), - (self.Fn.SEXTB, OpType.UNARY_BIT_MANIPULATION_1, Funct3.SEXTB, Funct7.SEXTB), - (self.Fn.ZEXTH, OpType.UNARY_BIT_MANIPULATION_1, Funct3.ZEXTH, Funct7.ZEXTH), - (self.Fn.CPOP, OpType.UNARY_BIT_MANIPULATION_5, Funct3.CPOP, Funct7.CPOP), - (self.Fn.SEXTH, OpType.UNARY_BIT_MANIPULATION_2, Funct3.SEXTH, Funct7.SEXTH), - (self.Fn.CLZ, OpType.UNARY_BIT_MANIPULATION_3, Funct3.CLZ, Funct7.CLZ), - (self.Fn.CTZ, OpType.UNARY_BIT_MANIPULATION_4, Funct3.CTZ, Funct7.CTZ), + (self.Fn.REV8, OpType.UNARY_BIT_MANIPULATION_1, Funct3.REV8), + (self.Fn.SEXTB, OpType.UNARY_BIT_MANIPULATION_1, Funct3.SEXTB), + (self.Fn.ZEXTH, OpType.UNARY_BIT_MANIPULATION_1, Funct3.ZEXTH), + (self.Fn.ORCB, OpType.UNARY_BIT_MANIPULATION_2, Funct3.ORCB), + (self.Fn.SEXTH, OpType.UNARY_BIT_MANIPULATION_2, Funct3.SEXTH), + (self.Fn.CLZ, OpType.UNARY_BIT_MANIPULATION_3, Funct3.CLZ), + (self.Fn.CTZ, OpType.UNARY_BIT_MANIPULATION_4, Funct3.CTZ), + (self.Fn.CPOP, OpType.UNARY_BIT_MANIPULATION_5, Funct3.CPOP), ] * self.zbb_enable ) diff --git a/coreblocks/fu/div_unit.py b/coreblocks/fu/div_unit.py index a4767a0b0..9e3f3dfc6 100644 --- a/coreblocks/fu/div_unit.py +++ b/coreblocks/fu/div_unit.py @@ -3,6 +3,7 @@ from collections.abc import Sequence from amaranth import * +from amaranth.lib import data from coreblocks.params.fu_params import FunctionalComponentParams from coreblocks.params import Funct3, GenParams, FuncUnitLayouts, OpType @@ -33,7 +34,7 @@ def get_instructions(self) -> Sequence[tuple]: ] -def get_input(arg: Record) -> tuple[Value, Value]: +def get_input(arg: data.View) -> tuple[Value, Value]: return arg.s1_val, Mux(arg.imm, arg.imm, arg.s2_val) diff --git a/coreblocks/fu/fu_decoder.py b/coreblocks/fu/fu_decoder.py index 510ee30f0..eeaae8bf1 100644 --- a/coreblocks/fu/fu_decoder.py +++ b/coreblocks/fu/fu_decoder.py @@ -15,13 +15,13 @@ class Decoder(Elaboratable): Attributes ---------- decode_fn: Signal - exec_fn: Record + exec_fn: View """ def __init__(self, gen_params: GenParams, decode_fn: Type[IntFlag], ops: Sequence[tuple], check_optype: bool): layouts = gen_params.get(CommonLayoutFields) - self.exec_fn = Record(layouts.exec_fn_layout) + self.exec_fn = Signal(layouts.exec_fn_layout) self.decode_fn = Signal(decode_fn) self.ops = ops self.check_optype = check_optype diff --git a/coreblocks/fu/jumpbranch.py b/coreblocks/fu/jumpbranch.py index 2b6c8aefa..8b4ba52c9 100644 --- a/coreblocks/fu/jumpbranch.py +++ b/coreblocks/fu/jumpbranch.py @@ -7,10 +7,10 @@ from transactron import * from transactron.core import def_method from transactron.lib import * +from transactron.lib import logging from transactron.utils import DependencyManager - from coreblocks.params import * -from coreblocks.params.keys import AsyncInterruptInsertSignalKey +from coreblocks.params.keys import AsyncInterruptInsertSignalKey, BranchVerifyKey from transactron.utils import OneHotSwitch from coreblocks.utils.protocols import FuncUnit @@ -19,6 +19,9 @@ __all__ = ["JumpBranchFuncUnit", "JumpComponent"] +log = logging.HardwareLogger("backend.fu.jumpbranch") + + class JumpBranchFn(DecoderManager): class Fn(IntFlag): JAL = auto() @@ -124,28 +127,34 @@ def __init__(self, gen_params: GenParams, jb_fn=JumpBranchFn()): self.issue = Method(i=layouts.issue) self.accept = Method(o=layouts.accept) - self.branch_result = Method(o=gen_params.get(FetchLayouts).branch_verify) + + self.fifo_branch_resolved = FIFO(self.gen_params.get(JumpBranchLayouts).verify_branch, 2) self.jb_fn = jb_fn self.dm = gen_params.get(DependencyManager) + self.dm.add_dependency(BranchVerifyKey(), self.fifo_branch_resolved.read) + + self.perf_jumps = HwCounter("backend.fu.jumpbranch.jumps", "Number of jump instructions issued") + self.perf_branches = HwCounter("backend.fu.jumpbranch.branches", "Number of branch instructions issued") + self.perf_misaligned = HwCounter( + "backend.fu.jumpbranch.misaligned", "Number of instructions with misaligned target address" + ) def elaborate(self, platform): m = TModule() + m.submodules += [self.perf_jumps, self.perf_branches, self.perf_misaligned] + m.submodules.jb = jb = JumpBranch(self.gen_params, fn=self.jb_fn) m.submodules.fifo_res = fifo_res = FIFO(self.gen_params.get(FuncUnitLayouts).accept, 2) - m.submodules.fifo_branch = fifo_branch = FIFO(self.gen_params.get(FetchLayouts).branch_verify, 2) m.submodules.decoder = decoder = self.jb_fn.get_decoder(self.gen_params) + m.submodules.fifo_branch_resolved = self.fifo_branch_resolved @def_method(m, self.accept) def _(): return fifo_res.read(m) - @def_method(m, self.branch_result) - def _(): - return fifo_branch.read(m) - @def_method(m, self.issue) def _(arg): m.d.top_comb += decoder.exec_fn.eq(arg.exec_fn) @@ -159,8 +168,13 @@ def _(arg): m.d.top_comb += jb.in_rvc.eq(arg.exec_fn.funct7) is_auipc = decoder.decode_fn == JumpBranchFn.Fn.AUIPC + is_jump = (decoder.decode_fn == JumpBranchFn.Fn.JAL) | (decoder.decode_fn == JumpBranchFn.Fn.JALR) + jump_result = Mux(jb.taken, jb.jmp_addr, jb.reg_res) + self.perf_jumps.incr(m, cond=is_jump) + self.perf_branches.incr(m, cond=(~is_jump & ~is_auipc)) + exception = Signal() exception_report = self.dm.get_dependency(ExceptionReportKey()) @@ -170,7 +184,12 @@ def _(arg): AsyncInterruptInsertSignalKey() ) + # TODO: Update with branch prediction support. + # Temporarily there is no jump prediction, jumps don't stall fetch and pc+4 is always fetched to pipeline + misprediction = ~is_auipc & jb.taken + with m.If(~is_auipc & jb.taken & jmp_addr_misaligned): + self.perf_misaligned.incr(m) # Spec: "[...] if the target address is not four-byte aligned. This exception is reported on the branch # or jump instruction, not on the target instruction. No instruction-address-misaligned exception is # generated for a conditional branch that is not taken." @@ -183,12 +202,24 @@ def _(arg): # and exception would be lost. m.d.comb += exception.eq(1) exception_report(m, rob_id=arg.rob_id, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=jump_result) + with m.Elif(misprediction): + # Async interrupts can have priority, because `jump_result` is handled in the same way. + # No extra misprediction penalty will be introducted at interrupt return to `jump_result` address. + m.d.comb += exception.eq(1) + exception_report(m, rob_id=arg.rob_id, cause=ExceptionCause._COREBLOCKS_MISPREDICTION, pc=jump_result) fifo_res.write(m, rob_id=arg.rob_id, result=jb.reg_res, rp_dst=arg.rp_dst, exception=exception) - # skip writing next branch target for auipc with m.If(~is_auipc): - fifo_branch.write(m, from_pc=jb.in_pc, next_pc=jump_result, resume_from_exception=0) + self.fifo_branch_resolved.write(m, from_pc=jb.in_pc, next_pc=jump_result, misprediction=misprediction) + log.debug( + m, + True, + "jumping from 0x{:08x} to 0x{:08x}; misprediction: {}", + jb.in_pc, + jump_result, + misprediction, + ) return m @@ -199,8 +230,6 @@ def __init__(self): def get_module(self, gen_params: GenParams) -> FuncUnit: unit = JumpBranchFuncUnit(gen_params, self.jb_fn) - connections = gen_params.get(DependencyManager) - connections.add_dependency(BranchResolvedKey(), unit.branch_result) return unit def get_optypes(self) -> set[OpType]: diff --git a/coreblocks/fu/mul_unit.py b/coreblocks/fu/mul_unit.py index b55ff604e..0deba543a 100644 --- a/coreblocks/fu/mul_unit.py +++ b/coreblocks/fu/mul_unit.py @@ -45,13 +45,13 @@ def get_instructions(self) -> Sequence[tuple]: ] -def get_input(arg: Record) -> tuple[Value, Value]: +def get_input(arg: MethodStruct) -> tuple[Value, Value]: """ Operation of getting two input values. Parameters ---------- - arg: Record + arg: MethodStruct Arguments of functional unit issue call. Returns diff --git a/coreblocks/fu/priv.py b/coreblocks/fu/priv.py index 1dc31bfcc..1e7d599d5 100644 --- a/coreblocks/fu/priv.py +++ b/coreblocks/fu/priv.py @@ -36,7 +36,7 @@ def __init__(self, gp: GenParams): self.accept = Method(o=layouts.accept) self.precommit = Method(i=gp.get(RetirementLayouts).precommit) - self.branch_resolved_fifo = BasicFifo(self.gp.get(FetchLayouts).branch_verify, 2) + self.fetch_resume_fifo = BasicFifo(self.gp.get(FetchLayouts).resume, 2) def elaborate(self, platform): m = TModule() @@ -52,7 +52,7 @@ def elaborate(self, platform): exception_report = self.dm.get_dependency(ExceptionReportKey()) csr = self.dm.get_dependency(GenericCSRRegistersKey()) - m.submodules.branch_resolved_fifo = self.branch_resolved_fifo + m.submodules.fetch_resume_fifo = self.fetch_resume_fifo @def_method(m, self.issue, ready=~instr_valid) def _(arg): @@ -89,7 +89,7 @@ def _(): exception_report(m, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=ret_pc, rob_id=instr_rob) with m.Else(): # Unstall the fetch to return address (MRET is SYSTEM opcode) - self.branch_resolved_fifo.write(m, next_pc=ret_pc, from_pc=0, resume_from_exception=0) + self.fetch_resume_fifo.write(m, pc=ret_pc, resume_from_exception=0) return { "rob_id": instr_rob, @@ -106,7 +106,7 @@ def get_module(self, gp: GenParams) -> FuncUnit: unit = PrivilegedFuncUnit(gp) connections = gp.get(DependencyManager) connections.add_dependency(InstructionPrecommitKey(), unit.precommit) - connections.add_dependency(BranchResolvedKey(), unit.branch_resolved_fifo.read) + connections.add_dependency(FetchResumeKey(), unit.fetch_resume_fifo.read) return unit def get_optypes(self) -> set[OpType]: diff --git a/coreblocks/fu/shift_unit.py b/coreblocks/fu/shift_unit.py index f0ce3dc2d..0df08b73c 100644 --- a/coreblocks/fu/shift_unit.py +++ b/coreblocks/fu/shift_unit.py @@ -34,8 +34,8 @@ def get_instructions(self) -> Sequence[tuple]: (self.Fn.SRL, OpType.SHIFT, Funct3.SR, Funct7.SL), (self.Fn.SRA, OpType.SHIFT, Funct3.SR, Funct7.SA), ] + [ - (self.Fn.ROR, OpType.BIT_MANIPULATION, Funct3.ROR, Funct7.ROR), - (self.Fn.ROL, OpType.BIT_MANIPULATION, Funct3.ROL, Funct7.ROL), + (self.Fn.ROR, OpType.BIT_ROTATION, Funct3.ROR), + (self.Fn.ROL, OpType.BIT_ROTATION, Funct3.ROL), ] * self.zbb_enable diff --git a/coreblocks/lsu/dummyLsu.py b/coreblocks/lsu/dummyLsu.py index cff013627..3b8edd4a4 100644 --- a/coreblocks/lsu/dummyLsu.py +++ b/coreblocks/lsu/dummyLsu.py @@ -1,21 +1,23 @@ from amaranth import * +from amaranth.lib.data import View from transactron import Method, def_method, Transaction, TModule from coreblocks.params import * -from coreblocks.peripherals.wishbone import WishboneMaster +from coreblocks.peripherals.bus_adapter import BusMasterInterface from transactron.lib.connectors import Forwarder from transactron.utils import assign, ModuleLike, DependencyManager from coreblocks.utils.protocols import FuncBlock +from transactron.lib.simultaneous import condition from coreblocks.lsu.pma import PMAChecker __all__ = ["LSUDummy", "LSUBlockComponent"] -class LSURequesterWB(Elaboratable): +class LSURequester(Elaboratable): """ - Wishbone request logic for the load/store unit. Its job is to interface - between the LSU and the Wishbone bus. + Bus request logic for the load/store unit. Its job is to interface + between the LSU and the bus. Attributes ---------- @@ -25,14 +27,14 @@ class LSURequesterWB(Elaboratable): Retrieves a result from the bus. """ - def __init__(self, gen_params: GenParams, bus: WishboneMaster) -> None: + def __init__(self, gen_params: GenParams, bus: BusMasterInterface) -> None: """ Parameters ---------- gen_params : GenParams Parameters to be used during processor generation. - bus : WishboneMaster - An instance of the Wishbone master for interfacing with the data bus. + bus : BusMasterInterface + An instance of the bus master for interfacing with the data bus. """ self.gen_params = gen_params self.bus = bus @@ -43,7 +45,7 @@ def __init__(self, gen_params: GenParams, bus: WishboneMaster) -> None: self.accept = Method(o=lsu_layouts.accept) def prepare_bytes_mask(self, m: ModuleLike, funct3: Value, addr: Value) -> Signal: - mask_len = self.gen_params.isa.xlen // self.bus.wb_params.granularity + mask_len = self.gen_params.isa.xlen // self.bus.params.granularity mask = Signal(mask_len) with m.Switch(funct3): with m.Case(Funct3.B, Funct3.BU): @@ -71,7 +73,7 @@ def postprocess_load_data(self, m: ModuleLike, funct3: Value, raw_data: Value, a m.d.av_comb += data.eq(tmp.as_signed()) with m.Else(): m.d.av_comb += data.eq(tmp) - with m.Case(): + with m.Default(): m.d.av_comb += data.eq(raw_data) return data @@ -82,7 +84,7 @@ def prepare_data_to_save(self, m: ModuleLike, funct3: Value, raw_data: Value, ad m.d.av_comb += data.eq(raw_data[0:8] << (addr[0:2] << 3)) with m.Case(Funct3.H): m.d.av_comb += data.eq(raw_data[0:16] << (addr[1] << 4)) - with m.Case(): + with m.Default(): m.d.av_comb += data.eq(raw_data) return data @@ -93,7 +95,7 @@ def check_align(self, m: TModule, funct3: Value, addr: Value): m.d.av_comb += aligned.eq(addr[0:2] == 0) with m.Case(Funct3.H, Funct3.HU): m.d.av_comb += aligned.eq(addr[0] == 0) - with m.Case(): + with m.Default(): m.d.av_comb += aligned.eq(1) return aligned @@ -112,10 +114,17 @@ def _(addr: Value, data: Value, funct3: Value, store: Value): aligned = self.check_align(m, funct3, addr) bytes_mask = self.prepare_bytes_mask(m, funct3, addr) - wb_data = self.prepare_data_to_save(m, funct3, data, addr) + bus_data = self.prepare_data_to_save(m, funct3, data, addr) + + with condition(m, nonblocking=False, priority=False) as branch: + with branch(aligned & store): + self.bus.request_write(m, addr=addr >> 2, data=bus_data, sel=bytes_mask) + with branch(aligned & ~store): + self.bus.request_read(m, addr=addr >> 2, sel=bytes_mask) + with branch(~aligned): + pass with m.If(aligned): - self.bus.request(m, addr=addr >> 2, we=store, sel=bytes_mask, data=wb_data) m.d.sync += request_sent.eq(1) m.d.sync += addr_reg.eq(addr) m.d.sync += funct3_reg.eq(funct3) @@ -130,15 +139,23 @@ def _(addr: Value, data: Value, funct3: Value, store: Value): @def_method(m, self.accept, request_sent) def _(): + data = Signal(self.gen_params.isa.xlen) exception = Signal() cause = Signal(ExceptionCause) + err = Signal() - fetched = self.bus.result(m) - m.d.sync += request_sent.eq(0) + with condition(m, nonblocking=False, priority=False) as branch: + with branch(store_reg): + fetched = self.bus.get_write_response(m) + err = fetched.err + with branch(~store_reg): + fetched = self.bus.get_read_response(m) + err = fetched.err + data = self.postprocess_load_data(m, funct3_reg, fetched.data, addr_reg) - data = self.postprocess_load_data(m, funct3_reg, fetched.data, addr_reg) + m.d.sync += request_sent.eq(0) - with m.If(fetched.err): + with m.If(err): m.d.av_comb += exception.eq(1) m.d.av_comb += cause.eq( Mux(store_reg, ExceptionCause.STORE_ACCESS_FAULT, ExceptionCause.LOAD_ACCESS_FAULT) @@ -172,14 +189,14 @@ class LSUDummy(FuncBlock, Elaboratable): Used to inform LSU that new instruction is ready to be retired. """ - def __init__(self, gen_params: GenParams, bus: WishboneMaster) -> None: + def __init__(self, gen_params: GenParams, bus: BusMasterInterface) -> None: """ Parameters ---------- gen_params : GenParams Parameters to be used during processor generation. - bus : WishboneMaster - An instance of the Wishbone master for interfacing with the data bus. + bus : BusMasterInterface + An instance of the bus master for interfacing with the data bus. """ self.gen_params = gen_params @@ -204,10 +221,10 @@ def elaborate(self, platform): precommiting = Signal() # start execution issued = Signal() # instruction was issued to the bus flush = Signal() # exception handling, requests are not issued - current_instr = Record(self.lsu_layouts.rs.data_layout) + current_instr = Signal(self.lsu_layouts.rs.data_layout) m.submodules.pma_checker = pma_checker = PMAChecker(self.gen_params) - m.submodules.requester = requester = LSURequesterWB(self.gen_params, self.bus) + m.submodules.requester = requester = LSURequester(self.gen_params, self.bus) m.submodules.results = results = self.forwarder = Forwarder(self.lsu_layouts.accept) @@ -232,7 +249,7 @@ def _(): return {"rs_entry_id": 0} @def_method(m, self.insert) - def _(rs_data: Record, rs_entry_id: Value): + def _(rs_data: View, rs_entry_id: Value): m.d.sync += assign(current_instr, rs_data) m.d.sync += valid.eq(1) @@ -303,8 +320,8 @@ def _(rob_id: Value, side_fx: Value): class LSUBlockComponent(BlockComponentParams): def get_module(self, gen_params: GenParams) -> FuncBlock: connections = gen_params.get(DependencyManager) - wb_master = connections.get_dependency(WishboneDataKey()) - unit = LSUDummy(gen_params, wb_master) + bus_master = connections.get_dependency(CommonBusDataKey()) + unit = LSUDummy(gen_params, bus_master) connections.add_dependency(InstructionPrecommitKey(), unit.precommit) return unit diff --git a/coreblocks/lsu/pma.py b/coreblocks/lsu/pma.py index 8e474a6bf..cd91c98f0 100644 --- a/coreblocks/lsu/pma.py +++ b/coreblocks/lsu/pma.py @@ -2,6 +2,7 @@ from functools import reduce from operator import or_ from amaranth import * +from amaranth.lib import data from coreblocks.params import * from transactron.utils import HasElaborate @@ -29,6 +30,11 @@ class PMARegion: mmio: bool = False +class PMALayout(data.StructLayout): + def __init__(self): + super().__init__({"mmio": unsigned(1)}) + + class PMAChecker(Elaboratable): """ Implementation of physical memory attributes checker. It may or may not be a part of LSU. @@ -38,21 +44,20 @@ class PMAChecker(Elaboratable): ---------- addr : Signal Memory address, for which PMAs are requested. - result : Record + result : View PMAs for given address. """ def __init__(self, gen_params: GenParams) -> None: # poor man's interval list self.segments = gen_params.pma - self.attr_layout = gen_params.get(PMALayouts).pma_attrs_layout - self.result = Record(self.attr_layout) + self.result = Signal(PMALayout()) self.addr = Signal(gen_params.isa.xlen) def elaborate(self, platform) -> HasElaborate: m = TModule() - outputs = [Record(self.attr_layout) for _ in self.segments] + outputs = [Signal(PMALayout()) for _ in self.segments] # zero output if addr not in region, propagate value if addr in region for i, segment in enumerate(self.segments): @@ -64,6 +69,6 @@ def elaborate(self, platform) -> HasElaborate: m.d.comb += outputs[i].eq(segment.mmio) # OR all outputs - m.d.comb += self.result.eq(reduce(or_, outputs, 0)) + m.d.comb += self.result.eq(reduce(or_, [Value.cast(o) for o in outputs], 0)) return m diff --git a/coreblocks/params/configurations.py b/coreblocks/params/configurations.py index a289c09e5..6d69cd8f9 100644 --- a/coreblocks/params/configurations.py +++ b/coreblocks/params/configurations.py @@ -48,6 +48,8 @@ class CoreConfiguration: Enables 16-bit Compressed Instructions extension. embedded: bool Enables Reduced Integer (E) extension. + debug_signals: bool + Enable debug signals (for example hardware metrics etc). If disabled, none of them will be synthesized. phys_regs_bits: int Size of the Physical Register File is 2**phys_regs_bits. rob_entries_bits: int @@ -76,6 +78,8 @@ class CoreConfiguration: compressed: bool = False embedded: bool = False + debug_signals: bool = True + phys_regs_bits: int = 6 rob_entries_bits: int = 7 start_pc: int = 0 diff --git a/coreblocks/params/genparams.py b/coreblocks/params/genparams.py index 916832b49..3691d02ca 100644 --- a/coreblocks/params/genparams.py +++ b/coreblocks/params/genparams.py @@ -1,6 +1,6 @@ from __future__ import annotations -from amaranth.utils import log2_int +from amaranth.utils import exact_log2 from .isa import ISA, gen_isa_string from .icache_params import ICacheParameters @@ -36,7 +36,7 @@ def __init__(self, cfg: CoreConfiguration): bytes_in_word = self.isa.xlen // 8 self.wb_params = WishboneParameters( - data_width=self.isa.xlen, addr_width=self.isa.xlen - log2_int(bytes_in_word) + data_width=self.isa.xlen, addr_width=self.isa.xlen - exact_log2(bytes_in_word) ) self.icache_params = ICacheParameters( @@ -47,6 +47,8 @@ def __init__(self, cfg: CoreConfiguration): block_size_bits=cfg.icache_block_size_bits, ) + self.debug_signals_enabled = cfg.debug_signals + # Verification temporally disabled # if not optypes_required_by_extensions(self.isa.extensions) <= optypes_supported(func_units_config): # raise Exception(f"Functional unit configuration fo not support all extension required by{isa_str}") diff --git a/coreblocks/params/instr.py b/coreblocks/params/instr.py index 7bf830436..efaab82cb 100644 --- a/coreblocks/params/instr.py +++ b/coreblocks/params/instr.py @@ -1,6 +1,6 @@ from abc import abstractmethod, ABC -from amaranth.hdl.ast import ValueCastable +from amaranth.hdl import ValueCastable from amaranth import * from transactron.utils import ValueLike diff --git a/coreblocks/params/isa.py b/coreblocks/params/isa.py index 8ca7466e4..dde829023 100644 --- a/coreblocks/params/isa.py +++ b/coreblocks/params/isa.py @@ -156,6 +156,7 @@ class ExceptionCause(IntEnum, shape=5): LOAD_PAGE_FAULT = 13 STORE_PAGE_FAULT = 15 _COREBLOCKS_ASYNC_INTERRUPT = 16 + _COREBLOCKS_MISPREDICTION = 17 @unique diff --git a/coreblocks/params/keys.py b/coreblocks/params/keys.py index 88ed1d102..eab1b3985 100644 --- a/coreblocks/params/keys.py +++ b/coreblocks/params/keys.py @@ -4,16 +4,17 @@ from transactron.lib.dependencies import SimpleKey, UnifierKey from transactron import Method from transactron.lib import MethodTryProduct, Collector -from coreblocks.peripherals.wishbone import WishboneMaster +from coreblocks.peripherals.bus_adapter import BusMasterInterface from amaranth import Signal if TYPE_CHECKING: from coreblocks.structs_common.csr_generic import GenericCSRRegisters # noqa: F401 __all__ = [ - "WishboneDataKey", + "CommonBusDataKey", "InstructionPrecommitKey", - "BranchResolvedKey", + "BranchVerifyKey", + "FetchResumeKey", "ExceptionReportKey", "GenericCSRRegistersKey", "AsyncInterruptInsertSignalKey", @@ -22,7 +23,7 @@ @dataclass(frozen=True) -class WishboneDataKey(SimpleKey[WishboneMaster]): +class CommonBusDataKey(SimpleKey[BusMasterInterface]): pass @@ -32,7 +33,12 @@ class InstructionPrecommitKey(UnifierKey, unifier=MethodTryProduct): @dataclass(frozen=True) -class BranchResolvedKey(UnifierKey, unifier=Collector): +class BranchVerifyKey(SimpleKey[Method]): + pass + + +@dataclass(frozen=True) +class FetchResumeKey(UnifierKey, unifier=Collector): pass diff --git a/coreblocks/params/layouts.py b/coreblocks/params/layouts.py index 969c5995d..98f69344c 100644 --- a/coreblocks/params/layouts.py +++ b/coreblocks/params/layouts.py @@ -1,6 +1,8 @@ +from amaranth.lib.data import StructLayout from coreblocks.params import GenParams, OpType, Funct7, Funct3 from coreblocks.params.isa import ExceptionCause from transactron.utils import LayoutList, LayoutListField, layout_subset +from transactron.utils.transactron_helpers import from_method_layout, make_layout __all__ = [ "CommonLayoutFields", @@ -16,9 +18,9 @@ "UnsignedMulUnitLayouts", "RATLayouts", "LSULayouts", - "PMALayouts", "CSRLayouts", "ICacheLayouts", + "JumpBranchLayouts", ] @@ -83,7 +85,7 @@ def __init__(self, gen_params: GenParams): self.instr: LayoutListField = ("instr", gen_params.isa.ilen) """RISC V instruction.""" - self.exec_fn_layout: LayoutList = [self.op_type, self.funct3, self.funct7] + self.exec_fn_layout = make_layout(self.op_type, self.funct3, self.funct7) """Decoded instruction, in layout form.""" self.exec_fn: LayoutListField = ("exec_fn", self.exec_fn_layout) @@ -135,48 +137,48 @@ def __init__(self, gen_params: GenParams): ) """Logical register number for the destination operand, before ROB allocation.""" - self.reg_alloc_in: LayoutList = [ + self.reg_alloc_in = make_layout( fields.exec_fn, fields.regs_l, fields.imm, fields.csr, fields.pc, - ] + ) - self.reg_alloc_out: LayoutList = [ + self.reg_alloc_out = make_layout( fields.exec_fn, fields.regs_l, self.regs_p_alloc_out, fields.imm, fields.csr, fields.pc, - ] + ) self.renaming_in = self.reg_alloc_out - self.renaming_out: LayoutList = [ + self.renaming_out = make_layout( fields.exec_fn, self.regs_l_rob_in, fields.regs_p, fields.imm, fields.csr, fields.pc, - ] + ) self.rob_allocate_in = self.renaming_out - self.rob_allocate_out: LayoutList = [ + self.rob_allocate_out = make_layout( fields.exec_fn, fields.regs_p, fields.rob_id, fields.imm, fields.csr, fields.pc, - ] + ) self.rs_select_in = self.rob_allocate_out - self.rs_select_out: LayoutList = [ + self.rs_select_out = make_layout( fields.exec_fn, fields.regs_p, fields.rob_id, @@ -185,11 +187,11 @@ def __init__(self, gen_params: GenParams): fields.imm, fields.csr, fields.pc, - ] + ) self.rs_insert_in = self.rs_select_out - self.free_rf_layout: LayoutList = [fields.reg_id] + self.free_rf_layout = make_layout(fields.reg_id) class RFLayouts: @@ -201,10 +203,10 @@ def __init__(self, gen_params: GenParams): self.valid: LayoutListField = ("valid", 1) """Physical register was assigned a value.""" - self.rf_read_in: LayoutList = [fields.reg_id] - self.rf_free: LayoutList = [fields.reg_id] - self.rf_read_out: LayoutList = [fields.reg_val, self.valid] - self.rf_write: LayoutList = [fields.reg_id, fields.reg_val] + self.rf_read_in = make_layout(fields.reg_id) + self.rf_free = make_layout(fields.reg_id) + self.rf_read_out = make_layout(fields.reg_val, self.valid) + self.rf_write = make_layout(fields.reg_id, fields.reg_val) class RATLayouts: @@ -216,19 +218,19 @@ def __init__(self, gen_params: GenParams): self.old_rp_dst: LayoutListField = ("old_rp_dst", gen_params.phys_regs_bits) """Physical register previously associated with the given logical register in RRAT.""" - self.frat_rename_in: LayoutList = [ + self.frat_rename_in = make_layout( fields.rl_s1, fields.rl_s2, fields.rl_dst, fields.rp_dst, - ] - self.frat_rename_out: LayoutList = [fields.rp_s1, fields.rp_s2] + ) + self.frat_rename_out = make_layout(fields.rp_s1, fields.rp_s2) - self.rrat_commit_in: LayoutList = [fields.rl_dst, fields.rp_dst] - self.rrat_commit_out: LayoutList = [self.old_rp_dst] + self.rrat_commit_in = make_layout(fields.rl_dst, fields.rp_dst) + self.rrat_commit_out = make_layout(self.old_rp_dst) - self.rrat_peek_in: LayoutList = [fields.rl_dst] - self.rrat_peek_out: LayoutList = self.rrat_commit_out + self.rrat_peek_in = make_layout(fields.rl_dst) + self.rrat_peek_out = self.rrat_commit_out class ROBLayouts: @@ -237,10 +239,10 @@ class ROBLayouts: def __init__(self, gen_params: GenParams): fields = gen_params.get(CommonLayoutFields) - self.data_layout: LayoutList = [ + self.data_layout = make_layout( fields.rl_dst, fields.rp_dst, - ] + ) self.rob_data: LayoutListField = ("rob_data", self.data_layout) """Data stored in a reorder buffer entry.""" @@ -254,26 +256,26 @@ def __init__(self, gen_params: GenParams): self.end: LayoutListField = ("end", gen_params.rob_entries_bits) """Index of the entry following the last (the latest) entry in the reorder buffer.""" - self.id_layout: LayoutList = [fields.rob_id] + self.id_layout = make_layout(fields.rob_id) - self.internal_layout: LayoutList = [ + self.internal_layout = make_layout( self.rob_data, self.done, fields.exception, - ] + ) - self.mark_done_layout: LayoutList = [ + self.mark_done_layout = make_layout( fields.rob_id, fields.exception, - ] + ) - self.peek_layout: LayoutList = [ + self.peek_layout = make_layout( self.rob_data, fields.rob_id, fields.exception, - ] + ) - self.get_indices: LayoutList = [self.start, self.end] + self.get_indices = make_layout(self.start, self.end) class RSLayoutFields: @@ -293,7 +295,7 @@ class RSFullDataLayout: def __init__(self, gen_params: GenParams): fields = gen_params.get(CommonLayoutFields) - self.data_layout: LayoutList = [ + self.data_layout = make_layout( fields.rp_s1, fields.rp_s2, ("rp_s1_reg", gen_params.phys_regs_bits), @@ -306,7 +308,7 @@ def __init__(self, gen_params: GenParams): fields.imm, fields.csr, fields.pc, - ] + ) class RSInterfaceLayouts: @@ -316,13 +318,13 @@ def __init__(self, gen_params: GenParams, *, rs_entries_bits: int, data_layout: fields = gen_params.get(CommonLayoutFields) rs_fields = gen_params.get(RSLayoutFields, rs_entries_bits=rs_entries_bits, data_layout=data_layout) - self.data_layout: LayoutList = data_layout + self.data_layout = from_method_layout(data_layout) - self.select_out: LayoutList = [rs_fields.rs_entry_id] + self.select_out = make_layout(rs_fields.rs_entry_id) - self.insert_in: LayoutList = [rs_fields.rs_data, rs_fields.rs_entry_id] + self.insert_in = make_layout(rs_fields.rs_data, rs_fields.rs_entry_id) - self.update_in: LayoutList = [fields.reg_id, fields.reg_val] + self.update_in = make_layout(fields.reg_id, fields.reg_val) class RetirementLayouts: @@ -331,7 +333,7 @@ class RetirementLayouts: def __init__(self, gen_params: GenParams): fields = gen_params.get(CommonLayoutFields) - self.precommit: LayoutList = [fields.rob_id, fields.side_fx] + self.precommit = make_layout(fields.rob_id, fields.side_fx) self.flushing = ("flushing", 1) """ Core is currently flushed """ @@ -366,7 +368,7 @@ def __init__(self, gen_params: GenParams, *, rs_entries_bits: int): self.rs = gen_params.get(RSInterfaceLayouts, rs_entries_bits=rs_entries_bits, data_layout=data_layout) rs_fields = gen_params.get(RSLayoutFields, rs_entries_bits=rs_entries_bits, data_layout=data_layout) - self.take_in: LayoutList = [rs_fields.rs_entry_id] + self.take_in = make_layout(rs_fields.rs_entry_id) self.take_out = layout_subset( data.data_layout, @@ -381,7 +383,7 @@ def __init__(self, gen_params: GenParams, *, rs_entries_bits: int): }, ) - self.get_ready_list_out: LayoutList = [self.ready_list] + self.get_ready_list_out = make_layout(self.ready_list) class ICacheLayouts: @@ -393,23 +395,23 @@ def __init__(self, gen_params: GenParams): self.error: LayoutListField = ("last", 1) """This is the last cache refill result.""" - self.issue_req: LayoutList = [fields.addr] + self.issue_req = make_layout(fields.addr) - self.accept_res: LayoutList = [ + self.accept_res = make_layout( fields.instr, fields.error, - ] + ) - self.start_refill: LayoutList = [ + self.start_refill = make_layout( fields.addr, - ] + ) - self.accept_refill: LayoutList = [ + self.accept_refill = make_layout( fields.addr, fields.data, fields.error, self.error, - ] + ) class FetchLayouts: @@ -424,18 +426,14 @@ def __init__(self, gen_params: GenParams): self.rvc: LayoutListField = ("rvc", 1) """Instruction is a compressed (two-byte) one.""" - self.raw_instr: LayoutList = [ + self.raw_instr = make_layout( fields.instr, fields.pc, self.access_fault, self.rvc, - ] + ) - self.branch_verify: LayoutList = [ - ("from_pc", gen_params.isa.xlen), - ("next_pc", gen_params.isa.xlen), - ("resume_from_exception", 1), - ] + self.resume = make_layout(("pc", gen_params.isa.xlen), ("resume_from_exception", 1)) class DecodeLayouts: @@ -444,13 +442,13 @@ class DecodeLayouts: def __init__(self, gen_params: GenParams): fields = gen_params.get(CommonLayoutFields) - self.decoded_instr: LayoutList = [ + self.decoded_instr = make_layout( fields.exec_fn, fields.regs_l, fields.imm, fields.csr, fields.pc, - ] + ) class FuncUnitLayouts: @@ -462,7 +460,7 @@ def __init__(self, gen_params: GenParams): self.result: LayoutListField = ("result", gen_params.isa.xlen) """The result value produced in a functional unit.""" - self.issue: LayoutList = [ + self.issue = make_layout( fields.s1_val, fields.s2_val, fields.rp_dst, @@ -470,39 +468,47 @@ def __init__(self, gen_params: GenParams): fields.exec_fn, fields.imm, fields.pc, - ] + ) - self.accept: LayoutList = [ + self.accept = make_layout( fields.rob_id, self.result, fields.rp_dst, fields.exception, - ] + ) class UnsignedMulUnitLayouts: def __init__(self, gen_params: GenParams): - self.issue: LayoutList = [ + self.issue = make_layout( ("i1", gen_params.isa.xlen), ("i2", gen_params.isa.xlen), - ] + ) - self.accept: LayoutList = [ + self.accept = make_layout( ("o", 2 * gen_params.isa.xlen), - ] + ) class DivUnitLayouts: def __init__(self, gen_params: GenParams): - self.issue: LayoutList = [ + self.issue = make_layout( ("dividend", gen_params.isa.xlen), ("divisor", gen_params.isa.xlen), - ] + ) - self.accept: LayoutList = [ + self.accept = make_layout( ("quotient", gen_params.isa.xlen), ("remainder", gen_params.isa.xlen), - ] + ) + + +class JumpBranchLayouts: + def __init__(self, gen_params: GenParams): + self.verify_branch = make_layout( + ("from_pc", gen_params.isa.xlen), ("next_pc", gen_params.isa.xlen), ("misprediction", 1) + ) + """ Hint for Branch Predictor about branch result """ class LSULayouts: @@ -537,16 +543,11 @@ def __init__(self, gen_params: GenParams): self.store: LayoutListField = ("store", 1) - self.issue: LayoutList = [fields.addr, fields.data, fields.funct3, self.store] - - self.issue_out: LayoutList = [fields.exception, fields.cause] - - self.accept: LayoutList = [fields.data, fields.exception, fields.cause] + self.issue = make_layout(fields.addr, fields.data, fields.funct3, self.store) + self.issue_out = make_layout(fields.exception, fields.cause) -class PMALayouts: - def __init__(self, gen_params: GenParams): - self.pma_attrs_layout = [("mmio", 1)] + self.accept = make_layout(fields.data, fields.exception, fields.cause) class CSRLayouts: @@ -558,16 +559,16 @@ def __init__(self, gen_params: GenParams): self.rs_entries_bits = 0 - self.read: LayoutList = [ + self.read = make_layout( fields.data, ("read", 1), ("written", 1), - ] + ) - self.write: LayoutList = [fields.data] + self.write = make_layout(fields.data) - self._fu_read: LayoutList = [fields.data] - self._fu_write: LayoutList = [fields.data] + self._fu_read = make_layout(fields.data) + self._fu_write = make_layout(fields.data) data_layout = layout_subset( data.data_layout, @@ -597,15 +598,15 @@ class ExceptionRegisterLayouts: def __init__(self, gen_params: GenParams): fields = gen_params.get(CommonLayoutFields) - self.valid = ("valid", 1) + self.valid: LayoutListField = ("valid", 1) - self.report: LayoutList = [ + self.report = make_layout( fields.cause, fields.rob_id, fields.pc, - ] + ) - self.get = self.report + [self.valid] + self.get = StructLayout(self.report.members | make_layout(self.valid).members) class CoreInstructionCounterLayouts: diff --git a/coreblocks/params/optypes.py b/coreblocks/params/optypes.py index 72ca461b8..60fd52c19 100644 --- a/coreblocks/params/optypes.py +++ b/coreblocks/params/optypes.py @@ -34,6 +34,7 @@ class OpType(IntEnum): SINGLE_BIT_MANIPULATION = auto() ADDRESS_GENERATION = auto() BIT_MANIPULATION = auto() + BIT_ROTATION = auto() UNARY_BIT_MANIPULATION_1 = auto() UNARY_BIT_MANIPULATION_2 = auto() UNARY_BIT_MANIPULATION_3 = auto() @@ -88,6 +89,7 @@ class OpType(IntEnum): ], Extension.ZBB: [ OpType.BIT_MANIPULATION, + OpType.BIT_ROTATION, OpType.UNARY_BIT_MANIPULATION_1, OpType.UNARY_BIT_MANIPULATION_2, OpType.UNARY_BIT_MANIPULATION_3, diff --git a/coreblocks/peripherals/axi_lite.py b/coreblocks/peripherals/axi_lite.py index 6fd3fac01..268c396ab 100644 --- a/coreblocks/peripherals/axi_lite.py +++ b/coreblocks/peripherals/axi_lite.py @@ -1,10 +1,12 @@ +from typing import Protocol, TypeAlias, runtime_checkable from amaranth import * -from amaranth.hdl.rec import DIR_FANIN, DIR_FANOUT +from amaranth.lib.wiring import Component, Signature, In, Out from transactron import Method, def_method, TModule from transactron.core import Transaction from transactron.lib.connectors import Forwarder +from transactron.utils._typing import AbstractInterface, AbstractSignature -__all__ = ["AXILiteParameters", "AXILiteMaster"] +__all__ = ["AXILiteParameters", "AXILiteSignature", "AXILiteInterface", "AXILiteMaster"] class AXILiteParameters: @@ -21,67 +23,119 @@ class AXILiteParameters: def __init__(self, *, data_width: int = 64, addr_width: int = 64): self.data_width = data_width self.addr_width = addr_width + self.granularity = 8 -class AXILiteLayout: - """AXI-Lite bus layout generator +class AXILiteSignature(Signature): + """AXI-Lite bus signature Parameters ---------- axil_params: AXILiteParameters - Patameters used to generate AXI-Lite layout - master: Boolean - Whether the layout should be generated for master side - (if false it's generatd for the slave side) - - Attributes - ---------- - axil_layout: Record - Record of a AXI-Lite bus. + Patameters used to generate AXI-Lite signature """ - def __init__(self, axil_params: AXILiteParameters, *, master: bool = True): - write_address = [ - ("valid", 1, DIR_FANOUT if master else DIR_FANIN), - ("rdy", 1, DIR_FANIN if master else DIR_FANOUT), - ("addr", axil_params.addr_width, DIR_FANOUT if master else DIR_FANIN), - ("prot", 3, DIR_FANOUT if master else DIR_FANIN), - ] + def __init__(self, axil_params: AXILiteParameters): + write_address = Signature( + { + "valid": Out(1), + "rdy": In(1), + "addr": Out(axil_params.addr_width), + "prot": Out(3), + } + ) - write_data = [ - ("valid", 1, DIR_FANOUT if master else DIR_FANIN), - ("rdy", 1, DIR_FANIN if master else DIR_FANOUT), - ("data", axil_params.data_width, DIR_FANOUT if master else DIR_FANIN), - ("strb", axil_params.data_width // 8, DIR_FANOUT if master else DIR_FANIN), - ] + write_data = Signature( + { + "valid": Out(1), + "rdy": In(1), + "data": Out(axil_params.data_width), + "strb": Out(axil_params.data_width // 8), + } + ) - write_response = [ - ("valid", 1, DIR_FANIN if master else DIR_FANOUT), - ("rdy", 1, DIR_FANOUT if master else DIR_FANIN), - ("resp", 2, DIR_FANIN if master else DIR_FANOUT), - ] + write_response = Signature( + { + "valid": In(1), + "rdy": Out(1), + "resp": In(2), + } + ) - read_address = [ - ("valid", 1, DIR_FANOUT if master else DIR_FANIN), - ("rdy", 1, DIR_FANIN if master else DIR_FANOUT), - ("addr", axil_params.addr_width, DIR_FANOUT if master else DIR_FANIN), - ("prot", 3, DIR_FANOUT if master else DIR_FANIN), - ] + read_address = Signature( + { + "valid": Out(1), + "rdy": In(1), + "addr": Out(axil_params.addr_width), + "prot": Out(3), + } + ) - read_data = [ - ("valid", 1, DIR_FANIN if master else DIR_FANOUT), - ("rdy", 1, DIR_FANOUT if master else DIR_FANIN), - ("data", axil_params.data_width, DIR_FANIN if master else DIR_FANOUT), - ("resp", 2, DIR_FANIN if master else DIR_FANOUT), - ] + read_data = Signature( + { + "valid": In(1), + "rdy": Out(1), + "data": In(axil_params.data_width), + "resp": In(2), + } + ) + + super().__init__( + { + "write_address": Out(write_address), + "write_data": Out(write_data), + "write_response": Out(write_response), + "read_address": Out(read_address), + "read_data": Out(read_data), + } + ) - self.axil_layout = [ - ("write_address", write_address), - ("write_data", write_data), - ("write_response", write_response), - ("read_address", read_address), - ("read_data", read_data), - ] + +class AXILiteWriteAddressInterface(AbstractInterface[AbstractSignature], Protocol): + valid: Signal + rdy: Signal + addr: Signal + prot: Signal + + +class AXILiteWriteDataInterface(AbstractInterface[AbstractSignature], Protocol): + valid: Signal + rdy: Signal + data: Signal + strb: Signal + + +class AXILiteWriteResponseInterface(AbstractInterface[AbstractSignature], Protocol): + valid: Signal + rdy: Signal + resp: Signal + + +class AXILiteReadAddressInterface(AbstractInterface[AbstractSignature], Protocol): + valid: Signal + rdy: Signal + addr: Signal + prot: Signal + + +@runtime_checkable +class AXILiteReadDataInterface(AbstractInterface[AbstractSignature], Protocol): + valid: Signal + rdy: Signal + data: Signal + resp: Signal + + +class AXILiteInterface(AbstractInterface[AbstractSignature], Protocol): + write_address: AXILiteWriteAddressInterface + write_data: AXILiteWriteDataInterface + write_response: AXILiteWriteResponseInterface + read_address: AXILiteReadAddressInterface + read_data: AXILiteReadDataInterface + + +AXILiteOutChannel: TypeAlias = AXILiteWriteAddressInterface | AXILiteWriteDataInterface | AXILiteReadAddressInterface +AXILiteInChannel: TypeAlias = AXILiteWriteResponseInterface | AXILiteReadDataInterface class AXILiteMasterMethodLayouts: @@ -112,18 +166,18 @@ class AXILiteMasterMethodLayouts: def __init__(self, axil_params: AXILiteParameters): self.ra_request_layout = [ - ("addr", axil_params.addr_width, DIR_FANIN), - ("prot", 3, DIR_FANIN), + ("addr", axil_params.addr_width), + ("prot", 3), ] self.wa_request_layout = [ - ("addr", axil_params.addr_width, DIR_FANIN), - ("prot", 3, DIR_FANIN), + ("addr", axil_params.addr_width), + ("prot", 3), ] self.wd_request_layout = [ - ("data", axil_params.data_width, DIR_FANIN), - ("strb", axil_params.data_width // 8, DIR_FANIN), + ("data", axil_params.data_width), + ("strb", axil_params.data_width // 8), ] self.rd_response_layout = [ @@ -136,7 +190,7 @@ def __init__(self, axil_params: AXILiteParameters): ] -class AXILiteMaster(Elaboratable): +class AXILiteMaster(Component): """AXI-Lite master interface. Parameters @@ -172,7 +226,10 @@ class AXILiteMaster(Elaboratable): Returns response state as 'wr_response_layout'. """ + axil_master: AXILiteInterface + def __init__(self, axil_params: AXILiteParameters): + super().__init__({"axil_master": Out(AXILiteSignature(axil_params))}) self.axil_params = axil_params self.method_layouts = AXILiteMasterMethodLayouts(self.axil_params) @@ -192,7 +249,7 @@ def start_request_transaction(self, m, arg, *, channel, is_address_channel): m.d.sync += channel.strb.eq(arg.strb) m.d.sync += channel.valid.eq(1) - def state_machine_request(self, m: TModule, method: Method, *, channel: Record, request_signal: Signal): + def state_machine_request(self, m: TModule, method: Method, *, channel: AXILiteOutChannel, request_signal: Signal): with m.FSM("Idle"): with m.State("Idle"): m.d.sync += channel.valid.eq(0) @@ -209,11 +266,11 @@ def state_machine_request(self, m: TModule, method: Method, *, channel: Record, with m.Else(): m.d.comb += request_signal.eq(0) - def result_handler(self, m: TModule, forwarder: Forwarder, *, data: bool, channel: Record): + def result_handler(self, m: TModule, forwarder: Forwarder, *, channel: AXILiteInChannel): with m.If(channel.rdy & channel.valid): m.d.sync += channel.rdy.eq(forwarder.read.run) with Transaction().body(m): - if data: + if isinstance(channel, AXILiteReadDataInterface): forwarder.write(m, data=channel.data, resp=channel.resp) else: forwarder.write(m, resp=channel.resp) @@ -223,9 +280,6 @@ def result_handler(self, m: TModule, forwarder: Forwarder, *, data: bool, channe def elaborate(self, platform): m = TModule() - self.axil_layout = AXILiteLayout(self.axil_params).axil_layout - self.axil_master = Record(self.axil_layout) - m.submodules.rd_forwarder = rd_forwarder = Forwarder(self.method_layouts.rd_response_layout) m.submodules.wr_forwarder = wr_forwarder = Forwarder(self.method_layouts.wr_response_layout) @@ -245,7 +299,7 @@ def _(arg): self.start_request_transaction(m, arg, channel=self.axil_master.read_address, is_address_channel=True) # read_data - self.result_handler(m, rd_forwarder, data=True, channel=self.axil_master.read_data) + self.result_handler(m, rd_forwarder, channel=self.axil_master.read_data) @def_method(m, self.rd_response) def _(): @@ -276,7 +330,7 @@ def _(arg): self.start_request_transaction(m, arg, channel=self.axil_master.write_data, is_address_channel=False) # write_response - self.result_handler(m, wr_forwarder, data=False, channel=self.axil_master.write_response) + self.result_handler(m, wr_forwarder, channel=self.axil_master.write_response) @def_method(m, self.wr_response) def _(): diff --git a/coreblocks/peripherals/bus_adapter.py b/coreblocks/peripherals/bus_adapter.py new file mode 100644 index 000000000..139c5a1b3 --- /dev/null +++ b/coreblocks/peripherals/bus_adapter.py @@ -0,0 +1,266 @@ +from typing import Protocol + +from amaranth import * + +from coreblocks.peripherals.wishbone import WishboneMaster +from coreblocks.peripherals.axi_lite import AXILiteMaster + +from transactron import Method, def_method, TModule +from transactron.utils import HasElaborate +from transactron.lib import Serializer +from transactron.utils.transactron_helpers import make_layout + +__all__ = ["BusMasterInterface", "WishboneMasterAdapter", "AXILiteMasterAdapter"] + + +class BusParametersInterface(Protocol): + """ + An interface for parameters of a common bus. + + Parameters + ---------- + data_width : int + An integer that describes the data width of a parametrized bus. + addr_width : int + An integer that describes the address width of a parametrized bus. + granularity : int + An integer that describes the granularity of accesses of a parametrized bus. + """ + + data_width: int + addr_width: int + granularity: int + + +class BusMasterInterface(HasElaborate, Protocol): + """ + An interface of a common bus. + + The bus interface is the preferred way to gain access to a specific bus. + It simplifies interchangeability of buses on the core configuration level. + + Parameters + ---------- + params : BusParametersInterface + Parameters of the bus. + request_read : Method + A method that is used to send a read request to a bus. + request_write : Method + A method that is used to send a write request to a bus. + get_read_response : Method + A method that is used to receive a response from a bus for a previously sent read request. + get_write_response : Method + A method that is used to receive a response from a bus for a previously sent write request. + """ + + params: BusParametersInterface + request_read: Method + request_write: Method + get_read_response: Method + get_write_response: Method + + +class CommonBusMasterMethodLayout: + """ + Layouts of methods for a common bus master. + + Parameters + ---------- + bus_params: BusParametersInterface + Parameters used to generate common bus master methods layouts. + + Attributes + ---------- + request_read_layout: Layout + A layout for the `request_read` method of a common bus master. + + request_write_layout: Layout + A layout for the `request_write` method of a common bus master. + + read_response_layout: Layout + A layout for the `get_read_response` method of a common bus master. + + write_response_layout: Layout + A layout for the `get_write_response` method of a common bus master. + """ + + def __init__(self, bus_params: BusParametersInterface): + self.bus_params = bus_params + + self.request_read_layout = make_layout( + ("addr", self.bus_params.addr_width), + ("sel", self.bus_params.data_width // self.bus_params.granularity), + ) + + self.request_write_layout = make_layout( + ("addr", self.bus_params.addr_width), + ("data", self.bus_params.data_width), + ("sel", self.bus_params.data_width // self.bus_params.granularity), + ) + + self.read_response_layout = make_layout(("data", self.bus_params.data_width), ("err", 1)) + + self.write_response_layout = make_layout(("err", 1)) + + +class WishboneMasterAdapter(Elaboratable, BusMasterInterface): + """ + An adapter for Wishbone master. + + The adapter module is for use in places where BusMasterInterface is expected. + + Parameters + ---------- + bus: WishboneMaster + Specific Wishbone master module which is to be adapted. + + Attributes + ---------- + params: BusParametersInterface + Parameters of the bus. + + method_layouts: CommonBusMasterMethodLayout + Layouts of common bus master methods. + + request_read: Method + Transactional method for initiating a read request. + It is ready if the `request` method of the underlying Wishbone master is ready. + Input layout is `request_read_layout`. + + request_write: Method + Transactional method for initiating a write request. + It is ready if the `request` method of the underlying Wishbone master is ready. + Input layout is `request_write_layout`. + + get_read_response: Method + Transactional method for reading a response of a read action. + It is ready if the `result` method of the underlying Wishbone master is ready. + Output layout is `read_response_layout`. + + get_write_response: Method + Transactional method for reading a response of a write action. + It is ready if the `result` method of the underlying Wishbone master is ready. + Output layout is `write_response_layout`. + """ + + def __init__(self, bus: WishboneMaster): + self.bus = bus + self.params = self.bus.wb_params + + self.method_layouts = CommonBusMasterMethodLayout(self.params) + + self.request_read = Method(i=self.method_layouts.request_read_layout) + self.request_write = Method(i=self.method_layouts.request_write_layout) + self.get_read_response = Method(o=self.method_layouts.read_response_layout) + self.get_write_response = Method(o=self.method_layouts.write_response_layout) + + def elaborate(self, platform): + m = TModule() + + bus_serializer = Serializer( + port_count=2, serialized_req_method=self.bus.request, serialized_resp_method=self.bus.result + ) + m.submodules.bus_serializer = bus_serializer + + @def_method(m, self.request_read) + def _(arg): + we = C(0, unsigned(1)) + data = C(0, unsigned(self.params.data_width)) + bus_serializer.serialize_in[0](m, addr=arg.addr, data=data, we=we, sel=arg.sel) + + @def_method(m, self.request_write) + def _(arg): + we = C(1, unsigned(1)) + bus_serializer.serialize_in[1](m, addr=arg.addr, data=arg.data, we=we, sel=arg.sel) + + @def_method(m, self.get_read_response) + def _(): + res = bus_serializer.serialize_out[0](m) + return {"data": res.data, "err": res.err} + + @def_method(m, self.get_write_response) + def _(): + res = bus_serializer.serialize_out[1](m) + return {"err": res.err} + + return m + + +class AXILiteMasterAdapter(Elaboratable, BusMasterInterface): + """ + An adapter for AXI Lite master. + + The adapter module is for use in places where BusMasterInterface is expected. + + Parameters + ---------- + bus: AXILiteMaster + Specific AXI Lite master module which is to be adapted. + + Attributes + ---------- + params: BusParametersInterface + Parameters of the bus. + + method_layouts: CommonBusMasterMethodLayout + Layouts of common bus master methods. + + request_read: Method + Transactional method for initiating a read request. + It is ready if the `ra_request` method of the underlying AXI Lite master is ready. + Input layout is `request_read_layout`. + + request_write: Method + Transactional method for initiating a write request. + It is ready if the 'wa_request' and 'wd_request' methods of the underlying AXI Lite master are ready. + Input layout is `request_write_layout`. + + get_read_response: Method + Transactional method for reading a response of a read action. + It is ready if the `rd_response` method of the underlying AXI Lite master is ready. + Output layout is `read_response_layout`. + + get_write_response: Method + Transactional method for reading a response of a write action. + It is ready if the `wr_response` method of the underlying AXI Lite master is ready. + Output layout is `write_response_layout`. + """ + + def __init__(self, bus: AXILiteMaster): + self.bus = bus + self.params = self.bus.axil_params + + self.method_layouts = CommonBusMasterMethodLayout(self.params) + + self.request_read = Method(i=self.method_layouts.request_read_layout) + self.request_write = Method(i=self.method_layouts.request_write_layout) + self.get_read_response = Method(o=self.method_layouts.read_response_layout) + self.get_write_response = Method(o=self.method_layouts.write_response_layout) + + def elaborate(self, platform): + m = TModule() + + @def_method(m, self.request_read) + def _(arg): + prot = C(0, unsigned(3)) + self.bus.ra_request(m, addr=arg.addr, prot=prot) + + @def_method(m, self.request_write) + def _(arg): + prot = C(0, unsigned(3)) + self.bus.wa_request(m, addr=arg.addr, prot=prot) + self.bus.wd_request(m, data=arg.data, strb=arg.sel) + + @def_method(m, self.get_read_response) + def _(): + res = self.bus.rd_response(m) + err = res.resp != 0 + return {"data": res.data, "err": err} + + @def_method(m, self.get_write_response) + def _(): + res = self.bus.wr_response(m) + err = res.resp != 0 + return {"err": err} + + return m diff --git a/coreblocks/peripherals/wishbone.py b/coreblocks/peripherals/wishbone.py index bbf81eb18..f2dcca253 100644 --- a/coreblocks/peripherals/wishbone.py +++ b/coreblocks/peripherals/wishbone.py @@ -1,14 +1,17 @@ from amaranth import * -from amaranth.hdl.rec import DIR_FANIN, DIR_FANOUT +from amaranth.lib.wiring import PureInterface, Signature, In, Out, Component from functools import reduce -from typing import List +from typing import Protocol, cast import operator from transactron import Method, def_method, TModule from transactron.core import Transaction from transactron.lib import AdapterTrans, BasicFifo from transactron.utils import OneHotSwitchDynamic, assign, RoundRobin +from transactron.utils._typing import AbstractInterface, AbstractSignature from transactron.lib.connectors import Forwarder +from transactron.utils.transactron_helpers import make_layout +from transactron.lib import logging class WishboneParameters: @@ -24,134 +27,150 @@ class WishboneParameters: The smallest unit of data transfer that a port is capable of transferring. Defaults to 8 bits """ - def __init__(self, *, data_width=64, addr_width=64, granularity=8): + def __init__(self, *, data_width: int = 64, addr_width: int = 64, granularity: int = 8): self.data_width = data_width self.addr_width = addr_width self.granularity = granularity -class WishboneLayout: - """Wishbone bus Layout generator. +class WishboneSignature(Signature): + def __init__(self, wb_params: WishboneParameters): + super().__init__( + { + "dat_r": In(wb_params.data_width), + "dat_w": Out(wb_params.data_width), + "rst": Out(1), + "ack": In(1), + "adr": Out(wb_params.addr_width), + "cyc": Out(1), + "stall": In(1), + "err": In(1), + "lock": Out(1), + "rty": In(1), + "sel": Out(wb_params.data_width // wb_params.granularity), + "stb": Out(1), + "we": Out(1), + } + ) + + def create(self, *, path: tuple[str | int, ...] = (), src_loc_at: int = 0): + """Create a WishboneInterface.""" # workaround for Sphinx problem with Amaranth docstring + return cast(WishboneInterface, PureInterface(self, path=path, src_loc_at=src_loc_at + 1)) + + +class WishboneInterface(AbstractInterface[AbstractSignature], Protocol): + dat_r: Signal + dat_w: Signal + rst: Signal + ack: Signal + adr: Signal + cyc: Signal + stall: Signal + err: Signal + lock: Signal + rty: Signal + sel: Signal + stb: Signal + we: Signal + + +class WishboneMasterMethodLayout: + """Wishbone master layouts for methods Parameters ---------- wb_params: WishboneParameters - Parameters used to generate Wishbone layout - master: Boolean - Whether the layout should be generated for the master side - (otherwise it's generated for the slave side) + Patameters used to generate Wishbone master layouts Attributes ---------- - wb_layout: Record - Record of a Wishbone bus. - """ - - def __init__(self, wb_params: WishboneParameters, master=True): - self.wb_layout = [ - ("dat_r", wb_params.data_width, DIR_FANIN if master else DIR_FANOUT), - ("dat_w", wb_params.data_width, DIR_FANOUT if master else DIR_FANIN), - ("rst", 1, DIR_FANOUT if master else DIR_FANIN), - ("ack", 1, DIR_FANIN if master else DIR_FANOUT), - ("adr", wb_params.addr_width, DIR_FANOUT if master else DIR_FANIN), - ("cyc", 1, DIR_FANOUT if master else DIR_FANIN), - ("stall", 1, DIR_FANIN if master else DIR_FANOUT), - ("err", 1, DIR_FANIN if master else DIR_FANOUT), - ("lock", 1, DIR_FANOUT if master else DIR_FANIN), - ("rty", 1, DIR_FANIN if master else DIR_FANOUT), - ("sel", wb_params.data_width // wb_params.granularity, DIR_FANOUT if master else DIR_FANIN), - ("stb", 1, DIR_FANOUT if master else DIR_FANIN), - ("we", 1, DIR_FANOUT if master else DIR_FANIN), - ] - - -class WishboneBus(Record): - """Wishbone bus. + request_layout: Layout + Layout for request method of WishboneMaster. - Parameters - ---------- - wb_params: WishboneParameters - Parameters for bus generation. + result_layout: Layout + Layout for result method of WishboneMaster. """ - def __init__(self, wb_params: WishboneParameters, **kwargs): - super().__init__(WishboneLayout(wb_params).wb_layout, **kwargs) + def __init__(self, wb_params: WishboneParameters): + self.request_layout = make_layout( + ("addr", wb_params.addr_width), + ("data", wb_params.data_width), + ("we", 1), + ("sel", wb_params.data_width // wb_params.granularity), + ) + + self.result_layout = make_layout(("data", wb_params.data_width), ("err", 1)) -class WishboneMaster(Elaboratable): +class WishboneMaster(Component): """Wishbone bus master interface. Parameters ---------- wb_params: WishboneParameters Parameters for bus generation. + name: str, optional + Name of this bus. Used for logging. Attributes ---------- - wbMaster: Record (like WishboneLayout) + wb_master: WishboneInterface Wishbone bus output. request: Method Transactional method to start a new Wishbone request. Ready when no request is being executed and previous result is read. - Takes `requestLayout` as argument. + Takes `request_layout` as argument. result: Method Transactional method to read previous request result. Becomes ready after Wishbone request is completed. - Returns state of request (error or success) and data (in case of read request) as `resultLayout`. + Returns state of request (error or success) and data (in case of read request) as `result_layout`. """ - def __init__(self, wb_params: WishboneParameters): + wb_master: WishboneInterface + + def __init__(self, wb_params: WishboneParameters, name: str = ""): + super().__init__({"wb_master": Out(WishboneSignature(wb_params))}) + self.name = name self.wb_params = wb_params - self.wb_layout = WishboneLayout(wb_params).wb_layout - self.wbMaster = Record(self.wb_layout) - self.generate_layouts(wb_params) - self.request = Method(i=self.requestLayout) - self.result = Method(o=self.resultLayout) + self.method_layouts = WishboneMasterMethodLayout(wb_params) - self.result_data = Record(self.resultLayout) + self.request = Method(i=self.method_layouts.request_layout) + self.result = Method(o=self.method_layouts.result_layout) # latched input signals - self.txn_req = Record(self.requestLayout) - - self.ports = list(self.wbMaster.fields.values()) + self.txn_req = Signal(self.method_layouts.request_layout) - def generate_layouts(self, wb_params: WishboneParameters): - # generate method layouts locally - self.requestLayout = [ - ("addr", wb_params.addr_width, DIR_FANIN), - ("data", wb_params.data_width, DIR_FANIN), - ("we", 1, DIR_FANIN), - ("sel", wb_params.data_width // wb_params.granularity, DIR_FANIN), - ] - - self.resultLayout = [("data", wb_params.data_width), ("err", 1)] + logger_name = "bus.wishbone" + if name != "": + logger_name += f".{name}" + self.log = logging.HardwareLogger(logger_name) def elaborate(self, platform): m = TModule() - m.submodules.result = result = Forwarder(self.resultLayout) + m.submodules.result = result = Forwarder(self.method_layouts.result_layout) request_ready = Signal() def FSMWBCycStart(request): # noqa: N802 # internal FSM function that starts Wishbone cycle - m.d.sync += self.wbMaster.cyc.eq(1) - m.d.sync += self.wbMaster.stb.eq(1) - m.d.sync += self.wbMaster.adr.eq(request.addr) - m.d.sync += self.wbMaster.dat_w.eq(Mux(request.we, request.data, 0)) - m.d.sync += self.wbMaster.we.eq(request.we) - m.d.sync += self.wbMaster.sel.eq(request.sel) + m.d.sync += self.wb_master.cyc.eq(1) + m.d.sync += self.wb_master.stb.eq(1) + m.d.sync += self.wb_master.adr.eq(request.addr) + m.d.sync += self.wb_master.dat_w.eq(Mux(request.we, request.data, 0)) + m.d.sync += self.wb_master.we.eq(request.we) + m.d.sync += self.wb_master.sel.eq(request.sel) with m.FSM("Reset"): with m.State("Reset"): - m.d.sync += self.wbMaster.rst.eq(1) + m.d.sync += self.wb_master.rst.eq(1) m.next = "Idle" with m.State("Idle"): # default values for important signals - m.d.sync += self.wbMaster.rst.eq(0) - m.d.sync += self.wbMaster.stb.eq(0) - m.d.sync += self.wbMaster.cyc.eq(0) + m.d.sync += self.wb_master.rst.eq(0) + m.d.sync += self.wb_master.stb.eq(0) + m.d.sync += self.wb_master.cyc.eq(0) m.d.comb += request_ready.eq(1) with m.If(self.request.run): m.next = "WBWaitACK" @@ -161,25 +180,35 @@ def FSMWBCycStart(request): # noqa: N802 m.next = "WBWaitACK" with m.State("WBWaitACK"): - with m.If(self.wbMaster.ack | self.wbMaster.err): + with m.If(self.wb_master.ack | self.wb_master.err): m.d.comb += request_ready.eq(result.read.run) with Transaction().body(m): # will be always ready, as we checked that in Idle - result.write(m, data=Mux(self.txn_req.we, 0, self.wbMaster.dat_r), err=self.wbMaster.err) + result.write(m, data=Mux(self.txn_req.we, 0, self.wb_master.dat_r), err=self.wb_master.err) with m.If(self.request.run): m.next = "WBWaitACK" with m.Else(): - m.d.sync += self.wbMaster.cyc.eq(0) - m.d.sync += self.wbMaster.stb.eq(0) + m.d.sync += self.wb_master.cyc.eq(0) + m.d.sync += self.wb_master.stb.eq(0) m.next = "Idle" - with m.If(self.wbMaster.rty): - m.d.sync += self.wbMaster.cyc.eq(1) - m.d.sync += self.wbMaster.stb.eq(0) + with m.If(self.wb_master.rty): + m.d.sync += self.wb_master.cyc.eq(1) + m.d.sync += self.wb_master.stb.eq(0) m.next = "WBCycStart" @def_method(m, self.result) def _(): - return result.read(m) + ret = result.read(m) + + self.log.debug( + m, + True, + "response data=0x{:x} err={}", + ret.data, + ret.err, + ) + + return ret @def_method(m, self.request, ready=request_ready & result.write.ready) def _(arg): @@ -187,13 +216,23 @@ def _(arg): # do WBCycStart state in the same clock cycle FSMWBCycStart(arg) + self.log.debug( + m, + True, + "request addr=0x{:x} data=0x{:x} sel=0x{:x} write={}", + arg.addr, + arg.data, + arg.sel, + arg.we, + ) + result.write.schedule_before(self.request) result.read.schedule_before(self.request) return m -class PipelinedWishboneMaster(Elaboratable): +class PipelinedWishboneMaster(Component): """Pipelined Wishbone bus master interface. Parameters @@ -205,21 +244,24 @@ class PipelinedWishboneMaster(Elaboratable): Attributes ---------- - wb: Record (like WishboneLayout) + wb: WishboneInterface Wishbone bus output. request: Method Transactional method to start a new Wishbone request. Ready if new request can be immediately sent. - Takes `requestLayout` as argument. + Takes `request_layout` as argument. result: Method Transactional method to read results from completed requests sequentially. Ready if buffered results are available. - Returns state of request (error or success) and data (in case of read request) as `resultLayout`. + Returns state of request (error or success) and data (in case of read request) as `result_layout`. requests_finished: Signal, out True, if there are no requests waiting for response """ + wb: WishboneInterface + def __init__(self, wb_params: WishboneParameters, *, max_req: int = 8): + super().__init__({"wb": Out(WishboneSignature(wb_params))}) self.wb_params = wb_params self.max_req = max_req @@ -229,9 +271,6 @@ def __init__(self, wb_params: WishboneParameters, *, max_req: int = 8): self.requests_finished = Signal() - self.wb_layout = WishboneLayout(wb_params).wb_layout - self.wb = Record(self.wb_layout) - def generate_method_layouts(self, wb_params: WishboneParameters): # generate method layouts locally self.request_in_layout = [ @@ -292,17 +331,17 @@ def _(arg) -> None: return m -class WishboneMuxer(Elaboratable): +class WishboneMuxer(Component): """Wishbone Muxer. Connects one master to multiple slaves. Parameters ---------- - master_wb: Record (like WishboneLayout) - Record of master inteface. - slaves: List[Record] - List of connected slaves' Wishbone Records (like WishboneLayout). + wb_params: WishboneParameters + Parameters for bus generation. + num_slaves: int + Number of slave devices to multiplex. ssel_tga: Signal Signal that selects the slave to connect. Signal width is the number of slaves and each bit coresponds to a slave. This signal is a Wishbone TGA (address tag), so it needs to be valid every time Wishbone STB @@ -311,15 +350,29 @@ class WishboneMuxer(Elaboratable): different `ssel_tga` value, all pending request have to be finished (and `stall` cleared) and there have to be one cycle delay from previouse request (to deassert the STB signal). Holding new requests should be implemented in block that controlls `ssel_tga` signal, before the Wishbone Master. + + Attributes + ---------- + master_wb: WishboneInterface + Master inteface. + slaves: list of WishboneInterface + List of connected slaves' Wishbone interfaces. """ - def __init__(self, master_wb: Record, slaves: List[Record], ssel_tga: Signal): - self.master_wb = master_wb - self.slaves = slaves + master_wb: WishboneInterface + slaves: list[WishboneInterface] + + def __init__(self, wb_params: WishboneParameters, num_slaves: int, ssel_tga: Signal): + super().__init__( + { + "master_wb": Out(WishboneSignature(wb_params)), + "slaves": In(WishboneSignature(wb_params)).array(num_slaves), + } + ) self.sselTGA = ssel_tga select_bits = ssel_tga.shape().width - assert select_bits == len(slaves) + assert select_bits == num_slaves self.txn_sel = Signal(select_bits) self.txn_sel_r = Signal(select_bits) @@ -339,10 +392,9 @@ def elaborate(self, platform): for i in range(len(self.slaves)): # connect all M->S signals except stb - m.d.comb += self.master_wb.connect( - self.slaves[i], - include=["dat_w", "rst", "cyc", "lock", "adr", "we", "sel"], - ) + # workaround for the lack of selective connecting in wiring + for n in ["dat_w", "cyc", "lock", "adr", "we", "sel", "stb"]: + m.d.comb += getattr(self.slaves[i], n).eq(getattr(self.master_wb, n)) # use stb as select m.d.comb += self.slaves[i].stb.eq(self.txn_sel[i] & self.master_wb.stb) @@ -352,12 +404,14 @@ def elaborate(self, platform): m.d.comb += self.master_wb.rty.eq(reduce(operator.or_, [self.slaves[i].rty for i in range(len(self.slaves))])) for i in OneHotSwitchDynamic(m, self.txn_sel): # mux S->M data - m.d.comb += self.master_wb.connect(self.slaves[i], include=["dat_r", "stall"]) + # workaround for the lack of selective connecting in wiring + for n in ["dat_r", "stall"]: + m.d.comb += getattr(self.master_wb, n).eq(getattr(self.slaves[i], n)) return m # connects multiple masters to one slave -class WishboneArbiter(Elaboratable): +class WishboneArbiter(Component): """Wishbone Arbiter. Connects multiple masters to one slave. @@ -365,20 +419,34 @@ class WishboneArbiter(Elaboratable): Parameters ---------- - slave_wb: Record (like WishboneLayout) - Record of slave inteface. - masters: List[Record] - List of master interface Records. + wb_params: WishboneParameters + Parameters for bus generation. + num_slaves: int + Number of master devices. + + Attributes + ---------- + slave_wb: WishboneInterface + Slave inteface. + masters: list of WishboneInterface + List of master interfaces. """ - def __init__(self, slave_wb: Record, masters: List[Record]): - self.slave_wb = slave_wb - self.masters = masters + slave_wb: WishboneInterface + masters: list[WishboneInterface] + + def __init__(self, wb_params: WishboneParameters, num_masters: int): + super().__init__( + { + "slave_wb": In(WishboneSignature(wb_params)), + "masters": Out(WishboneSignature(wb_params)).array(num_masters), + } + ) self.prev_cyc = Signal() # Amaranth round robin singals self.arb_enable = Signal() - self.req_signal = Signal(len(masters)) + self.req_signal = Signal(num_masters) def elaborate(self, platform): m = TModule() @@ -402,7 +470,9 @@ def elaborate(self, platform): m.d.comb += self.masters[i].err.eq((m.submodules.rr.grant == i) & self.slave_wb.err) m.d.comb += self.masters[i].rty.eq((m.submodules.rr.grant == i) & self.slave_wb.rty) # remaining S->M signals are shared, master will only accept response if bus termination signal is present - m.d.comb += self.masters[i].connect(self.slave_wb, include=["dat_r", "stall"]) + # workaround for the lack of selective connecting in wiring + for n in ["dat_r", "stall"]: + m.d.comb += getattr(self.masters[i], n).eq(getattr(self.slave_wb, n)) # combine reset singnal m.d.comb += self.slave_wb.rst.eq(reduce(operator.or_, [self.masters[i].rst for i in range(len(self.masters))])) @@ -411,10 +481,9 @@ def elaborate(self, platform): with m.Switch(m.submodules.rr.grant): for i in range(len(self.masters)): with m.Case(i): - m.d.comb += self.masters[i].connect( - self.slave_wb, - include=["dat_w", "cyc", "lock", "adr", "we", "sel", "stb"], - ) + # workaround for the lack of selective connecting in wiring + for n in ["dat_w", "cyc", "lock", "adr", "we", "sel", "stb"]: + m.d.comb += getattr(self.slave_wb, n).eq(getattr(self.masters[i], n)) # Disable slave when round robin is not valid at start of new request # This prevents chaning grant and muxes during Wishbone cycle @@ -424,7 +493,7 @@ def elaborate(self, platform): return m -class WishboneMemorySlave(Elaboratable): +class WishboneMemorySlave(Component): """Wishbone slave with memory Wishbone slave interface with addressable memory underneath. @@ -439,11 +508,14 @@ class WishboneMemorySlave(Elaboratable): Attributes ---------- - bus: Record (like WishboneLayout) - Wishbone bus record. + bus: WishboneInterface + Wishbone bus interface. """ + bus: WishboneInterface + def __init__(self, wb_params: WishboneParameters, **kwargs): + super().__init__({"bus": In(WishboneSignature(wb_params))}) if "width" not in kwargs: kwargs["width"] = wb_params.data_width if kwargs["width"] not in (8, 16, 32, 64): @@ -455,7 +527,6 @@ def __init__(self, wb_params: WishboneParameters, **kwargs): raise RuntimeError("Granularity has to be one of: 8, 16, 32, 64") self.mem = Memory(**kwargs) - self.bus = Record(WishboneLayout(wb_params, master=False).wb_layout) def elaborate(self, platform): m = TModule() diff --git a/coreblocks/scheduler/scheduler.py b/coreblocks/scheduler/scheduler.py index 31479c115..6e7e152bd 100644 --- a/coreblocks/scheduler/scheduler.py +++ b/coreblocks/scheduler/scheduler.py @@ -47,7 +47,7 @@ def elaborate(self, platform): m = TModule() free_reg = Signal(self.gen_params.phys_regs_bits) - data_out = Record(self.output_layout) + data_out = Signal(self.output_layout) with Transaction().body(m): instr = self.get_instr(m) @@ -95,7 +95,7 @@ def __init__(self, *, get_instr: Method, push_instr: Method, rename: Method, gen def elaborate(self, platform): m = TModule() - data_out = Record(self.output_layout) + data_out = Signal(self.output_layout) with Transaction().body(m): instr = self.get_instr(m) @@ -152,7 +152,7 @@ def __init__(self, *, get_instr: Method, push_instr: Method, rob_put: Method, ge def elaborate(self, platform): m = TModule() - data_out = Record(self.output_layout) + data_out = Signal(self.output_layout) with Transaction().body(m): instr = self.get_instr(m) @@ -239,7 +239,7 @@ def elaborate(self, platform): instr = self.get_instr(m) forwarder.write(m, instr) - data_out = Record(self.output_layout) + data_out = Signal(self.output_layout) for i, (alloc, optypes) in enumerate(self.rs_select): # checks if RS can perform this kind of operation @@ -332,7 +332,7 @@ def elaborate(self, platform): for i, rs_insert in enumerate(self.rs_insert): # connect only matching fields - arg = Record.like(rs_insert.data_in) + arg = Signal.like(rs_insert.data_in) m.d.comb += assign(arg, data, fields=AssignType.COMMON) # this assignment truncates signal width from max rs_entry_bits to target RS specific width m.d.comb += arg.rs_entry_id.eq(instr.rs_entry_id) diff --git a/coreblocks/scheduler/wakeup_select.py b/coreblocks/scheduler/wakeup_select.py index fbe6d40a4..724d6ffe7 100644 --- a/coreblocks/scheduler/wakeup_select.py +++ b/coreblocks/scheduler/wakeup_select.py @@ -41,14 +41,14 @@ def elaborate(self, platform): with Transaction().body(m): ready = self.get_ready(m) - ready_width = len(ready) + ready_width = ready.shape().size last = Signal(range(ready_width)) for i in range(ready_width): - with m.If(ready[i]): + with m.If(ready.ready_list[i]): m.d.comb += last.eq(i) row = self.take_row(m, last) - issue_rec = Record(self.gen_params.get(FuncUnitLayouts).issue) + issue_rec = Signal(self.gen_params.get(FuncUnitLayouts).issue) m.d.comb += assign(issue_rec, row, fields=AssignType.ALL) self.issue(m, issue_rec) diff --git a/coreblocks/stages/retirement.py b/coreblocks/stages/retirement.py index ca03326de..1225e7e57 100644 --- a/coreblocks/stages/retirement.py +++ b/coreblocks/stages/retirement.py @@ -4,6 +4,7 @@ from transactron.core import Method, Transaction, TModule, def_method from transactron.lib.simultaneous import condition from transactron.utils.dependencies import DependencyManager +from transactron.lib.metrics import * from coreblocks.params.genparams import GenParams from coreblocks.params.isa import ExceptionCause @@ -27,7 +28,6 @@ def __init__( exception_cause_clear: Method, frat_rename: Method, fetch_continue: Method, - fetch_stall: Method, instr_decrement: Method, trap_entry: Method, ): @@ -43,11 +43,11 @@ def __init__( self.exception_cause_clear = exception_cause_clear self.rename = frat_rename self.fetch_continue = fetch_continue - self.fetch_stall = fetch_stall self.instr_decrement = instr_decrement self.trap_entry = trap_entry self.instret_csr = DoubleCounterCSR(gen_params, CSRAddress.INSTRET, CSRAddress.INSTRETH) + self.perf_instr_ret = HwCounter("backend.retirement.retired_instr", "Number of retired instructions") self.dependency_manager = gen_params.get(DependencyManager) self.core_state = Method(o=self.gen_params.get(RetirementLayouts).core_state, nonexclusive=True) @@ -56,6 +56,8 @@ def __init__( def elaborate(self, platform): m = TModule() + m.submodules += [self.perf_instr_ret] + m_csr = self.dependency_manager.get_dependency(GenericCSRRegistersKey()).m_mode m.submodules.instret_csr = self.instret_csr @@ -83,6 +85,7 @@ def retire_instr(rob_entry): free_phys_reg(rat_out.old_rp_dst) self.instret_csr.increment(m) + self.perf_instr_ret.incr(m) def flush_instr(rob_entry): # get original rp_dst mapped to instruction rl_dst in R-RAT @@ -103,6 +106,8 @@ def flush_instr(rob_entry): ~rob_entry.exception | (rob_entry.exception & ecr_entry.valid & (ecr_entry.rob_id == rob_entry.rob_id)) ) + continue_pc_override = Signal() + continue_pc = Signal(self.gen_params.isa.xlen) core_flushing = Signal() with m.FSM("NORMAL") as fsm: @@ -116,12 +121,12 @@ def flush_instr(rob_entry): commit = Signal() with m.If(rob_entry.exception): - self.fetch_stall(m) - cause_register = self.exception_cause_get(m) cause_entry = Signal(self.gen_params.isa.xlen) + arch_trap = Signal(reset=1) + with m.If(cause_register.cause == ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT): # Async interrupts are inserted only by JumpBranchUnit and conditionally by MRET and CSR # The PC field is set to address of instruction to resume from interrupt (e.g. for jumps @@ -132,6 +137,14 @@ def flush_instr(rob_entry): # TODO: set correct interrupt id from InterruptController # Set MSB - the Interrupt bit m.d.av_comb += cause_entry.eq(1 << (self.gen_params.isa.xlen - 1)) + with m.Elif(cause_register.cause == ExceptionCause._COREBLOCKS_MISPREDICTION): + # Branch misprediction - commit jump, flush core and continue from correct pc. + m.d.av_comb += commit.eq(1) + # Do not modify trap related CSRs + m.d.av_comb += arch_trap.eq(0) + + m.d.sync += continue_pc_override.eq(1) + m.d.sync += continue_pc.eq(cause_register.pc) with m.Else(): # RISC-V synchronous exceptions - don't retire instruction that caused exception, # and later resume from it. @@ -140,10 +153,13 @@ def flush_instr(rob_entry): m.d.av_comb += cause_entry.eq(cause_register.cause) - m_csr.mcause.write(m, cause_entry) - m_csr.mepc.write(m, cause_register.pc) - self.trap_entry(m) + with m.If(arch_trap): + # Register RISC-V architectural trap in CSRs + m_csr.mcause.write(m, cause_entry) + m_csr.mepc.write(m, cause_register.pc) + self.trap_entry(m) + # Fetch is already stalled by ExceptionCauseRegister with m.If(core_empty): m.next = "TRAP_RESUME" with m.Else(): @@ -184,9 +200,14 @@ def flush_instr(rob_entry): with Transaction().body(m): # Resume core operation + handler_pc = Signal(self.gen_params.isa.xlen) # mtvec without mode is [mxlen-1:2], mode is two last bits. Only direct mode is supported - resume_pc = m_csr.mtvec.read(m) & ~(0b11) - self.fetch_continue(m, from_pc=0, next_pc=resume_pc, resume_from_exception=1) + m.d.av_comb += handler_pc.eq(m_csr.mtvec.read(m).data & ~(0b11)) + + resume_pc = Mux(continue_pc_override, continue_pc, handler_pc) + m.d.sync += continue_pc_override.eq(0) + + self.fetch_continue(m, pc=resume_pc, resume_from_exception=1) # Release pending trap state - allow accepting new reports self.exception_cause_clear(m) diff --git a/coreblocks/structs_common/csr.py b/coreblocks/structs_common/csr.py index 9277e1c85..a01a028fa 100644 --- a/coreblocks/structs_common/csr.py +++ b/coreblocks/structs_common/csr.py @@ -1,4 +1,5 @@ from amaranth import * +from amaranth.lib.data import StructLayout from amaranth.lib.enum import IntEnum from dataclasses import dataclass @@ -11,12 +12,13 @@ from coreblocks.params.isa import Funct3, ExceptionCause from coreblocks.params.keys import ( AsyncInterruptInsertSignalKey, - BranchResolvedKey, + FetchResumeKey, ExceptionReportKey, InstructionPrecommitKey, ) from coreblocks.params.optypes import OpType from coreblocks.utils.protocols import FuncBlock +from transactron.utils.transactron_helpers import from_method_layout class PrivilegeLevel(IntEnum, shape=2): @@ -112,7 +114,7 @@ def __init__(self, csr_number: int, gen_params: GenParams, *, ro_bits: int = 0): self._fu_write = Method(i=csr_layouts._fu_write) self.value = Signal(gen_params.isa.xlen) - self.side_effects = Record([("read", 1), ("write", 1)]) + self.side_effects = Signal(StructLayout({"read": 1, "write": 1})) # append to global CSR list dm = gen_params.get(DependencyManager) @@ -121,9 +123,9 @@ def __init__(self, csr_number: int, gen_params: GenParams, *, ro_bits: int = 0): def elaborate(self, platform): m = TModule() - internal_method_layout = [("data", self.gen_params.isa.xlen), ("active", 1)] - write_internal = Record(internal_method_layout) - fu_write_internal = Record(internal_method_layout) + internal_method_layout = from_method_layout([("data", self.gen_params.isa.xlen), ("active", 1)]) + write_internal = Signal(internal_method_layout) + fu_write_internal = Signal(internal_method_layout) m.d.sync += self.side_effects.eq(0) @@ -194,7 +196,7 @@ def __init__(self, gen_params: GenParams): self.gen_params = gen_params self.dependency_manager = gen_params.get(DependencyManager) - self.fetch_continue = Method(o=gen_params.get(FetchLayouts).branch_verify) + self.fetch_resume = Method(o=gen_params.get(FetchLayouts).resume) # Standard RS interface self.csr_layouts = gen_params.get(CSRLayouts) @@ -228,7 +230,7 @@ def elaborate(self, platform): current_result = Signal(self.gen_params.isa.xlen) - instr = Record(self.csr_layouts.rs.data_layout + [("valid", 1)]) + instr = Signal(StructLayout(self.csr_layouts.rs.data_layout.members | {"valid": 1})) m.d.comb += ready_to_process.eq(precommitting & instr.valid & (instr.rp_s1 == 0)) @@ -358,12 +360,11 @@ def _(): "exception": exception | interrupt, } - @def_method(m, self.fetch_continue, accepted) + @def_method(m, self.fetch_resume, accepted) def _(): # CSR instructions are never compressed, PC+4 is always next instruction return { - "from_pc": instr.pc, - "next_pc": instr.pc + self.gen_params.isa.ilen_bytes, + "pc": instr.pc + self.gen_params.isa.ilen_bytes, "resume_from_exception": False, } @@ -381,7 +382,7 @@ class CSRBlockComponent(BlockComponentParams): def get_module(self, gen_params: GenParams) -> FuncBlock: connections = gen_params.get(DependencyManager) unit = CSRUnit(gen_params) - connections.add_dependency(BranchResolvedKey(), unit.fetch_continue) + connections.add_dependency(FetchResumeKey(), unit.fetch_resume) connections.add_dependency(InstructionPrecommitKey(), unit.precommit) return unit diff --git a/coreblocks/structs_common/exception.py b/coreblocks/structs_common/exception.py index 3b44f7ca8..4385b12f6 100644 --- a/coreblocks/structs_common/exception.py +++ b/coreblocks/structs_common/exception.py @@ -50,7 +50,7 @@ class ExceptionCauseRegister(Elaboratable): If `exception` bit is set in the ROB, `Retirement` stage fetches exception details from this module. """ - def __init__(self, gen_params: GenParams, rob_get_indices: Method): + def __init__(self, gen_params: GenParams, rob_get_indices: Method, fetch_stall_exception: Method): self.gen_params = gen_params self.cause = Signal(ExceptionCause) @@ -71,6 +71,7 @@ def __init__(self, gen_params: GenParams, rob_get_indices: Method): self.clear = Method() self.rob_get_indices = rob_get_indices + self.fetch_stall_exception = fetch_stall_exception def elaborate(self, platform): m = TModule() @@ -103,6 +104,9 @@ def _(cause, rob_id, pc): m.d.sync += self.valid.eq(1) + # In case of any reported exception, core will need to be flushed. Fetch can be stalled immediately + self.fetch_stall_exception(m) + @def_method(m, self.get) def _(): return {"rob_id": self.rob_id, "cause": self.cause, "pc": self.pc, "valid": self.valid} diff --git a/coreblocks/structs_common/rf.py b/coreblocks/structs_common/rf.py index 461fab8ed..899e99593 100644 --- a/coreblocks/structs_common/rf.py +++ b/coreblocks/structs_common/rf.py @@ -1,6 +1,7 @@ from amaranth import * from transactron import Method, def_method, TModule from coreblocks.params import RFLayouts, GenParams +from transactron.utils.transactron_helpers import make_layout __all__ = ["RegisterFile"] @@ -9,9 +10,9 @@ class RegisterFile(Elaboratable): def __init__(self, *, gen_params: GenParams): self.gen_params = gen_params layouts = gen_params.get(RFLayouts) - self.internal_layout = [("reg_val", gen_params.isa.xlen), ("valid", 1)] + self.internal_layout = make_layout(("reg_val", gen_params.isa.xlen), ("valid", 1)) self.read_layout = layouts.rf_read_out - self.entries = Array(Record(self.internal_layout) for _ in range(2**gen_params.phys_regs_bits)) + self.entries = Array(Signal(self.internal_layout) for _ in range(2**gen_params.phys_regs_bits)) self.read1 = Method(i=layouts.rf_read_in, o=layouts.rf_read_out) self.read2 = Method(i=layouts.rf_read_in, o=layouts.rf_read_out) diff --git a/coreblocks/structs_common/rob.py b/coreblocks/structs_common/rob.py index 0f01a6abf..b2b74f9ae 100644 --- a/coreblocks/structs_common/rob.py +++ b/coreblocks/structs_common/rob.py @@ -1,5 +1,6 @@ from amaranth import * from transactron import Method, def_method, TModule +from transactron.lib.metrics import * from ..params import GenParams, ROBLayouts __all__ = ["ReorderBuffer"] @@ -13,12 +14,21 @@ def __init__(self, gen_params: GenParams) -> None: self.mark_done = Method(i=layouts.mark_done_layout) self.peek = Method(o=layouts.peek_layout, nonexclusive=True) self.retire = Method() - self.data = Array(Record(layouts.internal_layout) for _ in range(2**gen_params.rob_entries_bits)) + self.data = Array(Signal(layouts.internal_layout) for _ in range(2**gen_params.rob_entries_bits)) self.get_indices = Method(o=layouts.get_indices, nonexclusive=True) + self.perf_rob_wait_time = LatencyMeasurer( + "backend.rob.wait_time", + description="Distribution of time instructions spend in ROB", + slots_number=(2**gen_params.rob_entries_bits + 1), + max_latency=1000, + ) + def elaborate(self, platform): m = TModule() + m.submodules += [self.perf_rob_wait_time] + start_idx = Signal(self.params.rob_entries_bits) end_idx = Signal(self.params.rob_entries_bits) @@ -35,11 +45,13 @@ def _(): @def_method(m, self.retire, ready=self.data[start_idx].done) def _(): + self.perf_rob_wait_time.stop(m) m.d.sync += start_idx.eq(start_idx + 1) m.d.sync += self.data[start_idx].done.eq(0) @def_method(m, self.put, ready=put_possible) def _(arg): + self.perf_rob_wait_time.start(m) m.d.sync += self.data[end_idx].rob_data.eq(arg) m.d.sync += self.data[end_idx].done.eq(0) m.d.sync += end_idx.eq(end_idx + 1) diff --git a/coreblocks/structs_common/rs.py b/coreblocks/structs_common/rs.py index 255f48a63..fe8d04ba4 100644 --- a/coreblocks/structs_common/rs.py +++ b/coreblocks/structs_common/rs.py @@ -5,6 +5,7 @@ from transactron import Method, def_method, TModule from coreblocks.params import RSLayouts, GenParams, OpType from transactron.core import RecordDict +from transactron.utils.transactron_helpers import make_layout __all__ = ["RS"] @@ -18,12 +19,11 @@ def __init__( self.rs_entries = rs_entries self.rs_entries_bits = (rs_entries - 1).bit_length() self.layouts = gen_params.get(RSLayouts, rs_entries_bits=self.rs_entries_bits) - self.internal_layout = [ + self.internal_layout = make_layout( ("rs_data", self.layouts.rs.data_layout), ("rec_full", 1), - ("rec_ready", 1), ("rec_reserved", 1), - ] + ) self.insert = Method(i=self.layouts.rs.insert_in) self.select = Method(o=self.layouts.rs.select_out) @@ -33,22 +33,23 @@ def __init__( self.ready_for = [list(op_list) for op_list in ready_for] self.get_ready_list = [Method(o=self.layouts.get_ready_list_out, nonexclusive=True) for _ in self.ready_for] - self.data = Array(Record(self.internal_layout) for _ in range(self.rs_entries)) + self.data = Array(Signal(self.internal_layout) for _ in range(self.rs_entries)) + self.data_ready = Signal(self.rs_entries) def elaborate(self, platform): m = TModule() m.submodules.enc_select = PriorityEncoder(width=self.rs_entries) - for record in self.data: - m.d.comb += record.rec_ready.eq( + for i, record in enumerate(self.data): + m.d.comb += self.data_ready[i].eq( ~record.rs_data.rp_s1.bool() & ~record.rs_data.rp_s2.bool() & record.rec_full.bool() ) select_vector = Cat(~record.rec_reserved for record in self.data) select_possible = select_vector.any() - take_vector = Cat(record.rec_ready & record.rec_full for record in self.data) + take_vector = Cat(self.data_ready[i] & record.rec_full for i, record in enumerate(self.data)) take_possible = take_vector.any() ready_lists: list[Value] = [] diff --git a/docker/riscv-toolchain.Dockerfile b/docker/riscv-toolchain.Dockerfile index 957141eb0..a998e79e3 100644 --- a/docker/riscv-toolchain.Dockerfile +++ b/docker/riscv-toolchain.Dockerfile @@ -12,8 +12,8 @@ RUN apt-get update && \ RUN git clone --shallow-since=2023.05.01 https://github.com/riscv/riscv-gnu-toolchain && \ cd riscv-gnu-toolchain && \ - git checkout 2023.05.14 && \ - ./configure --with-multilib-generator="rv32i-ilp32--a*zifence*zicsr;rv32im-ilp32--a*zifence*zicsr;rv32ic-ilp32--a*zifence*zicsr;rv32imc-ilp32--a*zifence*zicsr;rv32imfc-ilp32f--a*zifence;rv32i_zmmul-ilp32--a*zifence*zicsr;rv32ic_zmmul-ilp32--a*zifence*zicsr" && \ + git checkout 2023.12.10 && \ + ./configure --with-multilib-generator="rv32i-ilp32--a*zifence*zicsr;rv32im-ilp32--a*zifence*zicsr;rv32ic-ilp32--a*zifence*zicsr;rv32imc-ilp32--a*zifence*zicsr;rv32imfc-ilp32f--a*zifence;rv32imc_zba_zbb_zbc_zbs-ilp32--a*zifence*zicsr" && \ make -j$(nproc) && \ cd / && rm -rf riscv-gnu-toolchain diff --git a/examples/wishbone_wbm_example.py b/examples/wishbone_wbm_example.py index e28ff26d8..1e6362bcb 100644 --- a/examples/wishbone_wbm_example.py +++ b/examples/wishbone_wbm_example.py @@ -12,6 +12,8 @@ def __init__(self): self.output = Record(WishboneMaster.resultLayout) self.output_btn = Signal() + self.wbm_ports = list(self.wbm.wb_master.fields.values()) + self.m = Module() self.tm = TransactionModule(self.m) self.wbm = WishboneMaster() @@ -24,7 +26,7 @@ def __init__(self): self.output.err, self.output.data, self.output_btn, - ] + self.wbm.ports + ] + self.wbm_ports def elaborate(self, platform): m = self.m diff --git a/pyrightconfig.json b/pyrightconfig.json index aadb361ef..e0eb26f81 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -21,6 +21,6 @@ "stubPath": "./stubs", - "pythonVersion": "3.10", + "pythonVersion": "3.11", "pythonPlatform": "Linux" } diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..970b444e5 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,9 @@ +[pytest] +minversion = 7.2.2 +testpaths = + tests +norecursedirs = '*.egg', '.*', 'build', 'dist', 'venv', '__traces__', '__pycache__' +filterwarnings = + ignore:cannot collect test class 'TestbenchIO':pytest.PytestCollectionWarning + ignore:No files were found in testpaths:pytest.PytestConfigWarning: +log_cli=true diff --git a/requirements-dev.txt b/requirements-dev.txt index 3e90e45be..1d9530305 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -13,9 +13,10 @@ pyright==1.1.332 Sphinx==5.1.1 sphinx-rtd-theme==1.0.0 sphinxcontrib-mermaid==0.8.1 -cocotb==1.7.2 +cocotb==1.8.1 cocotb-bus==0.2.1 -pytest==7.2.2 +pytest==8.0.0 +pytest-xdist==3.5.0 pyelftools==0.29 -dataclasses-json==0.6.3 tabulate==0.9.0 +filelock==3.13.1 diff --git a/requirements.txt b/requirements.txt index c08979f0e..43714219e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ amaranth-yosys==0.35.0.0.post81 -git+https://github.com/amaranth-lang/amaranth@ab6503e352825b36bb29f1a8622b9e98aac9a6c6 +git+https://github.com/amaranth-lang/amaranth@115954b4d957b4ba642ad056ab1670bf5d185fb6 +dataclasses-json==0.6.3 diff --git a/scripts/check_test_results.py b/scripts/check_test_results.py deleted file mode 100755 index c10af9bc2..000000000 --- a/scripts/check_test_results.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env python3 - -import sys -import os -import pathlib -import xml.etree.ElementTree as eT - -FAILURE_TAG = "failure" -TOP_DIR = pathlib.Path(__file__).parent.parent -TEST_RESULTS_FILE = TOP_DIR.joinpath("test/regression/cocotb/results.xml") - -if not os.path.exists(TEST_RESULTS_FILE): - print("File not found: ", TEST_RESULTS_FILE) - sys.exit(1) - -tree = eT.parse(TEST_RESULTS_FILE) - -if len(list(tree.iter(FAILURE_TAG))) > 0: - print("Some regression tests failed") - sys.exit(1) - -print("All regression tests pass") diff --git a/scripts/core_graph.py b/scripts/core_graph.py index a589c205a..6818f6dd0 100755 --- a/scripts/core_graph.py +++ b/scripts/core_graph.py @@ -17,7 +17,7 @@ from transactron.graph import TracingFragment # noqa: E402 from test.test_core import CoreTestElaboratable # noqa: E402 from coreblocks.params.configurations import basic_core_config # noqa: E402 -from transactron.core import TransactionModule # noqa: E402 +from transactron.core import TransactionManagerKey, TransactionModule # noqa: E402 gp = GenParams(basic_core_config) elaboratable = CoreTestElaboratable(gp) @@ -25,10 +25,10 @@ fragment = TracingFragment.get(tm, platform=None).prepare() core = fragment -while not hasattr(core, "transactionManager"): +while not hasattr(core, "manager"): core = core._tracing_original # type: ignore -mgr = core.transactionManager # type: ignore +mgr = core.manager.get_dependency(TransactionManagerKey()) # type: ignore with arg.ofile as fp: graph = mgr.visual_graph(fragment) diff --git a/scripts/gen_verilog.py b/scripts/gen_verilog.py index 964b52654..e9c5b8707 100755 --- a/scripts/gen_verilog.py +++ b/scripts/gen_verilog.py @@ -4,19 +4,21 @@ import sys import argparse +from amaranth import * from amaranth.build import Platform -from amaranth.back import verilog from amaranth import Module, Elaboratable + if __name__ == "__main__": parent = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, parent) from coreblocks.params.genparams import GenParams -from coreblocks.peripherals.wishbone import WishboneBus +from coreblocks.peripherals.wishbone import WishboneSignature from coreblocks.core import Core from transactron import TransactionModule -from transactron.utils import flatten_signals +from transactron.utils import DependencyManager, DependencyContext +from transactron.utils.gen import generate_verilog from coreblocks.params.configurations import * @@ -31,25 +33,34 @@ class Top(Elaboratable): def __init__(self, gen_params): self.gp: GenParams = gen_params - self.wb_instr = WishboneBus(self.gp.wb_params, name="wb_instr") - self.wb_data = WishboneBus(self.gp.wb_params, name="wb_data") + self.wb_instr = WishboneSignature(self.gp.wb_params).create() + self.wb_data = WishboneSignature(self.gp.wb_params).create() def elaborate(self, platform: Platform): m = Module() - tm = TransactionModule(m) + tm = TransactionModule(m, dependency_manager=DependencyContext.get()) m.submodules.c = Core(gen_params=self.gp, wb_instr_bus=self.wb_instr, wb_data_bus=self.wb_data) return tm -def gen_verilog(core_config: CoreConfiguration, output_path): - top = Top(GenParams(core_config)) +def gen_verilog(core_config: CoreConfiguration, output_path: str): + with DependencyContext(DependencyManager()): + gp = GenParams(core_config) + top = Top(gp) + instr_ports: list[Signal] = [getattr(top.wb_instr, name) for name in top.wb_instr.signature.members] + data_ports: list[Signal] = [getattr(top.wb_data, name) for name in top.wb_data.signature.members] + for sig in instr_ports: + sig.name = "wb_instr__" + sig.name + for sig in data_ports: + sig.name = "wb_data__" + sig.name - with open(output_path, "w") as f: - signals = list(flatten_signals(top.wb_instr)) + list(flatten_signals(top.wb_data)) + verilog_text, gen_info = generate_verilog(top, instr_ports + data_ports) - f.write(verilog.convert(top, ports=signals, strip_internal_attrs=True)) + gen_info.encode(f"{output_path}.json") + with open(output_path, "w") as f: + f.write(verilog_text) def main(): @@ -70,6 +81,12 @@ def main(): + f"Available configurations: {', '.join(list(str_to_coreconfig.keys()))}. Default: %(default)s", ) + parser.add_argument( + "--strip-debug", + action="store_true", + help="Remove debugging signals. Default: %(default)s", + ) + parser.add_argument( "-o", "--output", action="store", default="core.v", help="Output file path. Default: %(default)s" ) @@ -81,7 +98,11 @@ def main(): if args.config not in str_to_coreconfig: raise KeyError(f"Unknown config '{args.config}'") - gen_verilog(str_to_coreconfig[args.config], args.output) + config = str_to_coreconfig[args.config] + if args.strip_debug: + config = config.replace(debug_signals=False) + + gen_verilog(config, args.output) if __name__ == "__main__": diff --git a/scripts/run_benchmarks.py b/scripts/run_benchmarks.py index 1a0c9a2b4..442cb26ec 100755 --- a/scripts/run_benchmarks.py +++ b/scripts/run_benchmarks.py @@ -8,6 +8,7 @@ import sys import os import subprocess +import tabulate from typing import Literal from pathlib import Path @@ -15,6 +16,7 @@ sys.path.insert(0, str(topdir)) import test.regression.benchmark # noqa: E402 +from test.regression.benchmark import BenchmarkResult # noqa: E402 from test.regression.pysim import PySimulation # noqa: E402 @@ -58,6 +60,12 @@ def run_benchmarks_with_cocotb(benchmarks: list[str], traces: bool) -> bool: test_cases = ",".join(benchmarks) arglist += [f"TESTCASE={test_cases}"] + verilog_code = topdir.joinpath("core.v") + gen_info_path = f"{verilog_code}.json" + + arglist += [f"VERILOG_SOURCES={verilog_code}"] + arglist += [f"_COREBLOCKS_GEN_INFO={gen_info_path}"] + if traces: arglist += ["TRACES=1"] @@ -66,7 +74,7 @@ def run_benchmarks_with_cocotb(benchmarks: list[str], traces: bool) -> bool: return res.returncode == 0 -def run_benchmarks_with_pysim(benchmarks: list[str], traces: bool, verbose: bool) -> bool: +def run_benchmarks_with_pysim(benchmarks: list[str], traces: bool) -> bool: suite = unittest.TestSuite() def _gen_test(test_name: str): @@ -74,9 +82,7 @@ def test_fn(): traces_file = None if traces: traces_file = "benchmark." + test_name - asyncio.run( - test.regression.benchmark.run_benchmark(PySimulation(verbose, traces_file=traces_file), test_name) - ) + asyncio.run(test.regression.benchmark.run_benchmark(PySimulation(traces_file=traces_file), test_name)) test_fn.__name__ = test_name test_fn.__qualname__ = test_name @@ -86,25 +92,59 @@ def test_fn(): for test_name in benchmarks: suite.addTest(unittest.FunctionTestCase(_gen_test(test_name))) - runner = unittest.TextTestRunner(verbosity=(2 if verbose else 1)) + runner = unittest.TextTestRunner(verbosity=2) result = runner.run(suite) return result.wasSuccessful() -def run_benchmarks(benchmarks: list[str], backend: Literal["pysim", "cocotb"], traces: bool, verbose: bool) -> bool: +def run_benchmarks(benchmarks: list[str], backend: Literal["pysim", "cocotb"], traces: bool) -> bool: if backend == "cocotb": return run_benchmarks_with_cocotb(benchmarks, traces) elif backend == "pysim": - return run_benchmarks_with_pysim(benchmarks, traces, verbose) + return run_benchmarks_with_pysim(benchmarks, traces) return False +def build_result_table(results: dict[str, BenchmarkResult]) -> str: + if len(results) == 0: + return "" + + header = ["Testbench name", "Cycles", "Instructions", "IPC"] + + # First fetch all metrics names to build the header + result = next(iter(results.values())) + for metric_name in sorted(result.metric_values.keys()): + regs = result.metric_values[metric_name] + for reg_name in regs: + header.append(f"{metric_name}/{reg_name}") + + columns = [header] + for benchmark_name, result in results.items(): + ipc = result.instr / result.cycles + + column = [benchmark_name, result.cycles, result.instr, ipc] + + for metric_name in sorted(result.metric_values.keys()): + regs = result.metric_values[metric_name] + for reg_name in regs: + column.append(regs[reg_name]) + + columns.append(column) + + # Transpose the table, as the library expects to get a list of rows (and we have a list of columns). + rows = [list(i) for i in zip(*columns)] + + return tabulate.tabulate(rows, headers="firstrow", tablefmt="simple_outline") + + def main(): parser = argparse.ArgumentParser() parser.add_argument("-l", "--list", action="store_true", help="List all benchmarks") parser.add_argument("-t", "--trace", action="store_true", help="Dump waveforms") - parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") + parser.add_argument("--log-level", default="WARNING", action="store", help="Level of messages to display.") + parser.add_argument("--log-filter", default=".*", action="store", help="Regexp used to filter out logs.") + parser.add_argument("-p", "--profile", action="store_true", help="Write execution profiles") parser.add_argument("-b", "--backend", default="cocotb", choices=["cocotb", "pysim"], help="Simulation backend") parser.add_argument( "-o", @@ -123,6 +163,9 @@ def main(): print(name) return + os.environ["__TRANSACTRON_LOG_LEVEL"] = args.log_level + os.environ["__TRANSACTRON_LOG_FILTER"] = args.log_filter + if args.benchmark_name: pattern = re.compile(args.benchmark_name) benchmarks = [name for name in benchmarks if pattern.search(name)] @@ -131,27 +174,31 @@ def main(): print(f"Could not find benchmark '{args.benchmark_name}'") sys.exit(1) - success = run_benchmarks(benchmarks, args.backend, args.trace, args.verbose) + if args.profile: + os.environ["__TRANSACTRON_PROFILE"] = "1" + + success = run_benchmarks(benchmarks, args.backend, args.trace) if not success: print("Benchmark execution failed") sys.exit(1) - results = [] ipcs = [] + + results: dict[str, BenchmarkResult] = {} + for name in benchmarks: with open(f"{str(test.regression.benchmark.results_dir)}/{name}.json", "r") as f: - res = json.load(f) + result = BenchmarkResult.from_json(f.read()) # type: ignore - ipc = res["instr"] / res["cycle"] - ipcs.append(ipc) + results[name] = result - results.append({"name": name, "unit": "Instructions Per Cycle", "value": ipc}) - print(f"Benchmark '{name}': cycles={res['cycle']}, instructions={res['instr']} ipc={ipc:.4f}") + ipc = result.instr / result.cycles + ipcs.append({"name": name, "unit": "Instructions Per Cycle", "value": ipc}) - print(f"Average ipc={sum(ipcs)/len(ipcs):.4f}") + print(build_result_table(results)) with open(args.output, "w") as benchmark_file: - json.dump(results, benchmark_file, indent=4) + json.dump(ipcs, benchmark_file, indent=4) if __name__ == "__main__": diff --git a/scripts/run_signature.py b/scripts/run_signature.py index 7d45bae7f..2e047f9ae 100755 --- a/scripts/run_signature.py +++ b/scripts/run_signature.py @@ -31,6 +31,12 @@ def run_with_cocotb(test_name: str, traces: bool, output: str) -> bool: arglist += [f"TESTNAME={test_name}"] arglist += [f"OUTPUT={output}"] + verilog_code = f"{parent}/core.v" + gen_info_path = f"{verilog_code}.json" + + arglist += [f"VERILOG_SOURCES={verilog_code}"] + arglist += [f"_COREBLOCKS_GEN_INFO={gen_info_path}"] + if traces: arglist += ["TRACES=1"] @@ -39,32 +45,31 @@ def run_with_cocotb(test_name: str, traces: bool, output: str) -> bool: return os.path.isfile(output) # completed successfully if signature file was created -def run_with_pysim(test_name: str, traces: bool, verbose: bool, output: str) -> bool: +def run_with_pysim(test_name: str, traces: bool, output: str) -> bool: traces_file = None if traces: traces_file = os.path.basename(test_name) try: - asyncio.run( - test.regression.signature.run_test(PySimulation(verbose, traces_file=traces_file), test_name, output) - ) + asyncio.run(test.regression.signature.run_test(PySimulation(traces_file=traces_file), test_name, output)) except RuntimeError as e: print("RuntimeError:", e) return False return True -def run_test(test: str, backend: Literal["pysim", "cocotb"], traces: bool, verbose: bool, output: str) -> bool: +def run_test(test: str, backend: Literal["pysim", "cocotb"], traces: bool, output: str) -> bool: if backend == "cocotb": return run_with_cocotb(test, traces, output) elif backend == "pysim": - return run_with_pysim(test, traces, verbose, output) + return run_with_pysim(test, traces, output) return False def main(): parser = argparse.ArgumentParser() parser.add_argument("-t", "--trace", action="store_true", help="Dump waveforms") - parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") + parser.add_argument("--log-level", default="WARNING", action="store", help="Level of messages to display.") + parser.add_argument("--log-filter", default=".*", action="store", help="Regexp used to filter out logs.") parser.add_argument("-b", "--backend", default="pysim", choices=["cocotb", "pysim"], help="Simulation backend") parser.add_argument("-o", "--output", default=None, help="Selects output file to write test signature to") parser.add_argument("path") @@ -73,7 +78,10 @@ def main(): output = args.output if args.output else args.path + ".signature" - success = run_test(args.path, args.backend, args.trace, args.verbose, output) + os.environ["__TRANSACTRON_LOG_LEVEL"] = args.log_level + os.environ["__TRANSACTRON_LOG_FILTER"] = args.log_filter + + success = run_test(args.path, args.backend, args.trace, output) if not success: print(f"{args.path}: Program execution failed") diff --git a/scripts/run_tests.py b/scripts/run_tests.py index 264daa707..2223b8cf1 100755 --- a/scripts/run_tests.py +++ b/scripts/run_tests.py @@ -1,103 +1,15 @@ #!/usr/bin/env python3 -import unittest -import asyncio +import pytest import argparse -import re -import sys import os -import subprocess -from typing import Literal from pathlib import Path topdir = Path(__file__).parent.parent -sys.path.insert(0, str(topdir)) - -import test.regression.test # noqa: E402 -from test.regression.pysim import PySimulation # noqa: E402 - -REGRESSION_TESTS_PREFIX = "test.regression." def cd_to_topdir(): - os.chdir(str(topdir)) - - -def load_unit_tests(): - suite = unittest.TestLoader().discover(".") - - tests = {} - - def flatten(suite): - if hasattr(suite, "__iter__"): - for x in suite: - flatten(x) - else: - tests[suite.id()] = suite - - flatten(suite) - - return tests - - -def load_regression_tests() -> list[str]: - all_tests = test.regression.test.get_all_test_names() - if len(all_tests) == 0: - res = subprocess.run(["make", "-C", "test/external/riscv-tests"]) - if res.returncode != 0: - print("Couldn't build regression tests") - sys.exit(1) - - exclude = {"rv32ui-ma_data", "rv32ui-fence_i"} - - return list(all_tests - exclude) - - -def run_regressions_with_cocotb(tests: list[str], traces: bool) -> bool: - cpu_count = len(os.sched_getaffinity(0)) - arglist = ["make", "-C", "test/regression/cocotb", "-f", "test.Makefile", f"-j{cpu_count}"] - - test_cases = ",".join(tests) - arglist += [f"TESTCASE={test_cases}"] - - if traces: - arglist += ["TRACES=1"] - - res = subprocess.run(arglist) - - return res.returncode == 0 - - -def run_regressions_with_pysim(tests: list[str], traces: bool, verbose: bool) -> bool: - suite = unittest.TestSuite() - - def _gen_test(test_name: str): - def test_fn(): - traces_file = None - if traces: - traces_file = REGRESSION_TESTS_PREFIX + test_name - asyncio.run(test.regression.test.run_test(PySimulation(verbose, traces_file=traces_file), test_name)) - - test_fn.__name__ = test_name - test_fn.__qualname__ = test_name - - return test_fn - - for test_name in tests: - suite.addTest(unittest.FunctionTestCase(_gen_test(test_name))) - - runner = unittest.TextTestRunner(verbosity=(2 if verbose else 1)) - result = runner.run(suite) - - return result.wasSuccessful() - - -def run_regression_tests(tests: list[str], backend: Literal["pysim", "cocotb"], traces: bool, verbose: bool) -> bool: - if backend == "cocotb": - return run_regressions_with_cocotb(tests, traces) - elif backend == "pysim": - return run_regressions_with_pysim(tests, traces, verbose) - return False + os.chdir(topdir) def main(): @@ -111,46 +23,38 @@ def main(): "-b", "--backend", default="cocotb", choices=["cocotb", "pysim"], help="Simulation backend for regression tests" ) parser.add_argument("-c", "--count", type=int, help="Start `c` first tests which match regexp") + parser.add_argument( + "-j", "--jobs", type=int, default=len(os.sched_getaffinity(0)), help="Start `j` jobs in parallel. Default: all" + ) parser.add_argument("test_name", nargs="?") args = parser.parse_args() - unit_tests = load_unit_tests() - regression_tests = load_regression_tests() if args.all else [] - - if args.list: - for name in list(unit_tests.keys()): - print(name) - for name in regression_tests: - print(REGRESSION_TESTS_PREFIX + name) - return + pytest_arguments = ["--max-worker-restart=1"] if args.trace: - os.environ["__COREBLOCKS_DUMP_TRACES"] = "1" - + pytest_arguments.append("--coreblocks-traces") if args.profile: - os.environ["__TRANSACTRON_PROFILE"] = "1" - + pytest_arguments.append("--coreblocks-profile") if args.test_name: - pattern = re.compile(args.test_name) - unit_tests = {name: test for name, test in unit_tests.items() if pattern.search(name)} - regression_tests = [test for test in regression_tests if pattern.search(REGRESSION_TESTS_PREFIX + test)] - - if not unit_tests and not regression_tests: - print(f"Could not find test matching '{args.test_name}'") - sys.exit(1) - - unit_tests_success = True - if unit_tests: - runner = unittest.TextTestRunner(verbosity=(2 if args.verbose else 1)) - result = runner.run(unittest.TestSuite(list(unit_tests.values())[: args.count])) - unit_tests_success = result.wasSuccessful() - - regression_tests_success = True - if regression_tests: - regression_tests_success = run_regression_tests(regression_tests, args.backend, args.trace, args.verbose) - - sys.exit(not (unit_tests_success and regression_tests_success)) + pytest_arguments += [f"--coreblocks-test-name={args.test_name}"] + if args.count: + pytest_arguments += ["--coreblocks-test-count", str(args.count)] + if args.list: + pytest_arguments.append("--coreblocks-list") + if args.jobs and not args.list: + # To list tests we can not use xdist, because it doesn't support forwarding of stdout from workers. + pytest_arguments += ["-n", str(args.jobs)] + if args.all: + pytest_arguments.append("--coreblocks-regression") + if args.verbose: + pytest_arguments.append("--verbose") + if args.backend: + pytest_arguments += [f"--coreblocks-backend={args.backend}"] + + ret = pytest.main(pytest_arguments, []) + + exit(ret) if __name__ == "__main__": diff --git a/scripts/synthesize.py b/scripts/synthesize.py index 5e14d019f..6c5c2f7eb 100755 --- a/scripts/synthesize.py +++ b/scripts/synthesize.py @@ -7,12 +7,14 @@ from amaranth.build import Platform from amaranth import * +from amaranth.lib.wiring import Flow if __name__ == "__main__": parent = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, parent) +from transactron.utils.dependencies import DependencyContext, DependencyManager from transactron.utils import ModuleConnector from coreblocks.params.genparams import GenParams from coreblocks.params.fu_params import FunctionalComponentParams @@ -25,7 +27,7 @@ from coreblocks.fu.zbs import ZbsComponent from transactron import TransactionModule from transactron.lib import AdapterBase, AdapterTrans -from coreblocks.peripherals.wishbone import WishboneArbiter, WishboneBus +from coreblocks.peripherals.wishbone import WishboneArbiter, WishboneInterface from constants.ecp5_platforms import ( ResourceBuilder, adapter_resources, @@ -44,7 +46,7 @@ class WishboneConnector(Elaboratable): - def __init__(self, wb: WishboneBus, number: int): + def __init__(self, wb: WishboneInterface, number: int): self.wb = wb self.number = number @@ -54,7 +56,12 @@ def elaborate(self, platform: Platform): pins = platform.request("wishbone", self.number) assert isinstance(pins, Record) - m.d.comb += self.wb.connect(pins) + for name in self.wb.signature.members: + member = self.wb.signature.members[name] + if member.flow == Flow.In: + m.d.comb += getattr(pins, name).o.eq(getattr(self.wb, name)) + else: + m.d.comb += getattr(self.wb, name).eq(getattr(pins, name).i) return m @@ -92,14 +99,13 @@ def elaborate(self, platform: Platform): def unit_core(gen_params: GenParams): resources = wishbone_resources(gen_params.wb_params) - wb_instr = WishboneBus(gen_params.wb_params) - wb_data = WishboneBus(gen_params.wb_params) + wb_arbiter = WishboneArbiter(gen_params.wb_params, 2) + wb_instr = wb_arbiter.masters[0] + wb_data = wb_arbiter.masters[1] - core = Core(gen_params=gen_params, wb_instr_bus=wb_instr, wb_data_bus=wb_data) + wb_connector = WishboneConnector(wb_arbiter.slave_wb, 0) - wb = WishboneBus(gen_params.wb_params) - wb_arbiter = WishboneArbiter(wb, [wb_instr, wb_data]) - wb_connector = WishboneConnector(wb, 0) + core = Core(gen_params=gen_params, wb_instr_bus=wb_instr, wb_data_bus=wb_data) module = ModuleConnector(core=core, wb_arbiter=wb_arbiter, wb_connector=wb_connector) @@ -117,7 +123,7 @@ def unit(gen_params: GenParams): module = ModuleConnector(fu=fu, issue_connector=issue_connector, accept_connector=accept_connector) - return resources, TransactionModule(module) + return resources, TransactionModule(module, dependency_manager=DependencyContext.get()) return unit @@ -138,11 +144,12 @@ def unit(gen_params: GenParams): def synthesize(core_config: CoreConfiguration, platform: str, core: UnitCore): - gen_params = GenParams(core_config) - resource_builder, module = core(gen_params) + with DependencyContext(DependencyManager()): + gen_params = GenParams(core_config) + resource_builder, module = core(gen_params) - if platform == "ecp5": - make_ecp5_platform(resource_builder)().build(module) + if platform == "ecp5": + make_ecp5_platform(resource_builder)().build(module) def main(): @@ -170,6 +177,12 @@ def main(): help="Select core unit." + f"Available units: {', '.join(core_units.keys())}. Default: %(default)s", ) + parser.add_argument( + "--strip-debug", + action="store_true", + help="Remove debugging signals. Default: %(default)s", + ) + parser.add_argument( "-v", "--verbose", @@ -187,7 +200,11 @@ def main(): if args.unit not in core_units: raise KeyError(f"Unknown core unit '{args.unit}'") - synthesize(str_to_coreconfig[args.config], args.platform, core_units[args.unit]) + config = str_to_coreconfig[args.config] + if args.strip_debug: + config = config.replace(debug_signals=False) + + synthesize(config, args.platform, core_units[args.unit]) if __name__ == "__main__": diff --git a/stubs/amaranth/_toolchain/yosys.pyi b/stubs/amaranth/_toolchain/yosys.pyi new file mode 100644 index 000000000..46cff1055 --- /dev/null +++ b/stubs/amaranth/_toolchain/yosys.pyi @@ -0,0 +1,144 @@ +""" +This type stub file was generated by pyright. +""" + +__all__ = ["YosysError", "YosysBinary", "find_yosys"] +from typing import Optional +from pathlib import Path + + +class YosysError(Exception): + ... + + +class YosysWarning(Warning): + ... + + +class YosysBinary: + @classmethod + def available(cls) -> bool: + """Check for Yosys availability. + + Returns + ------- + available : bool + ``True`` if Yosys is installed, ``False`` otherwise. Installed binary may still not + be runnable, or might be too old to be useful. + """ + ... + + @classmethod + def version(cls) -> Optional[tuple[int, int, int]]: + """Get Yosys version. + + Returns + ------- + ``None`` if version number could not be determined, or a 3-tuple ``(major, minor, distance)`` if it could. + + major : int + Major version. + minor : int + Minor version. + distance : int + Distance to last tag per ``git describe``. May not be exact for system Yosys. + """ + ... + + @classmethod + def data_dir(cls) -> pathlib.Path: + """Get Yosys data directory. + + Returns + ------- + data_dir : pathlib.Path + Yosys data directory (also known as "datdir"). + """ + ... + + @classmethod + def run(cls, args: list[str], stdin: str=...) -> str: + """Run Yosys process. + + Parameters + ---------- + args : list of str + Arguments, not including the program name. + stdin : str + Standard input. + + Returns + ------- + stdout : str + Standard output. + + Exceptions + ---------- + YosysError + Raised if Yosys returns a non-zero code. The exception message is the standard error + output. + """ + ... + + + +class _BuiltinYosys(YosysBinary): + YOSYS_PACKAGE = ... + @classmethod + def available(cls): # -> bool: + ... + + @classmethod + def version(cls): # -> tuple[int, int, int]: + ... + + @classmethod + def data_dir(cls): # -> Traversable: + ... + + @classmethod + def run(cls, args, stdin=..., *, ignore_warnings=..., src_loc_at=...): + ... + + + +class _SystemYosys(YosysBinary): + YOSYS_BINARY = ... + @classmethod + def available(cls): # -> bool: + ... + + @classmethod + def version(cls): # -> tuple[int, int, int] | None: + ... + + @classmethod + def data_dir(cls): # -> Path: + ... + + @classmethod + def run(cls, args, stdin=..., *, ignore_warnings=..., src_loc_at=...): + ... + + + +def find_yosys(requirement): + """Find an available Yosys executable of required version. + + Parameters + ---------- + requirement : function + Version check. Should return ``True`` if the version is acceptable, ``False`` otherwise. + + Returns + ------- + yosys_binary : subclass of YosysBinary + Proxy for running the requested version of Yosys. + + Exceptions + ---------- + YosysError + Raised if required Yosys version is not found. + """ + ... + diff --git a/stubs/amaranth/_unused.pyi b/stubs/amaranth/_unused.pyi new file mode 100644 index 000000000..9af961a4a --- /dev/null +++ b/stubs/amaranth/_unused.pyi @@ -0,0 +1,13 @@ +import sys +import warnings + +__all__ = ["UnusedMustUse", "MustUse"] + + +class UnusedMustUse(Warning): + pass + + +class MustUse: + _MustUse__silence : bool + _MustUse__warning : UnusedMustUse diff --git a/stubs/amaranth/back/__init__.pyi b/stubs/amaranth/back/__init__.pyi new file mode 100644 index 000000000..006bc2749 --- /dev/null +++ b/stubs/amaranth/back/__init__.pyi @@ -0,0 +1,4 @@ +""" +This type stub file was generated by pyright. +""" + diff --git a/stubs/amaranth/back/verilog.pyi b/stubs/amaranth/back/verilog.pyi new file mode 100644 index 000000000..2850050a3 --- /dev/null +++ b/stubs/amaranth/back/verilog.pyi @@ -0,0 +1,14 @@ +""" +This type stub file was generated by pyright. +""" + +from .._toolchain.yosys import * +from ..hdl.ast import SignalDict + +__all__ = ["YosysError", "convert", "convert_fragment"] +def convert_fragment(*args, strip_internal_attrs=..., **kwargs) -> tuple[str, SignalDict]: + ... + +def convert(elaboratable, name=..., platform=..., *, ports=..., emit_src=..., strip_internal_attrs=..., **kwargs) -> str: + ... + diff --git a/stubs/amaranth/build/res.pyi b/stubs/amaranth/build/res.pyi index 41b734ca9..83a09d440 100644 --- a/stubs/amaranth/build/res.pyi +++ b/stubs/amaranth/build/res.pyi @@ -3,7 +3,7 @@ This type stub file was generated by pyright. """ from typing import Any -from ..hdl.ast import * +from ..hdl._ast import * from ..hdl.rec import * from ..lib.io import * from .dsl import * diff --git a/stubs/amaranth/hdl/__init__.pyi b/stubs/amaranth/hdl/__init__.pyi index 78c4f551c..c44fa4755 100644 --- a/stubs/amaranth/hdl/__init__.pyi +++ b/stubs/amaranth/hdl/__init__.pyi @@ -1,13 +1,29 @@ -""" -This type stub file was generated by pyright. -""" - -from .ast import Array, C, Cat, ClockSignal, Const, Mux, Repl, ResetSignal, Shape, Signal, Value, signed, unsigned -from .dsl import Module -from .cd import ClockDomain -from .ir import Elaboratable, Fragment, Instance -from .mem import Memory +from ._ast import Shape, unsigned, signed, ShapeCastable, ShapeLike +from ._ast import Value, ValueCastable, ValueLike +from ._ast import Const, C, Mux, Cat, Array, Signal, ClockSignal, ResetSignal +from ._dsl import SyntaxError, SyntaxWarning, Module +from ._cd import DomainError, ClockDomain +from ._ir import UnusedElaboratable, Elaboratable, DriverConflict, Fragment, Instance +from ._mem import Memory, ReadPort, WritePort, DummyPort from .rec import Record -from .xfrm import DomainRenamer, EnableInserter, ResetInserter +from ._xfrm import DomainRenamer, ResetInserter, EnableInserter + -__all__ = ["Shape", "unsigned", "signed", "Value", "Const", "C", "Mux", "Cat", "Repl", "Array", "Signal", "ClockSignal", "ResetSignal", "Module", "ClockDomain", "Elaboratable", "Fragment", "Instance", "Memory", "Record", "DomainRenamer", "ResetInserter", "EnableInserter"] +__all__ = [ + # _ast + "Shape", "unsigned", "signed", "ShapeCastable", "ShapeLike", + "Value", "ValueCastable", "ValueLike", + "Const", "C", "Mux", "Cat", "Array", "Signal", "ClockSignal", "ResetSignal", + # _dsl + "SyntaxError", "SyntaxWarning", "Module", + # _cd + "DomainError", "ClockDomain", + # _ir + "UnusedElaboratable", "Elaboratable", "DriverConflict", "Fragment", "Instance", + # _mem + "Memory", "ReadPort", "WritePort", "DummyPort", + # _rec + "Record", + # _xfrm + "DomainRenamer", "ResetInserter", "EnableInserter", +] diff --git a/stubs/amaranth/hdl/ast.pyi b/stubs/amaranth/hdl/_ast.pyi similarity index 95% rename from stubs/amaranth/hdl/ast.pyi rename to stubs/amaranth/hdl/_ast.pyi index fa115b316..8892c6f6e 100644 --- a/stubs/amaranth/hdl/ast.pyi +++ b/stubs/amaranth/hdl/_ast.pyi @@ -7,12 +7,14 @@ from collections.abc import Callable, MutableMapping, MutableSequence, MutableSe from typing import Any, Generic, Iterable, Iterator, Mapping, NoReturn, Optional, Sequence, TypeVar, final, overload from enum import Enum from transactron.utils import ValueLike, ShapeLike, StatementLike +from amaranth.lib.data import View __all__ = ["Shape", "ShapeCastable", "signed", "unsigned", "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl", "Array", "ArrayProxy", "Signal", "ClockSignal", "ResetSignal", "ValueCastable", "Sample", "Past", "Stable", "Rose", "Fell", "Initial", "Statement", "Switch", "Property", "Assign", "Assert", "Assume", "Cover", "ValueKey", "ValueDict", "ValueSet", "SignalKey", "SignalDict", "SignalSet", "ValueLike", "ShapeLike", "StatementLike", "SwitchKey"] T = TypeVar("T") U = TypeVar("U") +_T_ShapeCastable = TypeVar("_T_ShapeCastable", bound=ShapeCastable, covariant=True) Flattenable = T | Iterable[Flattenable[T]] SwitchKey = str | int | Enum @@ -408,14 +410,14 @@ class Repl(Value): class _SignalMeta(ABCMeta): @overload - def __call__(cls, shape: ShapeCastable[T], src_loc_at = ..., **kwargs) -> T: + def __call__(cls, shape: ShapeCastable[T], src_loc_at: int = ..., **kwargs) -> T: ... @overload - def __call__(cls, shape = ..., src_loc_at = ..., **kwargs) -> Signal: + def __call__(cls, shape: ShapeLike = ..., src_loc_at: int = ..., **kwargs) -> Signal: ... - def __call__(cls, shape = ..., src_loc_at = ..., **kwargs): + def __call__(cls, shape: ShapeLike = ..., src_loc_at: int = ..., **kwargs): ... @@ -425,9 +427,19 @@ class Signal(Value, DUID, metaclass=_SignalMeta): Pa""" def __init__(self, shape: Optional[ShapeLike] = ..., *, name: Optional[str] = ..., reset: int | Enum = ..., reset_less: bool = ..., attrs: dict = ..., decoder: type[Enum] | Callable[[int], str] = ..., src_loc_at=...) -> None: ... - + + @overload + @staticmethod + def like(other: View[_T_ShapeCastable], *, name: Optional[str] = ..., name_suffix: Optional[str] =..., src_loc_at=..., **kwargs) -> View[_T_ShapeCastable]: + ... + + @overload @staticmethod def like(other: ValueLike, *, name: Optional[str] = ..., name_suffix: Optional[str] =..., src_loc_at=..., **kwargs) -> Signal: + ... + + @staticmethod + def like(other: ValueLike, *, name: Optional[str] = ..., name_suffix: Optional[str] =..., src_loc_at=..., **kwargs): """Create Signal based on another. """ ... @@ -438,6 +450,7 @@ class Signal(Value, DUID, metaclass=_SignalMeta): def __repr__(self) -> str: ... + name: str decoder: Any diff --git a/stubs/amaranth/hdl/cd.pyi b/stubs/amaranth/hdl/_cd.pyi similarity index 100% rename from stubs/amaranth/hdl/cd.pyi rename to stubs/amaranth/hdl/_cd.pyi diff --git a/stubs/amaranth/hdl/dsl.pyi b/stubs/amaranth/hdl/_dsl.pyi similarity index 98% rename from stubs/amaranth/hdl/dsl.pyi rename to stubs/amaranth/hdl/_dsl.pyi index 9d0c07f76..4658a197e 100644 --- a/stubs/amaranth/hdl/dsl.pyi +++ b/stubs/amaranth/hdl/_dsl.pyi @@ -3,8 +3,7 @@ This type stub file was generated by pyright. """ from contextlib import _GeneratorContextManager, contextmanager -from typing import Callable, ContextManager, Iterator, NoReturn, OrderedDict, ParamSpec, TypeVar, Optional -from typing_extensions import Self +from typing import Callable, ContextManager, Iterator, NoReturn, OrderedDict, ParamSpec, TypeVar, Optional, Self from transactron.utils import HasElaborate from .ast import * from .ast import Flattenable diff --git a/stubs/amaranth/hdl/ir.pyi b/stubs/amaranth/hdl/_ir.pyi similarity index 88% rename from stubs/amaranth/hdl/ir.pyi rename to stubs/amaranth/hdl/_ir.pyi index 63acd1e3c..556a0f679 100644 --- a/stubs/amaranth/hdl/ir.pyi +++ b/stubs/amaranth/hdl/_ir.pyi @@ -3,13 +3,17 @@ This type stub file was generated by pyright. """ from abc import abstractmethod -from .ast import * -from .cd import * +from ._ast import * +from ._cd import * +from .. import _unused from transactron.utils import HasElaborate -__all__ = ["Elaboratable", "DriverConflict", "Fragment", "Instance"] +__all__ = ["UnusedElaboratable", "Elaboratable", "DriverConflict", "Fragment", "Instance"] +class UnusedElaboratable(_unused.UnusedMustUse): + ... + class Elaboratable(): @abstractmethod def elaborate(self, platform) -> HasElaborate: diff --git a/stubs/amaranth/hdl/mem.pyi b/stubs/amaranth/hdl/_mem.pyi similarity index 60% rename from stubs/amaranth/hdl/mem.pyi rename to stubs/amaranth/hdl/_mem.pyi index e7ed09e43..ddd629e06 100644 --- a/stubs/amaranth/hdl/mem.pyi +++ b/stubs/amaranth/hdl/_mem.pyi @@ -4,13 +4,33 @@ This type stub file was generated by pyright. from typing import Optional from .ast import * -from .ir import Elaboratable +from .ir import Elaboratable, Fragment -__all__ = ["Memory", "ReadPort", "WritePort", "DummyPort"] +__all__ = ["Memory", "ReadPort", "WritePort", "DummyPort", "MemoryInstance"] class Memory: """A word addressable storage. - - """ + Parameters + ---------- + width : int + Access granularity. Each storage element of this memory is ``width`` bits in size. + depth : int + Word count. This memory contains ``depth`` storage elements. + init : list of int + Initial values. At power on, each storage element in this memory is initialized to + the corresponding element of ``init``, if any, or to zero otherwise. + Uninitialized memories are not currently supported. + name : str + Name hint for this memory. If ``None`` (default) the name is inferred from the variable + name this ``Signal`` is assigned to. + attrs : dict + Dictionary of synthesis attributes. + Attributes + ---------- + width : int + depth : int + init : list of int + attrs : dict + """ width: int depth: int attrs: dict @@ -90,4 +110,10 @@ class DummyPort: ... - +class MemoryInstance(Fragment): + memory: Memory + read_ports: list[ReadPort] + write_ports: list[WritePort] + attrs: dict + def __init__(self, memory: Memory, read_ports: list[ReadPort], write_ports: list[WritePort]) -> None: + ... diff --git a/stubs/amaranth/hdl/xfrm.pyi b/stubs/amaranth/hdl/_xfrm.pyi similarity index 100% rename from stubs/amaranth/hdl/xfrm.pyi rename to stubs/amaranth/hdl/_xfrm.pyi diff --git a/stubs/amaranth/lib/data.pyi b/stubs/amaranth/lib/data.pyi index 46d13be39..52c5c0180 100644 --- a/stubs/amaranth/lib/data.pyi +++ b/stubs/amaranth/lib/data.pyi @@ -4,10 +4,9 @@ This type stub file was generated by pyright. from abc import ABCMeta, abstractmethod from collections.abc import Iterator, Mapping -from typing import TypeVar, Generic -from typing_extensions import Self +from typing import TypeVar, Generic, Self from amaranth.hdl import * -from amaranth.hdl.ast import Assign, ShapeCastable, ValueCastable +from amaranth.hdl._ast import Assign, ShapeCastable, ValueCastable from transactron.utils._typing import ShapeLike, ValueLike __all__ = ["Field", "Layout", "StructLayout", "UnionLayout", "ArrayLayout", "FlexibleLayout", "View", "Struct", "Union"] diff --git a/stubs/amaranth/lib/enum.pyi b/stubs/amaranth/lib/enum.pyi index 9c2d71e73..4ef5262f2 100644 --- a/stubs/amaranth/lib/enum.pyi +++ b/stubs/amaranth/lib/enum.pyi @@ -3,10 +3,9 @@ This type stub file was generated by pyright. """ import enum as py_enum -from typing import Generic, Optional, TypeVar, overload -from typing_extensions import Self +from typing import Generic, Optional, TypeVar, Self, overload from amaranth import * -from ..hdl.ast import Assign, ValueCastable, ShapeCastable, ValueLike +from ..hdl._ast import Assign, ValueCastable, ShapeCastable, ValueLike __all__ = ['EnumMeta', 'Enum', 'IntEnum', 'Flag', 'IntFlag', 'EnumView', 'FlagView', 'auto', 'unique'] diff --git a/stubs/amaranth/lib/fifo.pyi b/stubs/amaranth/lib/fifo.pyi index a64c5e8e5..e799e7223 100644 --- a/stubs/amaranth/lib/fifo.pyi +++ b/stubs/amaranth/lib/fifo.pyi @@ -11,7 +11,6 @@ __all__ = ["FIFOInterface", "SyncFIFO", "SyncFIFOBuffered", "AsyncFIFO", "AsyncF class FIFOInterface: width: int depth: int - fwft: bool w_data: Signal w_rdy: Signal w_en: Signal @@ -20,13 +19,13 @@ class FIFOInterface: r_rdy: Signal r_en: Signal r_level: Signal - def __init__(self, *, width: int, depth: int, fwft: bool) -> None: + def __init__(self, *, width: int, depth: int) -> None: ... class SyncFIFO(Elaboratable, FIFOInterface): - def __init__(self, *, width: int, depth: int, fwft: bool = ...) -> None: + def __init__(self, *, width: int, depth: int) -> None: ... def elaborate(self, platform) -> HasElaborate: diff --git a/stubs/amaranth/lib/wiring.pyi b/stubs/amaranth/lib/wiring.pyi new file mode 100644 index 000000000..9565301f9 --- /dev/null +++ b/stubs/amaranth/lib/wiring.pyi @@ -0,0 +1,1143 @@ +""" +This type stub file was generated by pyright. +""" + +import enum +from collections.abc import Mapping, Iterator +from typing import NoReturn, Literal, TypeVar, Generic, Any, Self, Optional, overload +from ..hdl._ir import Elaboratable +from .._utils import final +from transactron.utils._typing import ShapeLike, ValueLike, AbstractInterface, AbstractSignature, ModuleLike + +__all__ = ["In", "Out", "Signature", "PureInterface", "connect", "flipped", "Component"] + +_T_Signature = TypeVar("_T_Signature", bound=AbstractSignature) +_T_SignatureMembers = TypeVar("_T_SignatureMembers", bound=SignatureMembers) +_T_Interface = TypeVar("_T_Interface", bound=AbstractInterface) +_T = TypeVar("_T") + +class Flow(enum.Enum): + """Direction of data flow. This enumeration has two values, :attr:`Out` and :attr:`In`, + the meaning of which depends on the context in which they are used. + """ + Out = "out" + In = "in" + def flip(self) -> Flow: + """Flip the direction of data flow. + + Returns + ------- + :class:`Flow` + :attr:`In` if called as :pc:`Out.flip()`; :attr:`Out` if called as :pc:`In.flip()`. + """ + ... + + def __call__(self, description: Signature | ShapeLike, *, reset=...) -> Member: + """Create a :class:`Member` with this data flow and the provided description and + reset value. + + Returns + ------- + :class:`Member` + :pc:`Member(self, description, reset=reset)` + """ + ... + + def __repr__(self) -> Literal['Out', 'In']: + ... + + def __str__(self) -> str: + ... + + + +Out = Flow.Out +In = Flow.In + +@final +class Member: + """Description of a signature member. + + This class is a discriminated union: its instances describe either a `port member` or + a `signature member`, and accessing properties for the wrong kind of member raises + an :exc:`AttributeError`. + + The class is created from a `description`: a :class:`Signature` instance (in which case + the :class:`Member` is created as a signature member), or a :ref:`shape-like ` + object (in which case it is created as a port member). After creation the :class:`Member` + instance cannot be modified. + + When a :class:`Signal` is created from a description of a port member, the signal's reset value + is taken from the member description. If this signal is never explicitly assigned a value, it + will equal ``reset``. + + Although instances can be created directly, most often they will be created through + :data:`In` and :data:`Out`, e.g. :pc:`In(unsigned(1))` or :pc:`Out(stream.Signature(RGBPixel))`. + """ + def __init__(self, flow: Flow, description: Signature | ShapeLike, *, reset=..., _dimensions=...) -> None: + ... + + def flip(self) -> Member: + """Flip the data flow of this member. + + Returns + ------- + :class:`Member` + A new :pc:`member` with :pc:`member.flow` equal to :pc:`self.flow.flip()`, and identical + to :pc:`self` other than that. + """ + ... + + def array(self, *dimensions) -> Member: + """Add array dimensions to this member. + + The dimensions passed to this method are `prepended` to the existing dimensions. + For example, :pc:`Out(1).array(2)` describes an array of 2 elements, whereas both + :pc:`Out(1).array(2, 3)` and :pc:`Out(1).array(3).array(2)` both describe a two dimensional + array of 2 by 3 elements. + + Dimensions are passed to :meth:`array` in the order in which they would be indexed. + That is, :pc:`.array(x, y)` creates a member that can be indexed up to :pc:`[x-1][y-1]`. + + The :meth:`array` method is composable: calling :pc:`member.array(x)` describes an array of + :pc:`x` members even if :pc:`member` was already an array. + + Returns + ------- + :class:`Member` + A new :pc:`member` with :pc:`member.dimensions` extended by :pc:`dimensions`, and + identical to :pc:`self` other than that. + """ + ... + + @property + def flow(self) -> Flow: + """Data flow of this member. + + Returns + ------- + :class:`Flow` + """ + ... + + @property + def is_port(self) -> bool: + """Whether this is a description of a port member. + + Returns + ------- + :class:`bool` + :pc:`True` if this is a description of a port member, + :pc:`False` if this is a description of a signature member. + """ + ... + + @property + def is_signature(self) -> bool: + """Whether this is a description of a signature member. + + Returns + ------- + :class:`bool` + :pc:`True` if this is a description of a signature member, + :pc:`False` if this is a description of a port member. + """ + ... + + @property + def shape(self) -> ShapeLike: + """Shape of a port member. + + Returns + ------- + :ref:`shape-like object ` + The shape that was provided when constructing this :class:`Member`. + + Raises + ------ + :exc:`AttributeError` + If :pc:`self` describes a signature member. + """ + ... + + @property + def reset(self): # -> None: + """Reset value of a port member. + + Returns + ------- + :ref:`const-castable object ` + The reset value that was provided when constructing this :class:`Member`. + + Raises + ------ + :exc:`AttributeError` + If :pc:`self` describes a signature member. + """ + ... + + @property + def signature(self) -> Signature: + """Signature of a signature member. + + Returns + ------- + :class:`Signature` + The signature that was provided when constructing this :class:`Member`. + + Raises + ------ + :exc:`AttributeError` + If :pc:`self` describes a port member. + """ + ... + + @property + def dimensions(self) -> tuple[int, ...]: + """Array dimensions. + + A member will usually have no dimensions; in this case it does not describe an array. + A single dimension describes one-dimensional array, and so on. + + Returns + ------- + :class:`tuple` of :class:`int` + Dimensions, if any, of this member, from most to least major. + """ + ... + + def __eq__(self, other) -> bool: + ... + + def __repr__(self) -> str: + ... + + + +@final +class SignatureError(Exception): + """ + This exception is raised when an invalid operation specific to signature manipulation is + performed with :class:`SignatureMembers`, such as adding a member to a frozen signature. + Other exceptions, such as :exc:`TypeError` or :exc:`NameError`, will still be raised where + appropriate. + """ + ... + + +@final +class SignatureMembers(Mapping[str, Member]): + """Mapping of signature member names to their descriptions. + + This container, a :class:`collections.abc.Mapping`, is used to implement the :pc:`members` + attribute of signature objects. + + The keys in this container must be valid Python attribute names that are public (do not begin + with an underscore. The values must be instances of :class:`Member`. The container is mutable + in a restricted manner: new keys may be added, but existing keys may not be modified or removed. + In addition, the container can be `frozen`, which disallows addition of new keys. Freezing + a container recursively freezes the members of any signatures inside. + + In addition to the use of the superscript operator, multiple members can be added at once with + the :pc:`+=` opreator. + + The :meth:`create` method converts this mapping into a mapping of names to signature members + (signals and interface objects) by creating them from their descriptions. The created mapping + can be used to populate an interface object. + """ + def __init__(self, members: Mapping[str, Member]=...) -> None: + ... + + def flip(self) -> FlippedSignatureMembers[Self]: + """Flip the data flow of the members in this mapping. + + Returns + ------- + :class:`FlippedSignatureMembers` + Proxy collection :pc:`FlippedSignatureMembers(self)` that flips the data flow of + the members that are accessed using it. + """ + ... + + def __eq__(self, other) -> bool: + """Compare the members in this and another mapping. + + Returns + ------- + :class:`bool` + :pc:`True` if the mappings contain the same key-value pairs, :pc:`False` otherwise. + """ + ... + + def __contains__(self, name: str) -> bool: + """Check whether a member with a given name exists. + + Returns + ------- + :class:`bool` + """ + ... + + def __getitem__(self, name: str) -> Member: + """Retrieves the description of a member with a given name. + + Returns + ------- + :class:`Member` + + Raises + ------ + :exc:`TypeError` + If :pc:`name` is not a string. + :exc:`NameError` + If :pc:`name` is not a valid, public Python attribute name. + :exc:`SignatureError` + If a member called :pc:`name` does not exist in the collection. + """ + ... + + def __setitem__(self, name: str, member: Member) -> NoReturn: + """Stub that forbids addition of members to the collection. + + Raises + ------ + :exc:`SignatureError` + Always. + """ + ... + + def __delitem__(self, name: str) -> NoReturn: + """Stub that forbids removal of members from the collection. + + Raises + ------ + :exc:`SignatureError` + Always. + """ + ... + + def __iter__(self) -> Iterator[str]: + """Iterate through the names of members in the collection. + + Returns + ------- + iterator of :class:`str` + Names of members, in the order of insertion. + """ + ... + + def __len__(self) -> int: + ... + + def flatten(self, *, path: tuple[str | int, ...]=...) -> Iterator[tuple[tuple[str | int, ...], Member]]: + """Recursively iterate through this collection. + + .. note:: + + The :ref:`paths ` returned by this method and by :meth:`Signature.flatten` + differ. This method yields a single result for each :class:`Member` in the collection, + disregarding their dimensions: + + .. doctest:: + + >>> sig = wiring.Signature({ + ... "items": In(1).array(2) + ... }) + >>> list(sig.members.flatten()) + [(('items',), In(1).array(2))] + + The :meth:`Signature.flatten` method yields multiple results for such a member; see + the documentation for that method for an example. + + Returns + ------- + iterator of (:class:`tuple` of :class:`str`, :class:`Member`) + Pairs of :ref:`paths ` and the corresponding members. A path yielded by + this method is a tuple of strings where each item is a key through which the item may + be reached. + """ + ... + + def create(self, *, path: tuple[str | int, ...] =..., src_loc_at: int =...) -> dict[str, Any]: + """Create members from their descriptions. + + For each port member, this function creates a :class:`Signal` with the shape and reset + value taken from the member description, and the name constructed from + the :ref:`paths ` to the member (by concatenating path items with a double + underscore, ``__``). + + For each signature member, this function calls :meth:`Signature.create` for that signature. + The resulting object can have any type if a :class:`Signature` subclass overrides + the :class:`create` method. + + If the member description includes dimensions, in each case, instead of a single member, + a :class:`list` of members is created for each dimension. (That is, for a single dimension + a list of members is returned, for two dimensions a list of lists is returned, and so on.) + + Returns + ------- + dict of :class:`str` to :ref:`value-like ` or interface object or a potentially nested list of these + Mapping of names to actual signature members. + """ + ... + + def __repr__(self) -> str: + ... + + + +@final +class FlippedSignatureMembers(Mapping[str, Member], Generic[_T_SignatureMembers]): + """Mapping of signature member names to their descriptions, with the directions flipped. + + Although an instance of :class:`FlippedSignatureMembers` could be created directly, it will + be usually created by a call to :meth:`SignatureMembers.flip`. + + This container is a wrapper around :class:`SignatureMembers` that contains the same members + as the inner mapping, but flips their data flow when they are accessed. For example: + + .. testcode:: + + members = wiring.SignatureMembers({"foo": Out(1)}) + + flipped_members = members.flip() + assert flipped_members["foo"].flow == In + + This class implements the same methods, with the same functionality (other than the flipping of + the data flow), as the :class:`SignatureMembers` class; see the documentation for that class + for details. + """ + def __init__(self, unflipped: _T_SignatureMembers) -> None: + ... + + def flip(self) -> _T_SignatureMembers: + """ + Flips this mapping back to the original one. + + Returns + ------- + :class:`SignatureMembers` + :pc:`unflipped` + """ + ... + + def __eq__(self, other) -> bool: + """Compare the members in this and another mapping. + + Returns + ------- + :class:`bool` + :pc:`True` if the mappings contain the same key-value pairs, :pc:`False` otherwise. + """ + ... + + def __contains__(self, name: str) -> bool: + ... + + def __getitem__(self, name: str) -> Member: + ... + + def __setitem__(self, name: str, member: Member) -> NoReturn: + ... + + def __delitem__(self, name: str) -> NoReturn: + ... + + def __iter__(self) -> Iterator[str]: + ... + + def __len__(self) -> int: + ... + + def flatten(self, *, path: tuple[str | int, ...] = ...) -> Iterator[tuple[tuple[str | int, ...], Member]]: + """Recursively iterate through this collection. + + .. note:: + + The :ref:`paths ` returned by this method and by :meth:`Signature.flatten` + differ. This method yields a single result for each :class:`Member` in the collection, + disregarding their dimensions: + + .. doctest:: + + >>> sig = wiring.Signature({ + ... "items": In(1).array(2) + ... }) + >>> list(sig.members.flatten()) + [(('items',), In(1).array(2))] + + The :meth:`Signature.flatten` method yields multiple results for such a member; see + the documentation for that method for an example. + + Returns + ------- + iterator of (:class:`tuple` of :class:`str`, :class:`Member`) + Pairs of :ref:`paths ` and the corresponding members. A path yielded by + this method is a tuple of strings where each item is a key through which the item may + be reached. + """ + ... + + def create(self, *, path: tuple[str | int, ...] =..., src_loc_at: int =...) -> dict[str, Any]: + """Create members from their descriptions. + + For each port member, this function creates a :class:`Signal` with the shape and reset + value taken from the member description, and the name constructed from + the :ref:`paths ` to the member (by concatenating path items with a double + underscore, ``__``). + + For each signature member, this function calls :meth:`Signature.create` for that signature. + The resulting object can have any type if a :class:`Signature` subclass overrides + the :class:`create` method. + + If the member description includes dimensions, in each case, instead of a single member, + a :class:`list` of members is created for each dimension. (That is, for a single dimension + a list of members is returned, for two dimensions a list of lists is returned, and so on.) + + Returns + ------- + dict of :class:`str` to :ref:`value-like ` or interface object or a potentially nested list of these + Mapping of names to actual signature members. + """ + ... + + def __repr__(self) -> str: + ... + + + +class SignatureMeta(type): + """Metaclass for :class:`Signature` that makes :class:`FlippedSignature` its + 'virtual subclass'. + + The object returned by :meth:`Signature.flip` is an instance of :class:`FlippedSignature`. + It implements all of the methods :class:`Signature` has, and for subclasses of + :class:`Signature`, it implements all of the methods defined on the subclass as well. + This makes it effectively a subtype of :class:`Signature` (or a derived class of it), but this + relationship is not captured by the Python type system: :class:`FlippedSignature` only has + :class:`object` as its base class. + + This metaclass extends :func:`issubclass` and :func:`isinstance` so that they take into + account the subtyping relationship between :class:`Signature` and :class:`FlippedSignature`, + described below. + """ + def __subclasscheck__(cls, subclass) -> bool: + """ + Override of :pc:`issubclass(cls, Signature)`. + + In addition to the standard behavior of :func:`issubclass`, this override makes + :class:`FlippedSignature` a subclass of :class:`Signature` or any of its subclasses. + """ + ... + + def __instancecheck__(cls, instance) -> bool: + """ + Override of :pc:`isinstance(obj, Signature)`. + + In addition to the standard behavior of :func:`isinstance`, this override makes + :pc:`isinstance(obj, cls)` act as :pc:`isinstance(obj.flip(), cls)` where + :pc:`obj` is an instance of :class:`FlippedSignature`. + """ + ... + + + +class Signature(metaclass=SignatureMeta): + """Description of an interface object. + + An interface object is a Python object that has a :pc:`signature` attribute containing + a :class:`Signature` object, as well as an attribute for every member of its signature. + Signatures and interface objects are tightly linked: an interface object can be created out + of a signature, and the signature is used when :func:`connect` ing two interface objects + together. See the :ref:`introduction to interfaces ` for a more detailed + explanation of why this is useful. + + :class:`Signature` can be used as a base class to define :ref:`customized ` + signatures and interface objects. + + .. important:: + + :class:`Signature` objects are immutable. Classes inheriting from :class:`Signature` must + ensure this remains the case when additional functionality is added. + """ + def __init__(self, members: Mapping[str, Member]) -> None: + ... + + def flip(self) -> FlippedSignature[Self]: + """Flip the data flow of the members in this signature. + + Returns + ------- + :class:`FlippedSignature` + Proxy object :pc:`FlippedSignature(self)` that flips the data flow of the attributes + corresponding to the members that are accessed using it. + + See the documentation for the :class:`FlippedSignature` class for a detailed discussion + of how this proxy object works. + """ + ... + + @property + def members(self) -> SignatureMembers: + """Members in this signature. + + Returns + ------- + :class:`SignatureMembers` + """ + ... + + def __eq__(self, other) -> bool: + """Compare this signature with another. + + The behavior of this operator depends on the types of the arguments. If both :pc:`self` + and :pc:`other` are instances of the base :class:`Signature` class, they are compared + structurally (the result is :pc:`self.members == other.members`); otherwise they are + compared by identity (the result is :pc:`self is other`). + + Subclasses of :class:`Signature` are expected to override this method to take into account + the specifics of the domain. If the subclass has additional properties that do not influence + the :attr:`members` dictionary but nevertheless make its instance incompatible with other + instances (for example, whether the feedback is combinational or registered), + the overridden method must take that into account. + + Returns + ------- + :class:`bool` + """ + ... + + def flatten(self, obj) -> Iterator[tuple[tuple[str | int, ...], Flow, ValueLike]]: + """Recursively iterate through this signature, retrieving member values from an interface + object. + + .. note:: + + The :ref:`paths ` returned by this method and by + :meth:`SignatureMembers.flatten` differ. This method yield several results for each + :class:`Member` in the collection that has a dimension: + + .. doctest:: + :options: +NORMALIZE_WHITESPACE + + >>> sig = wiring.Signature({ + ... "items": In(1).array(2) + ... }) + >>> obj = sig.create() + >>> list(sig.flatten(obj)) + [(('items', 0), In(1), (sig obj__items__0)), + (('items', 1), In(1), (sig obj__items__1))] + + The :meth:`SignatureMembers.flatten` method yields one result for such a member; see + the documentation for that method for an example. + + Returns + ------- + iterator of (:class:`tuple` of :class:`str` or :class:`int`, :class:`Flow`, :ref:`value-like `) + Tuples of :ref:`paths `, flow, and the corresponding member values. A path + yielded by this method is a tuple of strings or integers where each item is an attribute + name or index (correspondingly) using which the member value was retrieved. + """ + ... + + def is_compliant(self, obj, *, reasons: Optional[list[str]] =..., path: tuple[str, ...] =...) -> bool: + """Check whether an object matches the description in this signature. + + This module places few restrictions on what an interface object may be; it does not + prescribe a specific base class or a specific way of constructing the object, only + the values that its attributes should have. This method ensures consistency between + the signature and the interface object, checking every aspect of the provided interface + object for compliance with the signature. + + It verifies that: + + * :pc:`obj` has a :pc:`signature` attribute whose value a :class:`Signature` instance + such that ``self == obj.signature``; + * for each member, :pc:`obj` has an attribute with the same name, whose value: + + * for members with :meth:`dimensions ` specified, contains a list or + a tuple (or several levels of nested lists or tuples, for multiple dimensions) + satisfying the requirements below; + * for port members, is a :ref:`value-like ` object casting to + a :class:`Signal` or a :class:`Const` whose width and signedness is the same as that + of the member, and (in case of a :class:`Signal`) which is not reset-less and whose + reset value is that of the member; + * for signature members, matches the description in the signature as verified by + :meth:`Signature.is_compliant`. + + If the verification fails, this method reports the reason(s) by filling the :pc:`reasons` + container. These reasons are intended to be human-readable: more than one reason may be + reported but only in cases where this is helpful (e.g. the same error message will not + repeat 10 times for each of the 10 ports in a list). + + Arguments + --------- + reasons : :class:`list` or :pc:`None` + If provided, a container that receives diagnostic messages. + path : :class:`tuple` of :class:`str` + The :ref:`path ` to :pc:`obj`. Could be set to improve diagnostic + messages if :pc:`obj` is nested within another object, or for clarity. + + Returns + ------- + :class:`bool` + :pc:`True` if :pc:`obj` matches the description in this signature, :pc:`False` + otherwise. If :pc:`False` and :pc:`reasons` was not :pc:`None`, it will contain + a detailed explanation why. + """ + ... + + def create(self, *, path: tuple[str | int, ...]=..., src_loc_at: int =...) -> AbstractInterface[Self]: + """Create an interface object from this signature. + + The default :meth:`Signature.create` implementation consists of one line: + + .. code:: + + def create(self, *, path=None, src_loc_at=0): + return PureInterface(self, path=path, src_loc_at=1 + src_loc_at) + + This implementation creates an interface object from this signature that serves purely + as a container for the attributes corresponding to the signature members, and implements + no behavior. Such an implementation is sufficient for signatures created ad-hoc using + the :pc:`Signature({ ... })` constructor as well as simple signature subclasses. + + When defining a :class:`Signature` subclass that needs to customize the behavior of + the created interface objects, override this method with a similar implementation + that references the class of your custom interface object: + + .. testcode:: + + class CustomSignature(wiring.Signature): + def create(self, *, path=None, src_loc_at=0): + return CustomInterface(self, path=path, src_loc_at=1 + src_loc_at) + + class CustomInterface(wiring.PureInterface): + @property + def my_property(self): + ... + + The :pc:`path` and :pc:`src_loc_at` arguments are necessary to ensure the generated signals + have informative names and accurate source location information. + + The custom :meth:`create` method may take positional or keyword arguments in addition to + the two listed above. Such arguments must have a default value, because + the :meth:`SignatureMembers.create` method will call the :meth:`Signature.create` member + without these additional arguments when this signature is a member of another signature. + """ + ... + + def __repr__(self) -> str: + ... + + + +@final +class FlippedSignature(Generic[_T_Signature]): + """Description of an interface object, with the members' directions flipped. + + Although an instance of :class:`FlippedSignature` could be created directly, it will be usually + created by a call to :meth:`Signature.flip`. + + This proxy is a wrapper around :class:`Signature` that contains the same description as + the inner mapping, but flips the members' data flow when they are accessed. It is useful + because :class:`Signature` objects are mutable and may include custom behavior, and if one was + copied (rather than wrapped) by :meth:`Signature.flip`, the wrong object would be mutated, and + custom behavior would be unavailable. + + For example: + + .. testcode:: + + sig = wiring.Signature({"foo": Out(1)}) + + flipped_sig = sig.flip() + assert flipped_sig.members["foo"].flow == In + + sig.attr = 1 + assert flipped_sig.attr == 1 + flipped_sig.attr += 1 + assert sig.attr == flipped_sig.attr == 2 + + This class implements the same methods, with the same functionality (other than the flipping of + the members' data flow), as the :class:`Signature` class; see the documentation for that class + for details. + + It is not possible to inherit from :class:`FlippedSignature` and :meth:`Signature.flip` must not + be overridden. If a :class:`Signature` subclass defines a method and this method is called on + a flipped instance of the subclass, it receives the flipped instance as its :pc:`self` argument. + To distinguish being called on the flipped instance from being called on the unflipped one, use + :pc:`isinstance(self, FlippedSignature)`: + + .. testcode:: + + class SignatureKnowsWhenFlipped(wiring.Signature): + @property + def is_flipped(self): + return isinstance(self, wiring.FlippedSignature) + + sig = SignatureKnowsWhenFlipped({}) + assert sig.is_flipped == False + assert sig.flip().is_flipped == True + """ + def __init__(self, signature: _T_Signature) -> None: + ... + + def flip(self) -> _T_Signature: + """ + Flips this signature back to the original one. + + Returns + ------- + :class:`Signature` + :pc:`unflipped` + """ + ... + + @property + def members(self) -> FlippedSignatureMembers: + ... + + def __eq__(self, other) -> bool: + ... + + def flatten(self, obj) -> Iterator[tuple[tuple[str | int, ...], Flow, ValueLike]]: + ... + + def is_compliant(self, obj, *, reasons: Optional[list[str]] =..., path: tuple[str, ...] =...) -> bool: + ... + + def __getattr__(self, name) -> Any: + """Retrieves attribute or method :pc:`name` of the unflipped signature. + + Performs :pc:`getattr(unflipped, name)`, ensuring that, if :pc:`name` refers to a property + getter or a method, its :pc:`self` argument receives the *flipped* signature. A class + method's :pc:`cls` argument receives the class of the *unflipped* signature, as usual. + """ + ... + + def __setattr__(self, name, value) -> None: + """Assigns attribute :pc:`name` of the unflipped signature to ``value``. + + Performs :pc:`setattr(unflipped, name, value)`, ensuring that, if :pc:`name` refers to + a property setter, its :pc:`self` argument receives the flipped signature. + """ + ... + + def __delattr__(self, name) -> None: + """Removes attribute :pc:`name` of the unflipped signature. + + Performs :pc:`delattr(unflipped, name)`, ensuring that, if :pc:`name` refers to a property + deleter, its :pc:`self` argument receives the flipped signature. + """ + ... + + def create(self, *args, path: tuple[str | int, ...] =..., src_loc_at: int =..., **kwargs) -> FlippedInterface: + ... + + def __repr__(self) -> str: + ... + + + +class PureInterface(Generic[_T_Signature]): + """A helper for constructing ad-hoc interfaces. + + The :class:`PureInterface` helper primarily exists to be used by the default implementation of + :meth:`Signature.create`, but it can also be used in any other context where an interface + object needs to be created without the overhead of defining a class for it. + + .. important:: + + Any object can be an interface object; it only needs a :pc:`signature` property containing + a compliant signature. It is **not** necessary to use :class:`PureInterface` in order to + create an interface object, but it may be used either directly or as a base class whenever + it is convenient to do so. + """ + signature: _T_Signature + + def __init__(self, signature: _T_Signature, *, path: tuple[str | int, ...]=..., src_loc_at: int =...) -> None: + """Create attributes from a signature. + + The sole method defined by this helper is its constructor, which only defines + the :pc:`self.signature` attribute as well as the attributes created from the signature + members: + + .. code:: + + def __init__(self, signature, *, path): + self.__dict__.update({ + "signature": signature, + **signature.members.create(path=path) + }) + + .. note:: + + This implementation can be copied and reused in interface objects that *do* include + custom behavior, if the signature serves as the source of truth for attributes + corresponding to its members. Although it is less repetitive, this approach can confuse + IDEs and type checkers. + """ + ... + + def __repr__(self) -> str: + ... + + + +@final +class FlippedInterface(Generic[_T_Signature, _T_Interface]): + """An interface object, with its members' directions flipped. + + An instance of :class:`FlippedInterface` should only be created by calling :func:`flipped`, + which ensures that a :pc:`FlippedInterface(FlippedInterface(...))` object is never created. + + This proxy wraps any interface object and forwards attribute and method access to the wrapped + interface object while flipping its signature and the values of any attributes corresponding to + interface members. It is useful because interface objects may be mutable or include custom + behavior, and explicitly keeping track of whether the interface object is flipped would be very + burdensome. + + For example: + + .. testcode:: + + intf = wiring.PureInterface(wiring.Signature({"foo": Out(1)}), path=()) + + flipped_intf = wiring.flipped(intf) + assert flipped_intf.signature.members["foo"].flow == In + + intf.attr = 1 + assert flipped_intf.attr == 1 + flipped_intf.attr += 1 + assert intf.attr == flipped_intf.attr == 2 + + It is not possible to inherit from :class:`FlippedInterface`. If an interface object class + defines a method or a property and it is called on the flipped interface object, the method + receives the flipped interface object as its :pc:`self` argument. To distinguish being called + on the flipped interface object from being called on the unflipped one, use + :pc:`isinstance(self, FlippedInterface)`: + + .. testcode:: + + class InterfaceKnowsWhenFlipped: + signature = wiring.Signature({}) + + @property + def is_flipped(self): + return isinstance(self, wiring.FlippedInterface) + + intf = InterfaceKnowsWhenFlipped() + assert intf.is_flipped == False + assert wiring.flipped(intf).is_flipped == True + """ + def __init__(self, interface: _T_Interface) -> None: + ... + + # not true -- this is a property -- but required for clean typing + signature: _T_Signature +# @property +# def signature(self) -> _T_Signature: +# """Signature of the flipped interface. +# +# Returns +# ------- +# Signature +# :pc:`unflipped.signature.flip()` +# """ +# ... + + def __eq__(self, other) -> bool: + """Compare this flipped interface with another. + + Returns + ------- + bool + :pc:`True` if :pc:`other` is an instance :pc:`FlippedInterface(other_unflipped)` where + :pc:`unflipped == other_unflipped`, :pc:`False` otherwise. + """ + ... + + def __getattr__(self, name) -> Any: + """Retrieves attribute or method :pc:`name` of the unflipped interface. + + Performs :pc:`getattr(unflipped, name)`, with the following caveats: + + 1. If :pc:`name` refers to a signature member, the returned interface object is flipped. + 2. If :pc:`name` refers to a property getter or a method, its :pc:`self` argument receives + the *flipped* interface. A class method's :pc:`cls` argument receives the class of + the *unflipped* interface, as usual. + """ + ... + + def __setattr__(self, name, value) -> None: + """Assigns attribute :pc:`name` of the unflipped interface to ``value``. + + Performs :pc:`setattr(unflipped, name, value)`, with the following caveats: + + 1. If :pc:`name` refers to a signature member, the assigned interface object is flipped. + 2. If :pc:`name` refers to a property setter, its :pc:`self` argument receives the flipped + interface. + """ + ... + + def __delattr__(self, name) -> None: + """Removes attribute :pc:`name` of the unflipped interface. + + Performs :pc:`delattr(unflipped, name)`, ensuring that, if :pc:`name` refers to a property + deleter, its :pc:`self` argument receives the flipped interface. + """ + ... + + def __repr__(self) -> str: + ... + +@overload +def flipped(interface: FlippedInterface[_T_Signature, _T_Interface]) -> _T_Interface: + ... + +# Can't be typed nicer for now. +@overload +def flipped(interface: _T_Interface) -> FlippedInterface[Any, _T_Interface]: + ... + +def flipped(interface: _T_Interface) -> _T_Interface | FlippedInterface[Any, _T_Interface]: + """ + Flip the data flow of the members of the interface object :pc:`interface`. + + If an interface object is flipped twice, returns the original object: + :pc:`flipped(flipped(interface)) is interface`. Otherwise, wraps :pc:`interface` in + a :class:`FlippedInterface` proxy object that flips the directions of its members. + + See the documentation for the :class:`FlippedInterface` class for a detailed discussion of how + this proxy object works. + """ + ... + +@final +class ConnectionError(Exception): + """Exception raised when the :func:`connect` function is requested to perform an impossible, + meaningless, or forbidden connection.""" + ... + + +def connect(m: ModuleLike, *args: AbstractInterface, **kwargs: AbstractInterface) -> None: + """Connect interface objects to each other. + + This function creates connections between ports of several interface objects. (Any number of + interface objects may be provided; in most cases it is two.) + + The connections can be made only if all of the objects satisfy a number of requirements: + + * Every interface object must have the same set of port members, and they must have the same + :meth:`dimensions `. + * For each path, the port members of every interface object must have the same width and reset + value (for port members corresponding to signals) or constant value (for port members + corresponding to constants). Signedness may differ. + * For each path, at most one interface object must have the corresponding port member be + an output. + * For a given path, if any of the interface objects has an input port member corresponding + to a constant value, then the rest of the interface objects must have output port members + corresponding to the same constant value. + + For example, if :pc:`obj1` is being connected to :pc:`obj2` and :pc:`obj3`, and :pc:`obj1.a.b` + is an output, then :pc:`obj2.a.b` and :pc:`obj2.a.b` must exist and be inputs. If :pc:`obj2.c` + is an input and its value is :pc:`Const(1)`, then :pc:`obj1.c` and :pc:`obj3.c` must be outputs + whose value is also :pc:`Const(1)`. If no ports besides :pc:`obj1.a.b` and :pc:`obj1.c` exist, + then no ports except for those two must exist on :pc:`obj2` and :pc:`obj3` either. + + Once it is determined that the interface objects can be connected, this function performs + an equivalent of: + + .. code:: + + m.d.comb += [ + in1.eq(out1), + in2.eq(out1), + ... + ] + + Where :pc:`out1` is an output and :pc:`in1`, :pc:`in2`, ... are the inputs that have the same + path. (If no interface object has an output for a given path, **no connection at all** is made.) + + The positions (within :pc:`args`) or names (within :pc:`kwargs`) of the arguments do not affect + the connections that are made. There is no difference in behavior between :pc:`connect(m, a, b)` + and :pc:`connect(m, b, a)` or :pc:`connect(m, arbiter=a, decoder=b)`. The names of the keyword + arguments serve only a documentation purpose: they clarify the diagnostic messages when + a connection cannot be made. + """ + ... + +class Component(Elaboratable, Generic[_T_Signature]): + """Base class for elaboratable interface objects. + + A component is an :class:`Elaboratable` whose interaction with other parts of the design is + defined by its signature. Most if not all elaboratables in idiomatic Amaranth code should be + components, as the signature clarifies the direction of data flow at their boundary. See + the :ref:`introduction to interfaces ` section for a practical guide to defining + and using components. + + There are two ways to define a component. If all instances of a component have the same + signature, it can be defined using :term:`variable annotations `: + + .. testcode:: + + class FixedComponent(wiring.Component): + en: In(1) + data: Out(8) + + The variable annotations are collected by the constructor :meth:`Component.__init__`. Only + public (not starting with ``_``) annotations with :class:`In ` or :class:`Out ` + objects are considered; all other annotations are ignored under the assumption that they are + interpreted by some other tool. + + It is possible to use inheritance to extend a component: the component's signature is composed + from the variable annotations in the class that is being constructed as well as all of its + base classes. It is an error to have more than one variable annotation for the same attribute. + + If different instances of a component may need to have different signatures, variable + annotations cannot be used. In this case, the constructor should be overridden, and + the computed signature members should be provided to the superclass constructor: + + .. testcode:: + + class ParametricComponent(wiring.Component): + def __init__(self, data_width): + super().__init__({ + "en": In(1), + "data": Out(data_width) + }) + + It is also possible to pass a :class:`Signature` instance to the superclass constructor. + + Aside from initializing the :attr:`signature` attribute, the :meth:`Component.__init__` + constructor creates attributes corresponding to all of the members defined in the signature. + If an attribute with the same name as that of a member already exists, an error is raied. + + Raises + ------ + :exc:`TypeError` + If the :pc:`signature` object is neither a :class:`Signature` nor a :class:`dict`. + If neither variable annotations nor the :pc:`signature` argument are present, or if + both are present. + :exc:`NameError` + If a name conflict is detected between two variable annotations, or between a member + and an existing attribute. + """ + def __init__(self, signature: Optional[_T_Signature | dict[str, Member]] = None) -> None: + ... + + @property + def signature(self) -> _T_Signature: + """The signature of the component. + + .. important:: + + Do not override this property. Once a component is constructed, its :attr:`signature` + property must always return the same :class:`Signature` instance. The constructor + can be used to customize a component's signature. + """ + ... + + + diff --git a/stubs/amaranth/sim/core.pyi b/stubs/amaranth/sim/core.pyi index 23fef3472..cee8e0f2e 100644 --- a/stubs/amaranth/sim/core.pyi +++ b/stubs/amaranth/sim/core.pyi @@ -3,8 +3,8 @@ This type stub file was generated by pyright. """ from .._utils import deprecated -from ..hdl.cd import * -from ..hdl.ir import * +from ..hdl._cd import * +from ..hdl._ir import * __all__ = ["Settle", "Delay", "Tick", "Passive", "Active", "Simulator"] class Command: diff --git a/stubs/amaranth/utils.pyi b/stubs/amaranth/utils.pyi index 0da04bf79..6ca424a08 100644 --- a/stubs/amaranth/utils.pyi +++ b/stubs/amaranth/utils.pyi @@ -2,10 +2,22 @@ This type stub file was generated by pyright. """ -__all__ = ["log2_int", "bits_for"] -def log2_int(n:int, need_pow2:bool=...) -> int: - ... +__all__ = ["ceil_log2", "exact_log2", "bits_for"] def bits_for(n:int, require_sign_bit:bool=...) -> int: ... + +def ceil_log2(n : int) -> int: + """Returns the integer log2 of the smallest power-of-2 greater than or equal to ``n``. + + Raises a ``ValueError`` for negative inputs. + """ + ... + +def exact_log2(n : int) -> int: + """Returns the integer log2 of ``n``, which must be an exact power of two. + + Raises a ``ValueError`` if ``n`` is not a power of two. + """ + ... diff --git a/stubs/amaranth/vendor/__init__.pyi b/stubs/amaranth/vendor/__init__.pyi new file mode 100644 index 000000000..d9c2463ca --- /dev/null +++ b/stubs/amaranth/vendor/__init__.pyi @@ -0,0 +1,8 @@ +""" +This type stub file was generated by pyright. +""" + +from ._lattice_ecp5 import LatticeECP5Platform +from ._lattice_ice40 import LatticeICE40Platform + +__all__ = ["LatticeECP5Platform", "LatticeICE40Platform"] diff --git a/stubs/amaranth/vendor/_lattice_ecp5.pyi b/stubs/amaranth/vendor/_lattice_ecp5.pyi new file mode 100644 index 000000000..b043637c0 --- /dev/null +++ b/stubs/amaranth/vendor/_lattice_ecp5.pyi @@ -0,0 +1,129 @@ +""" +This type stub file was generated by pyright. +""" + +from ..hdl import * +from ..build import * + +class LatticeECP5Platform(TemplatedPlatform): + """ + .. rubric:: Trellis toolchain + + Required tools: + * ``yosys`` + * ``nextpnr-ecp5`` + * ``ecppack`` + + The environment is populated by running the script specified in the environment variable + ``AMARANTH_ENV_TRELLIS``, if present. + + Available overrides: + * ``verbose``: enables logging of informational messages to standard error. + * ``read_verilog_opts``: adds options for ``read_verilog`` Yosys command. + * ``synth_opts``: adds options for ``synth_ecp5`` Yosys command. + * ``script_after_read``: inserts commands after ``read_ilang`` in Yosys script. + * ``script_after_synth``: inserts commands after ``synth_ecp5`` in Yosys script. + * ``yosys_opts``: adds extra options for ``yosys``. + * ``nextpnr_opts``: adds extra options for ``nextpnr-ecp5``. + * ``ecppack_opts``: adds extra options for ``ecppack``. + * ``add_preferences``: inserts commands at the end of the LPF file. + + Build products: + * ``{{name}}.rpt``: Yosys log. + * ``{{name}}.json``: synthesized RTL. + * ``{{name}}.tim``: nextpnr log. + * ``{{name}}.config``: ASCII bitstream. + * ``{{name}}.bit``: binary bitstream. + * ``{{name}}.svf``: JTAG programming vector. + + .. rubric:: Diamond toolchain + + Required tools: + * ``pnmainc`` + * ``ddtcmd`` + + The environment is populated by running the script specified in the environment variable + ``AMARANTH_ENV_DIAMOND``, if present. On Linux, diamond_env as provided by Diamond + itself is a good candidate. On Windows, the following script (named ``diamond_env.bat``, + for instance) is known to work:: + + @echo off + set PATH=C:\\lscc\\diamond\\%DIAMOND_VERSION%\\bin\\nt64;%PATH% + + Available overrides: + * ``script_project``: inserts commands before ``prj_project save`` in Tcl script. + * ``script_after_export``: inserts commands after ``prj_run Export`` in Tcl script. + * ``add_preferences``: inserts commands at the end of the LPF file. + * ``add_constraints``: inserts commands at the end of the XDC file. + + Build products: + * ``{{name}}_impl/{{name}}_impl.htm``: consolidated log. + * ``{{name}}.bit``: binary bitstream. + * ``{{name}}.svf``: JTAG programming vector. + """ + toolchain = ... + device = ... + package = ... + speed = ... + grade = ... + _nextpnr_device_options = ... + _nextpnr_package_options = ... + _trellis_required_tools = ... + _trellis_file_templates = ... + _trellis_command_templates = ... + _diamond_required_tools = ... + _diamond_file_templates = ... + _diamond_command_templates = ... + def __init__(self, *, toolchain=...) -> None: + ... + + @property + def required_tools(self): # -> list[str]: + ... + + @property + def file_templates(self): # -> dict[str, str]: + ... + + @property + def command_templates(self): # -> list[str]: + ... + + @property + def default_clk_constraint(self): # -> Clock: + ... + + def create_missing_domain(self, name): # -> Module | None: + ... + + _single_ended_io_types = ... + _differential_io_types = ... + def should_skip_port_component(self, port, attrs, component): # -> bool: + ... + + def get_input(self, pin, port, attrs, invert): # -> Module: + ... + + def get_output(self, pin, port, attrs, invert): # -> Module: + ... + + def get_tristate(self, pin, port, attrs, invert): # -> Module: + ... + + def get_input_output(self, pin, port, attrs, invert): # -> Module: + ... + + def get_diff_input(self, pin, port, attrs, invert): # -> Module: + ... + + def get_diff_output(self, pin, port, attrs, invert): # -> Module: + ... + + def get_diff_tristate(self, pin, port, attrs, invert): # -> Module: + ... + + def get_diff_input_output(self, pin, port, attrs, invert): # -> Module: + ... + + + diff --git a/stubs/amaranth/vendor/_lattice_ice40.pyi b/stubs/amaranth/vendor/_lattice_ice40.pyi new file mode 100644 index 000000000..2f72ba554 --- /dev/null +++ b/stubs/amaranth/vendor/_lattice_ice40.pyi @@ -0,0 +1,125 @@ +""" +This type stub file was generated by pyright. +""" + +from ..hdl import * +from ..build import * + +class LatticeICE40Platform(TemplatedPlatform): + """ + .. rubric:: IceStorm toolchain + + Required tools: + * ``yosys`` + * ``nextpnr-ice40`` + * ``icepack`` + + The environment is populated by running the script specified in the environment variable + ``AMARANTH_ENV_ICESTORM``, if present. + + Available overrides: + * ``verbose``: enables logging of informational messages to standard error. + * ``read_verilog_opts``: adds options for ``read_verilog`` Yosys command. + * ``synth_opts``: adds options for ``synth_ice40`` Yosys command. + * ``script_after_read``: inserts commands after ``read_ilang`` in Yosys script. + * ``script_after_synth``: inserts commands after ``synth_ice40`` in Yosys script. + * ``yosys_opts``: adds extra options for ``yosys``. + * ``nextpnr_opts``: adds extra options for ``nextpnr-ice40``. + * ``add_pre_pack``: inserts commands at the end in pre-pack Python script. + * ``add_constraints``: inserts commands at the end in the PCF file. + + Build products: + * ``{{name}}.rpt``: Yosys log. + * ``{{name}}.json``: synthesized RTL. + * ``{{name}}.tim``: nextpnr log. + * ``{{name}}.asc``: ASCII bitstream. + * ``{{name}}.bin``: binary bitstream. + + .. rubric:: iCECube2 toolchain + + This toolchain comes in two variants: ``LSE-iCECube2`` and ``Synplify-iCECube2``. + + Required tools: + * iCECube2 toolchain + * ``tclsh`` + + The environment is populated by setting the necessary environment variables based on + ``AMARANTH_ENV_ICECUBE2``, which must point to the root of the iCECube2 installation, and + is required. + + Available overrides: + * ``verbose``: enables logging of informational messages to standard error. + * ``lse_opts``: adds options for LSE. + * ``script_after_add``: inserts commands after ``add_file`` in Synplify Tcl script. + * ``script_after_options``: inserts commands after ``set_option`` in Synplify Tcl script. + * ``add_constraints``: inserts commands in SDC file. + * ``script_after_flow``: inserts commands after ``run_sbt_backend_auto`` in SBT + Tcl script. + + Build products: + * ``{{name}}_lse.log`` (LSE) or ``{{name}}_design/{{name}}.htm`` (Synplify): synthesis log. + * ``sbt/outputs/router/{{name}}_timing.rpt``: timing report. + * ``{{name}}.edf``: EDIF netlist. + * ``{{name}}.bin``: binary bitstream. + """ + toolchain = ... + device = ... + package = ... + _nextpnr_device_options = ... + _nextpnr_package_options = ... + _icestorm_required_tools = ... + _icestorm_file_templates = ... + _icestorm_command_templates = ... + _icecube2_required_tools = ... + _icecube2_file_templates = ... + _lse_icecube2_command_templates = ... + _synplify_icecube2_command_templates = ... + def __init__(self, *, toolchain=...) -> None: + ... + + @property + def family(self): # -> Literal['iCE40', 'iCE5']: + ... + + @property + def required_tools(self): # -> list[str]: + ... + + @property + def file_templates(self): # -> dict[str, str]: + ... + + @property + def command_templates(self): # -> list[str]: + ... + + @property + def default_clk_constraint(self): # -> Clock: + ... + + def create_missing_domain(self, name): # -> Module | None: + ... + + def should_skip_port_component(self, port, attrs, component): # -> bool: + ... + + def get_input(self, pin, port, attrs, invert): # -> Module: + ... + + def get_output(self, pin, port, attrs, invert): # -> Module: + ... + + def get_tristate(self, pin, port, attrs, invert): # -> Module: + ... + + def get_input_output(self, pin, port, attrs, invert): # -> Module: + ... + + def get_diff_input(self, pin, port, attrs, invert): # -> Module: + ... + + def get_diff_output(self, pin, port, attrs, invert): # -> Module: + ... + + + diff --git a/test/asm/csr.asm b/test/asm/csr.asm index 3c44657df..efa674453 100644 --- a/test/asm/csr.asm +++ b/test/asm/csr.asm @@ -3,3 +3,5 @@ rdinstret x1 nop nop rdinstret x2 +end: +j end diff --git a/test/cache/__init__.py b/test/cache/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/frontend/test_icache.py b/test/cache/test_icache.py similarity index 94% rename from test/frontend/test_icache.py rename to test/cache/test_icache.py index e52b73ba8..2afeff6db 100644 --- a/test/frontend/test_icache.py +++ b/test/cache/test_icache.py @@ -4,19 +4,21 @@ from amaranth import Elaboratable, Module from amaranth.sim import Passive, Settle -from amaranth.utils import log2_int +from amaranth.utils import exact_log2 from transactron.lib import AdapterTrans, Adapter -from coreblocks.frontend.icache import SimpleWBCacheRefiller, ICache, ICacheBypass, CacheRefillerInterface +from coreblocks.cache.icache import ICache, ICacheBypass, CacheRefillerInterface from coreblocks.params import GenParams, ICacheLayouts from coreblocks.peripherals.wishbone import WishboneMaster, WishboneParameters +from coreblocks.peripherals.bus_adapter import WishboneMasterAdapter from coreblocks.params.configurations import test_core_config +from coreblocks.cache.refiller import SimpleCommonBusCacheRefiller -from ..common import TestCaseWithSimulator, TestbenchIO, def_method_mock, RecordIntDictRet +from transactron.testing import TestCaseWithSimulator, TestbenchIO, def_method_mock, RecordIntDictRet from ..peripherals.test_wishbone import WishboneInterfaceWrapper -class SimpleWBCacheRefillerTestCircuit(Elaboratable): +class SimpleCommonBusCacheRefillerTestCircuit(Elaboratable): def __init__(self, gen_params: GenParams): self.gen_params = gen_params self.cp = self.gen_params.icache_params @@ -29,18 +31,22 @@ def elaborate(self, platform): addr_width=self.gen_params.isa.xlen, ) self.wb_master = WishboneMaster(wb_params) + self.bus_master_adapter = WishboneMasterAdapter(self.wb_master) - self.refiller = SimpleWBCacheRefiller(self.gen_params.get(ICacheLayouts), self.cp, self.wb_master) + self.refiller = SimpleCommonBusCacheRefiller( + self.gen_params.get(ICacheLayouts), self.cp, self.bus_master_adapter + ) self.start_refill = TestbenchIO(AdapterTrans(self.refiller.start_refill)) self.accept_refill = TestbenchIO(AdapterTrans(self.refiller.accept_refill)) m.submodules.wb_master = self.wb_master + m.submodules.bus_master_adapter = self.bus_master_adapter m.submodules.refiller = self.refiller m.submodules.start_refill = self.start_refill m.submodules.accept_refill = self.accept_refill - self.wb_ctrl = WishboneInterfaceWrapper(self.wb_master.wbMaster) + self.wb_ctrl = WishboneInterfaceWrapper(self.wb_master.wb_master) return m @@ -54,7 +60,7 @@ def elaborate(self, platform): ("blk_size64B_rv32i", 32, 6), ], ) -class TestSimpleWBCacheRefiller(TestCaseWithSimulator): +class TestSimpleCommonBusCacheRefiller(TestCaseWithSimulator): isa_xlen: int block_size: int @@ -63,7 +69,7 @@ def setUp(self) -> None: test_core_config.replace(xlen=self.isa_xlen, icache_block_size_bits=self.block_size) ) self.cp = self.gen_params.icache_params - self.test_module = SimpleWBCacheRefillerTestCircuit(self.gen_params) + self.test_module = SimpleCommonBusCacheRefillerTestCircuit(self.gen_params) random.seed(42) @@ -92,7 +98,7 @@ def wishbone_slave(self): yield from self.test_module.wb_ctrl.slave_wait() # Wishbone is addressing words, so we need to shift it a bit to get the real address. - addr = (yield self.test_module.wb_ctrl.wb.adr) << log2_int(self.cp.word_width_bytes) + addr = (yield self.test_module.wb_ctrl.wb.adr) << exact_log2(self.cp.word_width_bytes) yield while random.random() < 0.5: @@ -150,11 +156,14 @@ def elaborate(self, platform): ) m.submodules.wb_master = self.wb_master = WishboneMaster(wb_params) - m.submodules.bypass = self.bypass = ICacheBypass(self.gen_params.get(ICacheLayouts), self.cp, self.wb_master) + m.submodules.bus_master_adapter = self.bus_master_adapter = WishboneMasterAdapter(self.wb_master) + m.submodules.bypass = self.bypass = ICacheBypass( + self.gen_params.get(ICacheLayouts), self.cp, self.bus_master_adapter + ) m.submodules.issue_req = self.issue_req = TestbenchIO(AdapterTrans(self.bypass.issue_req)) m.submodules.accept_res = self.accept_res = TestbenchIO(AdapterTrans(self.bypass.accept_res)) - self.wb_ctrl = WishboneInterfaceWrapper(self.wb_master.wbMaster) + self.wb_ctrl = WishboneInterfaceWrapper(self.wb_master.wb_master) return m @@ -204,7 +213,7 @@ def wishbone_slave(self): yield from self.m.wb_ctrl.slave_wait() # Wishbone is addressing words, so we need to shift it a bit to get the real address. - addr = (yield self.m.wb_ctrl.wb.adr) << log2_int(self.cp.word_width_bytes) + addr = (yield self.m.wb_ctrl.wb.adr) << exact_log2(self.cp.word_width_bytes) while random.random() < 0.5: yield @@ -310,7 +319,7 @@ def init_module(self, ways, sets) -> None: test_core_config.replace( xlen=self.isa_xlen, icache_ways=ways, - icache_sets_bits=log2_int(sets), + icache_sets_bits=exact_log2(sets), icache_block_size_bits=self.block_size, ) ) diff --git a/test/common/functions.py b/test/common/functions.py deleted file mode 100644 index eb7abf886..000000000 --- a/test/common/functions.py +++ /dev/null @@ -1,34 +0,0 @@ -from amaranth import * -from amaranth.hdl.ast import Statement -from amaranth.sim.core import Command -from typing import TypeVar, Any, Generator, TypeAlias, TYPE_CHECKING, Union -from transactron.utils._typing import RecordValueDict, RecordIntDict - - -if TYPE_CHECKING: - from .infrastructure import CoreblocksCommand - - -T = TypeVar("T") -TestGen: TypeAlias = Generator[Union[Command, Value, Statement, "CoreblocksCommand", None], Any, T] - - -def set_inputs(values: RecordValueDict, field: Record) -> TestGen[None]: - for name, value in values.items(): - if isinstance(value, dict): - yield from set_inputs(value, getattr(field, name)) - else: - yield getattr(field, name).eq(value) - - -def get_outputs(field: Record) -> TestGen[RecordIntDict]: - # return dict of all signal values in a record because amaranth's simulator can't read all - # values of a Record in a single yield - it can only read Values (Signals) - result = {} - for name, _, _ in field.layout: - val = getattr(field, name) - if isinstance(val, Signal): - result[name] = yield val - else: # field is a Record - result[name] = yield from get_outputs(val) - return result diff --git a/test/common/profiler.py b/test/common/profiler.py deleted file mode 100644 index 58c236153..000000000 --- a/test/common/profiler.py +++ /dev/null @@ -1,85 +0,0 @@ -import os.path -from amaranth.sim import * -from transactron.core import MethodMap, TransactionManager -from transactron.profiler import CycleProfile, Profile, ProfileInfo -from transactron.utils import SrcLoc -from .functions import TestGen - -__all__ = ["profiler_process"] - - -def profiler_process(transaction_manager: TransactionManager, profile: Profile, clk_period: float): - def process() -> TestGen: - method_map = MethodMap(transaction_manager.transactions) - cgr, _, _ = TransactionManager._conflict_graph(method_map) - id_map = dict[int, int]() - id_seq = 0 - - def get_id(obj): - try: - return id_map[id(obj)] - except KeyError: - nonlocal id_seq - id_seq = id_seq + 1 - id_map[id(obj)] = id_seq - return id_seq - - def local_src_loc(src_loc: SrcLoc): - return (os.path.relpath(src_loc[0]), src_loc[1]) - - for transaction in method_map.transactions: - profile.transactions_and_methods[get_id(transaction)] = ProfileInfo( - transaction.owned_name, local_src_loc(transaction.src_loc), True - ) - - for method in method_map.methods: - profile.transactions_and_methods[get_id(method)] = ProfileInfo( - method.owned_name, local_src_loc(method.src_loc), False - ) - - yield Passive() - while True: - yield Delay((1 - 1e-4) * clk_period) # shorter than one clock cycle - - cprof = CycleProfile() - profile.cycles.append(cprof) - - for transaction in method_map.transactions: - request = yield transaction.request - runnable = yield transaction.runnable - grant = yield transaction.grant - - if grant: - cprof.running[get_id(transaction)] = None - elif request and runnable: - for transaction2 in cgr[transaction]: - if (yield transaction2.grant): - cprof.locked[get_id(transaction)] = get_id(transaction2) - - running = set(cprof.running) - for method in method_map.methods: - if (yield method.run): - running.add(get_id(method)) - - locked_methods = set[int]() - for method in method_map.methods: - if get_id(method) not in running: - if any(get_id(transaction) in running for transaction in method_map.transactions_by_method[method]): - locked_methods.add(get_id(method)) - - for method in method_map.methods: - if get_id(method) in running: - for t_or_m in method_map.method_parents[method]: - if get_id(t_or_m) in running: - cprof.running[get_id(method)] = get_id(t_or_m) - elif get_id(method) in locked_methods: - caller = next( - get_id(t_or_m) - for t_or_m in method_map.method_parents[method] - if get_id(t_or_m) in running or get_id(t_or_m) in locked_methods - ) - cprof.locked[get_id(method)] = caller - - yield - - return process diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 000000000..a291b488d --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,113 @@ +import re +import os +from typing import Optional +import pytest + + +def pytest_addoption(parser: pytest.Parser): + group = parser.getgroup("coreblocks") + group.addoption("--coreblocks-regression", action="store_true", help="Run also regression tests.") + group.addoption( + "--coreblocks-backend", + default="cocotb", + choices=["cocotb", "pysim"], + help="Simulation backend for regression tests", + ) + group.addoption("--coreblocks-traces", action="store_true", help="Generate traces from regression tests") + group.addoption("--coreblocks-profile", action="store_true", help="Write execution profiles") + group.addoption("--coreblocks-list", action="store_true", help="List all tests in flatten format.") + group.addoption( + "--coreblocks-test-name", + action="store", + type=str, + help="Name or regexp in flatten format matching the tests to run.", + ) + group.addoption( + "--coreblocks-test-count", + action="store", + type=int, + help="Number of tests to start. If less than number of all selected tests, then starts only subset of them.", + ) + group.addoption("--coreblocks-log-filter", default=".*", action="store", help="Regexp used to filter out logs.") + + +def generate_unittestname(item: pytest.Item) -> str: + full_name = ".".join(map(lambda s: s[:-3] if s[-3:] == ".py" else s, map(lambda x: x.name, item.listchain()))) + return full_name + + +def generate_test_cases_list(session: pytest.Session) -> list[str]: + tests_list = [] + for item in session.items: + full_name = generate_unittestname(item) + tests_list.append(full_name) + return tests_list + + +def pytest_collection_finish(session: pytest.Session): + if session.config.getoption("coreblocks_list"): + full_names = generate_test_cases_list(session) + for i in full_names: + print(i) + + +@pytest.hookimpl(tryfirst=True) +def pytest_runtestloop(session: pytest.Session) -> Optional[bool]: + if session.config.getoption("coreblocks_list"): + return True + return None + + +def deselect_based_on_flatten_name(items: list[pytest.Item], config: pytest.Config) -> None: + coreblocks_test_name = config.getoption("coreblocks_test_name") + if not isinstance(coreblocks_test_name, str): + return + + deselected = [] + remaining = [] + regexp = re.compile(coreblocks_test_name) + for item in items: + full_name = generate_unittestname(item) + match = regexp.search(full_name) + if match is None: + deselected.append(item) + else: + remaining.append(item) + if deselected: + config.hook.pytest_deselected(items=deselected) + items[:] = remaining + + +def deselect_based_on_count(items: list[pytest.Item], config: pytest.Config) -> None: + coreblocks_test_count = config.getoption("coreblocks_test_count") + if not isinstance(coreblocks_test_count, int): + return + + deselected = items[coreblocks_test_count:] + remaining = items[:coreblocks_test_count] + if deselected: + config.hook.pytest_deselected(items=deselected) + items[:] = remaining + + +def pytest_collection_modifyitems(items: list[pytest.Item], config: pytest.Config) -> None: + deselect_based_on_flatten_name(items, config) + deselect_based_on_count(items, config) + + +def pytest_runtest_setup(item: pytest.Item): + """ + This function is called to perform the setup phase for every test, so + it is a perfect moment to set environment variables. + """ + if item.config.getoption("--coreblocks-traces", False): # type: ignore + os.environ["__TRANSACTRON_DUMP_TRACES"] = "1" + + if item.config.getoption("--coreblocks-profile", False): # type: ignore + os.environ["__TRANSACTRON_PROFILE"] = "1" + + log_filter = item.config.getoption("--coreblocks-log-filter") + os.environ["__TRANSACTRON_LOG_FILTER"] = ".*" if not isinstance(log_filter, str) else log_filter + + log_level = item.config.getoption("--log-level") + os.environ["__TRANSACTRON_LOG_LEVEL"] = "WARNING" if not isinstance(log_level, str) else log_level diff --git a/test/external/embench/board_config/coreblocks-sim/board.cfg b/test/external/embench/board_config/coreblocks-sim/board.cfg index 96eaae307..b1a885340 100644 --- a/test/external/embench/board_config/coreblocks-sim/board.cfg +++ b/test/external/embench/board_config/coreblocks-sim/board.cfg @@ -1,5 +1,5 @@ cc = 'riscv64-unknown-elf-gcc' -cflags = (['-c', '-fdata-sections', '-march=rv32ic_zmmul_zicsr', '-mabi=ilp32']) -ldflags = (['-Wl,-gc-sections', '-march=rv32ic_zmmul_zicsr', '-mabi=ilp32', '-nostartfiles', '-T../../../common/link.ld']) +cflags = (['-c', '-fdata-sections', '-march=rv32imc_zba_zbb_zbc_zbs_zicsr', '-mabi=ilp32']) +ldflags = (['-Wl,-gc-sections', '-march=rv32imc_zba_zbb_zbc_zbs_zicsr', '-mabi=ilp32', '-nostartfiles', '-T../../../common/link.ld']) user_libs = (['-lm']) cpu_mhz = 0.01 diff --git a/test/external/riscof/coreblocks/coreblocks_isa.yaml b/test/external/riscof/coreblocks/coreblocks_isa.yaml index 483e8b41f..8b9623298 100644 --- a/test/external/riscof/coreblocks/coreblocks_isa.yaml +++ b/test/external/riscof/coreblocks/coreblocks_isa.yaml @@ -1,6 +1,6 @@ hart_ids: [0] hart0: - ISA: RV32I + ISA: RV32IMCZba_Zbb_Zbc_Zbs physical_addr_sz: 32 User_Spec_Version: '2.3' diff --git a/test/external/riscof/coreblocks/riscof_coreblocks.py b/test/external/riscof/coreblocks/riscof_coreblocks.py index 549e24934..ae6e63268 100644 --- a/test/external/riscof/coreblocks/riscof_coreblocks.py +++ b/test/external/riscof/coreblocks/riscof_coreblocks.py @@ -97,24 +97,9 @@ def build(self, isa_yaml, platform_yaml): # will be useful in setting integer value in the compiler string (if not already hardcoded); self.xlen = "64" if 64 in ispec["supported_xlen"] else "32" - # for coreblocks start building the '--isa' argument. the self.isa is dut specific and may not be - # useful for all DUTs - self.isa = "rv" + self.xlen - if "I" in ispec["ISA"]: - self.isa += "i" - if "M" in ispec["ISA"]: - self.isa += "m" - if "F" in ispec["ISA"]: - self.isa += "f" - if "D" in ispec["ISA"]: - self.isa += "d" - if "C" in ispec["ISA"]: - self.isa += "c" - if "B" in ispec["ISA"]: - self.isa += "b" - - # TODO: The following assumes you are using the riscv-gcc toolchain. If - # not please change appropriately + self.isa = ispec["ISA"].lower() + + # The following assumes you are using the riscv-gcc toolchain. self.compile_cmd = self.compile_cmd + " -mabi=" + ("lp64 " if 64 in ispec["supported_xlen"] else "ilp32 ") def runTests(self, testList): # noqa: N802 N803 @@ -168,6 +153,11 @@ def runTests(self, testList): # noqa: N802 N803 target_build = "cd {0}; {1};".format(testentry["work_dir"], buildcmd) target_run = "mkdir -p {0}; cd {1}; {2};".format(testentry["work_dir"], self.work_dir, simcmd) + # for some reason C extension enables priv tests. Disable them for now. Not ready yet! + if "privilege" in test_dir: + print("SKIP generating", test_dir, test) + continue + # create a target. The makeutil will create a target with the name "TARGET" where num # starts from 0 and increments automatically for each new target that is added make_build.add_target(target_build) diff --git a/test/external/riscof/riscv-arch-test b/test/external/riscof/riscv-arch-test new file mode 160000 index 000000000..8a52b016d --- /dev/null +++ b/test/external/riscof/riscv-arch-test @@ -0,0 +1 @@ +Subproject commit 8a52b016dbe1e2733cc168b9d6e5c93e39059d4d diff --git a/test/external/riscof/spike_simple/riscof_spike_simple.py b/test/external/riscof/spike_simple/riscof_spike_simple.py index 5e06de990..427fb5e37 100644 --- a/test/external/riscof/spike_simple/riscof_spike_simple.py +++ b/test/external/riscof/spike_simple/riscof_spike_simple.py @@ -54,19 +54,14 @@ def initialise(self, suite, work_dir, compliance_env): def build(self, isa_yaml, platform_yaml): ispec = utils.load_yaml(isa_yaml)['hart0'] self.xlen = ('64' if 64 in ispec['supported_xlen'] else '32') - self.isa = 'rv' + self.xlen if "64I" in ispec["ISA"]: self.compile_cmd = self.compile_cmd+' -mabi='+'lp64 ' elif "32I" in ispec["ISA"]: self.compile_cmd = self.compile_cmd+' -mabi='+'ilp32 ' elif "32E" in ispec["ISA"]: self.compile_cmd = self.compile_cmd+' -mabi='+'ilp32e ' - if "I" in ispec["ISA"]: - self.isa += 'i' - if "M" in ispec["ISA"]: - self.isa += 'm' - if "C" in ispec["ISA"]: - self.isa += 'c' + self.isa = ispec["ISA"].lower() + compiler = "riscv64-unknown-elf-gcc".format(self.xlen) if shutil.which(compiler) is None: logger.error(compiler+": executable not found. Please check environment setup.") diff --git a/test/external/riscof/spike_simple/spike_simple_isa.yaml b/test/external/riscof/spike_simple/spike_simple_isa.yaml index dad55a4f1..302439c34 100644 --- a/test/external/riscof/spike_simple/spike_simple_isa.yaml +++ b/test/external/riscof/spike_simple/spike_simple_isa.yaml @@ -1,6 +1,6 @@ hart_ids: [0] hart0: - ISA: RV32IMCZicsr_Zifencei + ISA: RV32IMCBZicsr_Zifencei_Zba_Zbb_Zbc_Zbs physical_addr_sz: 32 User_Spec_Version: '2.3' supported_xlen: [32] diff --git a/test/frontend/test_decode_stage.py b/test/frontend/test_decode_stage.py index 1638d9e83..c9c80251a 100644 --- a/test/frontend/test_decode_stage.py +++ b/test/frontend/test_decode_stage.py @@ -1,6 +1,6 @@ from transactron.lib import AdapterTrans, FIFO -from ..common import TestCaseWithSimulator, TestbenchIO, SimpleTestCircuit, ModuleConnector +from transactron.testing import TestCaseWithSimulator, TestbenchIO, SimpleTestCircuit, ModuleConnector from coreblocks.frontend.decode_stage import DecodeStage from coreblocks.params import GenParams, FetchLayouts, DecodeLayouts, OpType, Funct3, Funct7 diff --git a/test/frontend/test_fetch.py b/test/frontend/test_fetch.py index 95c9d97b4..e57392485 100644 --- a/test/frontend/test_fetch.py +++ b/test/frontend/test_fetch.py @@ -8,14 +8,14 @@ from transactron.core import Method from transactron.lib import AdapterTrans, FIFO, Adapter from coreblocks.frontend.fetch import Fetch, UnalignedFetch -from coreblocks.frontend.icache import ICacheInterface +from coreblocks.cache.iface import CacheInterface from coreblocks.params import * from coreblocks.params.configurations import test_core_config from transactron.utils import ModuleConnector -from ..common import TestCaseWithSimulator, TestbenchIO, def_method_mock, SimpleTestCircuit +from transactron.testing import TestCaseWithSimulator, TestbenchIO, def_method_mock, SimpleTestCircuit -class MockedICache(Elaboratable, ICacheInterface): +class MockedICache(Elaboratable, CacheInterface): def __init__(self, gen_params: GenParams): layouts = gen_params.get(ICacheLayouts) @@ -43,14 +43,13 @@ def setUp(self) -> None: fifo = FIFO(self.gen_params.get(FetchLayouts).raw_instr, depth=2) self.io_out = TestbenchIO(AdapterTrans(fifo.read)) + self.fetch = SimpleTestCircuit(Fetch(self.gen_params, self.icache, fifo.write)) - self.verify_branch = TestbenchIO(AdapterTrans(self.fetch._dut.verify_branch)) self.m = ModuleConnector( icache=self.icache, fetch=self.fetch, io_out=self.io_out, - verify_branch=self.verify_branch, fifo=fifo, ) @@ -82,6 +81,7 @@ def cache_process(self): # randomize being a branch instruction if is_branch: data |= 0b1100000 + data &= ~0b0010000 # but not system self.output_q.append({"instr": data, "error": 0}) @@ -111,19 +111,32 @@ def accept_res_mock(self): return self.output_q.popleft() def fetch_out_check(self): + discard_mispredict = False + next_pc = 0 for _ in range(self.iterations): + v = yield from self.io_out.call() + if discard_mispredict: + while v["pc"] != next_pc: + v = yield from self.io_out.call() + discard_mispredict = False + while len(self.instr_queue) == 0: yield instr = self.instr_queue.popleft() - if instr["is_branch"]: - yield from self.random_wait(10) - yield from self.verify_branch.call(from_pc=instr["pc"], next_pc=instr["next_pc"]) - - v = yield from self.io_out.call() self.assertEqual(v["pc"], instr["pc"]) self.assertEqual(v["instr"], instr["instr"]) + if instr["is_branch"]: + # branches on mispredict will stall fetch because of exception and then resume with new pc + yield from self.random_wait(5) + yield from self.fetch.stall_exception.call() + yield from self.random_wait(5) + yield from self.fetch.resume.call(pc=instr["next_pc"], resume_from_exception=1) + + discard_mispredict = True + next_pc = instr["next_pc"] + def test(self): with self.run_simulation(self.m) as sim: sim.add_sync_process(self.cache_process) @@ -140,9 +153,10 @@ def setUp(self) -> None: fifo = FIFO(self.gen_params.get(FetchLayouts).raw_instr, depth=2) self.io_out = TestbenchIO(AdapterTrans(fifo.read)) fetch = UnalignedFetch(self.gen_params, self.icache, fifo.write) - self.verify_branch = TestbenchIO(AdapterTrans(fetch.verify_branch)) + self.fetch_resume = TestbenchIO(AdapterTrans(fetch.resume)) + self.fetch_stall_exception = TestbenchIO(AdapterTrans(fetch.stall_exception)) - self.m = ModuleConnector(self.icache, fifo, self.io_out, fetch, self.verify_branch) + self.m = ModuleConnector(self.icache, fifo, self.io_out, fetch, self.fetch_resume, self.fetch_stall_exception) self.mem = {} self.memerr = set() @@ -174,6 +188,7 @@ def gen_instr_seq(self): data |= 0b11 # 2 lowest bits must be set in 32-bit long instructions if is_branch: data |= 0b1100000 + data &= ~0b0010000 self.mem[pc] = data & 0xFFFF self.mem[pc + 2] = data >> 16 @@ -224,6 +239,7 @@ def accept_res_mock(self): return self.output_q.popleft() def fetch_out_check(self): + discard_mispredict = False while self.instr_queue: instr = self.instr_queue.popleft() @@ -234,12 +250,22 @@ def fetch_out_check(self): ) + 2 in self.memerr v = yield from self.io_out.call() + if discard_mispredict: + while v["pc"] != instr["pc"]: + v = yield from self.io_out.call() + discard_mispredict = False + self.assertEqual(v["pc"], instr["pc"]) self.assertEqual(v["access_fault"], instr_error) if instr["is_branch"] or instr_error: - yield from self.random_wait(10) - yield from self.verify_branch.call(next_pc=instr["next_pc"]) + yield from self.random_wait(5) + yield from self.fetch_stall_exception.call() + yield from self.random_wait(5) + while (yield from self.fetch_resume.call_try(pc=instr["next_pc"], resume_from_exception=1)) is None: + yield from self.io_out.call_try() # try flushing to unblock or wait + + discard_mispredict = True def test(self): self.gen_instr_seq() diff --git a/test/frontend/test_instr_decoder.py b/test/frontend/test_instr_decoder.py index abd3183cc..4c0a0b4b6 100644 --- a/test/frontend/test_instr_decoder.py +++ b/test/frontend/test_instr_decoder.py @@ -1,6 +1,6 @@ from amaranth.sim import * -from ..common import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator from coreblocks.params import * from coreblocks.params.configurations import test_core_config @@ -179,92 +179,85 @@ def setUp(self): self.decoder = InstrDecoder(self.gen_params) self.cnt = 1 - def do_test(self, test): + def do_test(self, tests: list[InstrTest]): def process(): - yield self.decoder.instr.eq(test.encoding) - yield Settle() + for test in tests: + yield self.decoder.instr.eq(test.encoding) + yield Settle() - self.assertEqual((yield self.decoder.illegal), test.illegal) - if test.illegal: - return + self.assertEqual((yield self.decoder.illegal), test.illegal) + if test.illegal: + return - self.assertEqual((yield self.decoder.opcode), test.opcode) + self.assertEqual((yield self.decoder.opcode), test.opcode) - if test.funct3 is not None: - self.assertEqual((yield self.decoder.funct3), test.funct3) - self.assertEqual((yield self.decoder.funct3_v), test.funct3 is not None) + if test.funct3 is not None: + self.assertEqual((yield self.decoder.funct3), test.funct3) + self.assertEqual((yield self.decoder.funct3_v), test.funct3 is not None) - if test.funct7 is not None: - self.assertEqual((yield self.decoder.funct7), test.funct7) - self.assertEqual((yield self.decoder.funct7_v), test.funct7 is not None) + if test.funct7 is not None: + self.assertEqual((yield self.decoder.funct7), test.funct7) + self.assertEqual((yield self.decoder.funct7_v), test.funct7 is not None) - if test.funct12 is not None: - self.assertEqual((yield self.decoder.funct12), test.funct12) - self.assertEqual((yield self.decoder.funct12_v), test.funct12 is not None) + if test.funct12 is not None: + self.assertEqual((yield self.decoder.funct12), test.funct12) + self.assertEqual((yield self.decoder.funct12_v), test.funct12 is not None) - if test.rd is not None: - self.assertEqual((yield self.decoder.rd), test.rd) - self.assertEqual((yield self.decoder.rd_v), test.rd is not None) + if test.rd is not None: + self.assertEqual((yield self.decoder.rd), test.rd) + self.assertEqual((yield self.decoder.rd_v), test.rd is not None) - if test.rs1 is not None: - self.assertEqual((yield self.decoder.rs1), test.rs1) - self.assertEqual((yield self.decoder.rs1_v), test.rs1 is not None) + if test.rs1 is not None: + self.assertEqual((yield self.decoder.rs1), test.rs1) + self.assertEqual((yield self.decoder.rs1_v), test.rs1 is not None) - if test.rs2 is not None: - self.assertEqual((yield self.decoder.rs2), test.rs2) - self.assertEqual((yield self.decoder.rs2_v), test.rs2 is not None) + if test.rs2 is not None: + self.assertEqual((yield self.decoder.rs2), test.rs2) + self.assertEqual((yield self.decoder.rs2_v), test.rs2 is not None) - if test.imm is not None: - self.assertEqual((yield self.decoder.imm.as_signed()), test.imm) + if test.imm is not None: + self.assertEqual((yield self.decoder.imm.as_signed()), test.imm) - if test.succ is not None: - self.assertEqual((yield self.decoder.succ), test.succ) + if test.succ is not None: + self.assertEqual((yield self.decoder.succ), test.succ) - if test.pred is not None: - self.assertEqual((yield self.decoder.pred), test.pred) + if test.pred is not None: + self.assertEqual((yield self.decoder.pred), test.pred) - if test.fm is not None: - self.assertEqual((yield self.decoder.fm), test.fm) + if test.fm is not None: + self.assertEqual((yield self.decoder.fm), test.fm) - if test.csr is not None: - self.assertEqual((yield self.decoder.csr), test.csr) + if test.csr is not None: + self.assertEqual((yield self.decoder.csr), test.csr) - self.assertEqual((yield self.decoder.optype), test.op) + self.assertEqual((yield self.decoder.optype), test.op) with self.run_simulation(self.decoder) as sim: sim.add_process(process) def test_i(self): - for test in self.DECODER_TESTS_I: - self.do_test(test) + self.do_test(self.DECODER_TESTS_I) def test_zifencei(self): - for test in self.DECODER_TESTS_ZIFENCEI: - self.do_test(test) + self.do_test(self.DECODER_TESTS_ZIFENCEI) def test_zicsr(self): - for test in self.DECODER_TESTS_ZICSR: - self.do_test(test) + self.do_test(self.DECODER_TESTS_ZICSR) def test_m(self): - for test in self.DECODER_TESTS_M: - self.do_test(test) + self.do_test(self.DECODER_TESTS_M) def test_illegal(self): - for test in self.DECODER_TESTS_ILLEGAL: - self.do_test(test) + self.do_test(self.DECODER_TESTS_ILLEGAL) def test_xintmachinemode(self): - for test in self.DECODER_TESTS_XINTMACHINEMODE: - self.do_test(test) + self.do_test(self.DECODER_TESTS_XINTMACHINEMODE) def test_xintsupervisor(self): - for test in self.DECODER_TESTS_XINTSUPERVISOR: - self.do_test(test) + self.do_test(self.DECODER_TESTS_XINTSUPERVISOR) def test_zbb(self): - for test in self.DECODER_TESTS_ZBB: - self.do_test(test) + self.do_test(self.DECODER_TESTS_ZBB) class TestDecoderEExtLegal(TestCaseWithSimulator): @@ -372,7 +365,7 @@ def test_decoded_distinguishable(self): Encoding(Opcode.OP_IMM, Funct3.BSET, Funct7.BSET), Encoding(Opcode.OP_IMM, Funct3.BINV, Funct7.BINV), }, - OpType.BIT_MANIPULATION: { + OpType.BIT_ROTATION: { Encoding(Opcode.OP_IMM, Funct3.ROR, Funct7.ROR), }, } diff --git a/test/frontend/test_rvc.py b/test/frontend/test_rvc.py index 668ead899..d31d21e63 100644 --- a/test/frontend/test_rvc.py +++ b/test/frontend/test_rvc.py @@ -8,7 +8,7 @@ from coreblocks.params.configurations import test_core_config from transactron.utils import ValueLike -from ..common import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator COMMON_TESTS = [ # Illegal instruction diff --git a/test/fu/functional_common.py b/test/fu/functional_common.py index b930399f4..7d21682cb 100644 --- a/test/fu/functional_common.py +++ b/test/fu/functional_common.py @@ -16,7 +16,7 @@ from coreblocks.params.layouts import ExceptionRegisterLayouts from coreblocks.params.optypes import OpType from transactron.lib import Adapter -from test.common import RecordIntDict, RecordIntDictRet, TestbenchIO, TestCaseWithSimulator, SimpleTestCircuit +from transactron.testing import RecordIntDict, RecordIntDictRet, TestbenchIO, TestCaseWithSimulator, SimpleTestCircuit from transactron.utils import ModuleConnector @@ -138,11 +138,12 @@ def setUp(self): cause = None if "exception" in results: cause = results["exception"] + self.exceptions.append({"rob_id": rob_id, "cause": cause, "pc": results.setdefault("exception_pc", pc)}) + results.pop("exception") + results.pop("exception_pc") self.responses.append({"rob_id": rob_id, "rp_dst": rp_dst, "exception": int(cause is not None)} | results) - if cause is not None: - self.exceptions.append({"rob_id": rob_id, "cause": cause, "pc": pc}) def consumer(self): while self.responses: diff --git a/test/fu/test_alu.py b/test/fu/test_alu.py index f350af49c..7e973fc92 100644 --- a/test/fu/test_alu.py +++ b/test/fu/test_alu.py @@ -28,14 +28,14 @@ class AluUnitTest(FunctionalUnitTestCase[AluFn.Fn]): AluFn.Fn.MAXU: ExecFn(OpType.BIT_MANIPULATION, Funct3.MAXU, Funct7.MAX), AluFn.Fn.MIN: ExecFn(OpType.BIT_MANIPULATION, Funct3.MIN, Funct7.MIN), AluFn.Fn.MINU: ExecFn(OpType.BIT_MANIPULATION, Funct3.MINU, Funct7.MIN), - AluFn.Fn.CPOP: ExecFn(OpType.UNARY_BIT_MANIPULATION_5, Funct3.CPOP, Funct7.CPOP), - AluFn.Fn.SEXTB: ExecFn(OpType.UNARY_BIT_MANIPULATION_1, Funct3.SEXTB, Funct7.SEXTB), - AluFn.Fn.ZEXTH: ExecFn(OpType.UNARY_BIT_MANIPULATION_1, Funct3.ZEXTH, Funct7.ZEXTH), - AluFn.Fn.SEXTH: ExecFn(OpType.UNARY_BIT_MANIPULATION_2, Funct3.SEXTH, Funct7.SEXTH), - AluFn.Fn.ORCB: ExecFn(OpType.UNARY_BIT_MANIPULATION_1, Funct3.ORCB, Funct7.ORCB), - AluFn.Fn.REV8: ExecFn(OpType.UNARY_BIT_MANIPULATION_1, Funct3.REV8, Funct7.REV8), - AluFn.Fn.CLZ: ExecFn(OpType.UNARY_BIT_MANIPULATION_3, Funct3.CLZ, Funct7.CLZ), - AluFn.Fn.CTZ: ExecFn(OpType.UNARY_BIT_MANIPULATION_4, Funct3.CTZ, Funct7.CTZ), + AluFn.Fn.SEXTB: ExecFn(OpType.UNARY_BIT_MANIPULATION_1, Funct3.SEXTB), + AluFn.Fn.ZEXTH: ExecFn(OpType.UNARY_BIT_MANIPULATION_1, Funct3.ZEXTH), + AluFn.Fn.REV8: ExecFn(OpType.UNARY_BIT_MANIPULATION_1, Funct3.REV8), + AluFn.Fn.SEXTH: ExecFn(OpType.UNARY_BIT_MANIPULATION_2, Funct3.SEXTH), + AluFn.Fn.ORCB: ExecFn(OpType.UNARY_BIT_MANIPULATION_2, Funct3.ORCB), + AluFn.Fn.CLZ: ExecFn(OpType.UNARY_BIT_MANIPULATION_3, Funct3.CLZ), + AluFn.Fn.CTZ: ExecFn(OpType.UNARY_BIT_MANIPULATION_4, Funct3.CTZ), + AluFn.Fn.CPOP: ExecFn(OpType.UNARY_BIT_MANIPULATION_5, Funct3.CPOP), } @staticmethod diff --git a/test/fu/test_fu_decoder.py b/test/fu/test_fu_decoder.py index 7e6ba7d98..965e07e40 100644 --- a/test/fu/test_fu_decoder.py +++ b/test/fu/test_fu_decoder.py @@ -3,7 +3,7 @@ from amaranth import * from amaranth.sim import * -from ..common import SimpleTestCircuit, TestCaseWithSimulator +from transactron.testing import SimpleTestCircuit, TestCaseWithSimulator from coreblocks.fu.fu_decoder import DecoderManager, Decoder from coreblocks.params import OpType, Funct3, Funct7, GenParams diff --git a/test/fu/test_jb_unit.py b/test/fu/test_jb_unit.py index 974f2699c..559062989 100644 --- a/test/fu/test_jb_unit.py +++ b/test/fu/test_jb_unit.py @@ -1,10 +1,11 @@ from amaranth import * +from amaranth.lib.data import StructLayout from parameterized import parameterized_class from coreblocks.params import * from coreblocks.fu.jumpbranch import JumpBranchFuncUnit, JumpBranchFn, JumpComponent from transactron import Method, def_method, TModule -from coreblocks.params.layouts import FuncUnitLayouts, FetchLayouts +from coreblocks.params.layouts import FuncUnitLayouts from coreblocks.utils.protocols import FuncUnit from transactron.utils import signed_to_int @@ -16,7 +17,11 @@ class JumpBranchWrapper(Elaboratable): def __init__(self, gen_params: GenParams): self.jb = JumpBranchFuncUnit(gen_params) self.issue = self.jb.issue - self.accept = Method(o=gen_params.get(FuncUnitLayouts).accept + gen_params.get(FetchLayouts).branch_verify) + self.accept = Method( + o=StructLayout( + gen_params.get(FuncUnitLayouts).accept.members | gen_params.get(JumpBranchLayouts).verify_branch.members + ) + ) def elaborate(self, platform): m = TModule() @@ -26,15 +31,15 @@ def elaborate(self, platform): @def_method(m, self.accept) def _(arg): res = self.jb.accept(m) - br = self.jb.branch_result(m) + verify = self.jb.fifo_branch_resolved.read(m) return { - "from_pc": br.from_pc, - "next_pc": br.next_pc, - "resume_from_exception": 0, "result": res.result, "rob_id": res.rob_id, "rp_dst": res.rp_dst, "exception": res.exception, + "next_pc": verify.next_pc, + "from_pc": verify.from_pc, + "misprediction": verify.misprediction, } return m @@ -77,12 +82,18 @@ def compute_result(i1: int, i2: int, i_imm: int, pc: int, fn: JumpBranchFn.Fn, x next_pc &= max_int res &= max_int + misprediction = next_pc != pc + 4 + exception = None + exception_pc = pc if next_pc & 0b11 != 0: exception = ExceptionCause.INSTRUCTION_ADDRESS_MISALIGNED + elif misprediction: + exception = ExceptionCause._COREBLOCKS_MISPREDICTION + exception_pc = next_pc - return {"result": res, "from_pc": pc, "next_pc": next_pc, "resume_from_exception": 0} | ( - {"exception": exception} if exception is not None else {} + return {"result": res, "from_pc": pc, "next_pc": next_pc, "misprediction": misprediction} | ( + {"exception": exception, "exception_pc": exception_pc} if exception is not None else {} ) diff --git a/test/fu/test_shift_unit.py b/test/fu/test_shift_unit.py index ba2de99e4..20eed6d55 100644 --- a/test/fu/test_shift_unit.py +++ b/test/fu/test_shift_unit.py @@ -12,8 +12,8 @@ class ShiftUnitTest(FunctionalUnitTestCase[ShiftUnitFn.Fn]): ShiftUnitFn.Fn.SLL: ExecFn(OpType.SHIFT, Funct3.SLL), ShiftUnitFn.Fn.SRL: ExecFn(OpType.SHIFT, Funct3.SR, Funct7.SL), ShiftUnitFn.Fn.SRA: ExecFn(OpType.SHIFT, Funct3.SR, Funct7.SA), - ShiftUnitFn.Fn.ROL: ExecFn(OpType.BIT_MANIPULATION, Funct3.ROL, Funct7.ROL), - ShiftUnitFn.Fn.ROR: ExecFn(OpType.BIT_MANIPULATION, Funct3.ROR, Funct7.ROR), + ShiftUnitFn.Fn.ROL: ExecFn(OpType.BIT_ROTATION, Funct3.ROL, Funct7.ROL), + ShiftUnitFn.Fn.ROR: ExecFn(OpType.BIT_ROTATION, Funct3.ROR, Funct7.ROR), } @staticmethod diff --git a/test/fu/test_unsigned_mul_unit.py b/test/fu/test_unsigned_mul_unit.py index 901f95bd7..56b3657e6 100644 --- a/test/fu/test_unsigned_mul_unit.py +++ b/test/fu/test_unsigned_mul_unit.py @@ -10,7 +10,7 @@ from coreblocks.fu.unsigned_multiplication.sequence import SequentialUnsignedMul from coreblocks.fu.unsigned_multiplication.shift import ShiftUnsignedMul -from test.common import TestCaseWithSimulator, SimpleTestCircuit +from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit from coreblocks.params import GenParams from coreblocks.params.configurations import test_core_config diff --git a/test/lsu/test_dummylsu.py b/test/lsu/test_dummylsu.py index a1328cfe3..61e9a3f29 100644 --- a/test/lsu/test_dummylsu.py +++ b/test/lsu/test_dummylsu.py @@ -14,7 +14,8 @@ from transactron.utils.dependencies import DependencyManager from coreblocks.params.layouts import ExceptionRegisterLayouts from coreblocks.peripherals.wishbone import * -from test.common import TestbenchIO, TestCaseWithSimulator, def_method_mock +from transactron.testing import TestbenchIO, TestCaseWithSimulator, def_method_mock +from coreblocks.peripherals.bus_adapter import WishboneMasterAdapter from test.peripherals.test_wishbone import WishboneInterfaceWrapper @@ -86,6 +87,7 @@ def elaborate(self, platform): ) self.bus = WishboneMaster(wb_params) + self.bus_master_adapter = WishboneMasterAdapter(self.bus) m.submodules.exception_report = self.exception_report = TestbenchIO( Adapter(i=self.gen.get(ExceptionRegisterLayouts).report) @@ -93,14 +95,15 @@ def elaborate(self, platform): self.gen.get(DependencyManager).add_dependency(ExceptionReportKey(), self.exception_report.adapter.iface) - m.submodules.func_unit = func_unit = LSUDummy(self.gen, self.bus) + m.submodules.func_unit = func_unit = LSUDummy(self.gen, self.bus_master_adapter) m.submodules.select_mock = self.select = TestbenchIO(AdapterTrans(func_unit.select)) m.submodules.insert_mock = self.insert = TestbenchIO(AdapterTrans(func_unit.insert)) m.submodules.update_mock = self.update = TestbenchIO(AdapterTrans(func_unit.update)) m.submodules.get_result_mock = self.get_result = TestbenchIO(AdapterTrans(func_unit.get_result)) m.submodules.precommit_mock = self.precommit = TestbenchIO(AdapterTrans(func_unit.precommit)) - self.io_in = WishboneInterfaceWrapper(self.bus.wbMaster) + self.io_in = WishboneInterfaceWrapper(self.bus.wb_master) + m.submodules.bus_master_adapter = self.bus_master_adapter m.submodules.bus = self.bus return m diff --git a/test/lsu/test_pma.py b/test/lsu/test_pma.py index d02b218c0..07c36652d 100644 --- a/test/lsu/test_pma.py +++ b/test/lsu/test_pma.py @@ -10,7 +10,8 @@ from transactron.utils.dependencies import DependencyManager from coreblocks.params.layouts import ExceptionRegisterLayouts from coreblocks.peripherals.wishbone import * -from test.common import TestbenchIO, TestCaseWithSimulator, def_method_mock +from transactron.testing import TestbenchIO, TestCaseWithSimulator, def_method_mock +from coreblocks.peripherals.bus_adapter import WishboneMasterAdapter from test.peripherals.test_wishbone import WishboneInterfaceWrapper @@ -55,6 +56,7 @@ def elaborate(self, platform): ) self.bus = WishboneMaster(wb_params) + self.bus_master_adapter = WishboneMasterAdapter(self.bus) m.submodules.exception_report = self.exception_report = TestbenchIO( Adapter(i=self.gen.get(ExceptionRegisterLayouts).report) @@ -62,15 +64,16 @@ def elaborate(self, platform): self.gen.get(DependencyManager).add_dependency(ExceptionReportKey(), self.exception_report.adapter.iface) - m.submodules.func_unit = func_unit = LSUDummy(self.gen, self.bus) + m.submodules.func_unit = func_unit = LSUDummy(self.gen, self.bus_master_adapter) m.submodules.select_mock = self.select = TestbenchIO(AdapterTrans(func_unit.select)) m.submodules.insert_mock = self.insert = TestbenchIO(AdapterTrans(func_unit.insert)) m.submodules.update_mock = self.update = TestbenchIO(AdapterTrans(func_unit.update)) m.submodules.get_result_mock = self.get_result = TestbenchIO(AdapterTrans(func_unit.get_result)) m.submodules.precommit_mock = self.precommit = TestbenchIO(AdapterTrans(func_unit.precommit)) - self.io_in = WishboneInterfaceWrapper(self.bus.wbMaster) + self.io_in = WishboneInterfaceWrapper(self.bus.wb_master) m.submodules.bus = self.bus + m.submodules.bus_master_adapter = self.bus_master_adapter return m diff --git a/test/peripherals/test_axi_lite.py b/test/peripherals/test_axi_lite.py index 9d887ce9f..27821c899 100644 --- a/test/peripherals/test_axi_lite.py +++ b/test/peripherals/test_axi_lite.py @@ -2,11 +2,11 @@ from transactron import Method, def_method, TModule from transactron.lib import AdapterTrans -from ..common import * +from transactron.testing import * class AXILiteInterfaceWrapper: - def __init__(self, axi_lite_master: Record): + def __init__(self, axi_lite_master: AXILiteInterface): self.axi_lite = axi_lite_master def slave_ra_ready(self, rdy=1): diff --git a/test/peripherals/test_wishbone.py b/test/peripherals/test_wishbone.py index 9db5b9430..4dd5485ed 100644 --- a/test/peripherals/test_wishbone.py +++ b/test/peripherals/test_wishbone.py @@ -1,16 +1,18 @@ import random from collections import deque +from amaranth.lib.wiring import connect + from coreblocks.peripherals.wishbone import * from transactron.lib import AdapterTrans -from ..common import * +from transactron.testing import * class WishboneInterfaceWrapper: - def __init__(self, wishbone_record): - self.wb = wishbone_record + def __init__(self, wishbone_interface: WishboneInterface): + self.wb = wishbone_interface def master_set(self, addr, data, we): yield self.wb.dat_w.eq(data) @@ -107,7 +109,7 @@ def result_process(): self.assertTrue(resp["err"]) def slave(): - wwb = WishboneInterfaceWrapper(twbm.wbm.wbMaster) + wwb = WishboneInterfaceWrapper(twbm.wbm.wb_master) yield from wwb.slave_wait() yield from wwb.slave_verify(2, 0, 0, 1) @@ -142,10 +144,10 @@ def slave(): class TestWishboneMuxer(TestCaseWithSimulator): def test_manual(self): - wb_master = WishboneInterfaceWrapper(Record(WishboneLayout(WishboneParameters()).wb_layout)) num_slaves = 4 - slaves = [WishboneInterfaceWrapper(Record.like(wb_master.wb, name=f"sl{i}")) for i in range(num_slaves)] - mux = WishboneMuxer(wb_master.wb, [s.wb for s in slaves], Signal(num_slaves)) + mux = WishboneMuxer(WishboneParameters(), num_slaves, Signal(num_slaves)) + slaves = [WishboneInterfaceWrapper(slave) for slave in mux.slaves] + wb_master = WishboneInterfaceWrapper(mux.master_wb) def process(): # check full communiaction @@ -183,9 +185,9 @@ def process(): class TestWishboneAribiter(TestCaseWithSimulator): def test_manual(self): - slave = WishboneInterfaceWrapper(Record(WishboneLayout(WishboneParameters()).wb_layout)) - masters = [WishboneInterfaceWrapper(Record.like(slave.wb, name=f"mst{i}")) for i in range(2)] - arb = WishboneArbiter(slave.wb, [m.wb for m in masters]) + arb = WishboneArbiter(WishboneParameters(), 2) + slave = WishboneInterfaceWrapper(arb.slave_wb) + masters = [WishboneInterfaceWrapper(master) for master in arb.masters] def process(): yield from masters[0].master_set(2, 3, 1) @@ -319,7 +321,7 @@ def elaborate(self, platform): m.submodules.request = self.request = TestbenchIO(AdapterTrans(self.mem_master.request)) m.submodules.result = self.result = TestbenchIO(AdapterTrans(self.mem_master.result)) - m.d.comb += self.mem_master.wbMaster.connect(self.mem_slave.bus) + connect(m, self.mem_master.wb_master, self.mem_slave.bus) return m @@ -369,7 +371,7 @@ def result_process(): self.assertEqual(res["data"], mem_state[req["addr"]]) def write_process(): - wwb = WishboneInterfaceWrapper(self.m.mem_master.wbMaster) + wwb = WishboneInterfaceWrapper(self.m.mem_master.wb_master) for _ in range(self.iters): yield from wwb.wait_ack() req = wr_queue.pop() diff --git a/test/regression/benchmark.py b/test/regression/benchmark.py index e43b6ee81..5b465d650 100644 --- a/test/regression/benchmark.py +++ b/test/regression/benchmark.py @@ -1,5 +1,6 @@ import os -import json +from dataclasses import dataclass +from dataclasses_json import dataclass_json from pathlib import Path from .memory import * @@ -8,6 +9,27 @@ test_dir = Path(__file__).parent.parent embench_dir = test_dir.joinpath("external/embench/build/src") results_dir = test_dir.joinpath("regression/benchmark_results") +profile_dir = test_dir.joinpath("__profiles__") + + +@dataclass_json +@dataclass +class BenchmarkResult: + """Result of running a single benchmark. + + Attributes + ---------- + cycles: int + A number of cycles the benchmark took. + instr: int + A count of instructions commited during the benchmark. + metric_values: dict[str, dict[str, int]] + Values of the core metrics taken at the end of the simulation. + """ + + cycles: int + instr: int + metric_values: dict[str, dict[str, int]] class MMIO(RandomAccessMemory): @@ -54,16 +76,20 @@ async def run_benchmark(sim_backend: SimulationBackend, benchmark_name: str): mem_model = CoreMemoryModel(mem_segments) - success = await sim_backend.run(mem_model, timeout_cycles=5000000) + result = await sim_backend.run(mem_model, timeout_cycles=2000000) + + if result.profile is not None: + os.makedirs(profile_dir, exist_ok=True) + result.profile.encode(f"{profile_dir}/benchmark.{benchmark_name}.json") - if not success: + if not result.success: raise RuntimeError("Simulation timed out") if mmio.return_code() != 0: raise RuntimeError("The benchmark exited with a non-zero return code: %d" % mmio.return_code()) - results = {"cycle": mmio.cycle_cnt(), "instr": mmio.instr_cnt()} + bench_results = BenchmarkResult(cycles=mmio.cycle_cnt(), instr=mmio.instr_cnt(), metric_values=result.metric_values) os.makedirs(str(results_dir), exist_ok=True) with open(f"{str(results_dir)}/{benchmark_name}.json", "w") as outfile: - json.dump(results, outfile) + outfile.write(bench_results.to_json()) # type: ignore diff --git a/test/regression/cocotb.py b/test/regression/cocotb.py index e59bcef03..444360d04 100644 --- a/test/regression/cocotb.py +++ b/test/regression/cocotb.py @@ -1,5 +1,7 @@ from decimal import Decimal import inspect +import re +import os from typing import Any from collections.abc import Coroutine from dataclasses import dataclass @@ -7,11 +9,15 @@ import cocotb from cocotb.clock import Clock, Timer from cocotb.handle import ModifiableObject -from cocotb.triggers import FallingEdge, Event, with_timeout +from cocotb.triggers import FallingEdge, Event, RisingEdge, with_timeout from cocotb_bus.bus import Bus +from cocotb.result import SimTimeoutError from .memory import * -from .common import SimulationBackend +from .common import SimulationBackend, SimulationExecutionResult + +from transactron.profiler import CycleProfile, MethodSamples, Profile, ProfileSamples, TransactionSamples +from transactron.utils.gen import GenerationInfo @dataclass @@ -82,10 +88,6 @@ async def start(self): sig_s = WishboneSlaveSignals() if sig_m.we: - cocotb.logging.debug( - f"Wishbone bus '{self.name}' write request: " - f"addr=0x{addr:x} data=0x{int(sig_m.dat_w):x} sel={sig_m.sel}" - ) resp = self.model.write( WriteRequest( addr=addr, @@ -95,7 +97,6 @@ async def start(self): ) ) else: - cocotb.logging.debug(f"Wishbone bus '{self.name}' read request: addr=0x{addr:x} sel={sig_m.sel}") resp = self.model.read( ReadRequest( addr=addr, @@ -118,11 +119,6 @@ async def start(self): raise ValueError("Bus doesn't support rty") sig_s.rty = 1 - cocotb.logging.debug( - f"Wishbone bus '{self.name}' response: " - f"ack={sig_s.ack} err={sig_s.err} rty={sig_s.rty} data={int(sig_s.dat_r):x}" - ) - for _ in range(self.delay): await clock_edge_event # type: ignore @@ -136,7 +132,96 @@ def __init__(self, dut): self.dut = dut self.finish_event = Event() - async def run(self, mem_model: CoreMemoryModel, timeout_cycles: int = 5000) -> bool: + try: + gen_info_path = os.environ["_COREBLOCKS_GEN_INFO"] + except KeyError: + raise RuntimeError("No core generation info provided") + + self.gen_info = GenerationInfo.decode(gen_info_path) + + self.log_level = os.environ["__TRANSACTRON_LOG_LEVEL"] + self.log_filter = os.environ["__TRANSACTRON_LOG_FILTER"] + + cocotb.logging.getLogger().setLevel(self.log_level) + + def get_cocotb_handle(self, path_components: list[str]) -> ModifiableObject: + obj = self.dut + # Skip the first component, as it is already referenced in "self.dut" + for component in path_components[1:]: + try: + # As the component may start with '_' character, we need to use '_id' + # function instead of 'getattr' - this is required by cocotb. + obj = obj._id(component, extended=False) + except AttributeError: + # Try with escaped name + if component[0] != "\\" and component[-1] != " ": + obj = obj._id("\\" + component + " ", extended=False) + else: + raise + + return obj + + async def profile_handler(self, clock, profile: Profile): + clock_edge_event = RisingEdge(clock) + + while True: + samples = ProfileSamples() + + for transaction_id, location in self.gen_info.transaction_signals_location.items(): + request_val = self.get_cocotb_handle(location.request) + runnable_val = self.get_cocotb_handle(location.runnable) + grant_val = self.get_cocotb_handle(location.grant) + samples.transactions[transaction_id] = TransactionSamples( + bool(request_val.value), bool(runnable_val.value), bool(grant_val.value) + ) + + for method_id, location in self.gen_info.method_signals_location.items(): + run_val = self.get_cocotb_handle(location.run) + samples.methods[method_id] = MethodSamples(bool(run_val.value)) + + cprof = CycleProfile.make(samples, self.gen_info.profile_data) + profile.cycles.append(cprof) + + await clock_edge_event # type: ignore + + async def logging_handler(self, clock): + clock_edge_event = FallingEdge(clock) + + log_level = cocotb.logging.getLogger().level + + logs = [ + (rec, self.get_cocotb_handle(rec.trigger_location)) + for rec in self.gen_info.logs + if rec.level >= log_level and re.search(self.log_filter, rec.logger_name) + ] + + while True: + for rec, trigger_handle in logs: + if not trigger_handle.value: + continue + + values: list[int] = [] + for field in rec.fields_location: + values.append(int(self.get_cocotb_handle(field).value)) + + formatted_msg = rec.format(*values) + + cocotb_log = cocotb.logging.getLogger(rec.logger_name) + + cocotb_log.log( + rec.level, + "%s:%d] %s", + rec.location[0], + rec.location[1], + formatted_msg, + ) + + if rec.level >= cocotb.logging.ERROR: + assert False, f"Assertion failed at {rec.location[0], rec.location[1]}: {formatted_msg}" + + await clock_edge_event # type: ignore + + async def run(self, mem_model: CoreMemoryModel, timeout_cycles: int = 5000) -> SimulationExecutionResult: clk = Clock(self.dut.clk, 1, "ns") cocotb.start_soon(clk.start()) @@ -150,9 +235,32 @@ async def run(self, mem_model: CoreMemoryModel, timeout_cycles: int = 5000) -> b data_wb = WishboneSlave(self.dut, "wb_data", self.dut.clk, mem_model, is_instr_bus=False) cocotb.start_soon(data_wb.start()) - res = await with_timeout(self.finish_event.wait(), timeout_cycles, "ns") + profile = None + if "__TRANSACTRON_PROFILE" in os.environ: + profile = Profile() + profile.transactions_and_methods = self.gen_info.profile_data.transactions_and_methods + cocotb.start_soon(self.profile_handler(self.dut.clk, profile)) + + cocotb.start_soon(self.logging_handler(self.dut.clk)) + + success = True + try: + await with_timeout(self.finish_event.wait(), timeout_cycles, "ns") + except SimTimeoutError: + success = False + + result = SimulationExecutionResult(success) + + result.profile = profile + + for metric_name, metric_loc in self.gen_info.metrics_location.items(): + result.metric_values[metric_name] = {} + for reg_name, reg_loc in metric_loc.regs.items(): + value = int(self.get_cocotb_handle(reg_loc)) + result.metric_values[metric_name][reg_name] = value + cocotb.logging.info(f"Metric {metric_name}/{reg_name}={value}") - return res is not None + return result def stop(self): self.finish_event.set() diff --git a/test/regression/cocotb/benchmark.Makefile b/test/regression/cocotb/benchmark.Makefile index 015b79d6e..5c89d3785 100644 --- a/test/regression/cocotb/benchmark.Makefile +++ b/test/regression/cocotb/benchmark.Makefile @@ -4,9 +4,6 @@ SIM ?= verilator TOPLEVEL_LANG ?= verilog -VERILOG_SOURCES += $(PWD)/../../../core.v -# use VHDL_SOURCES for VHDL files - # TOPLEVEL is the name of the toplevel module in your Verilog or VHDL file TOPLEVEL = top diff --git a/test/regression/cocotb/benchmark_entrypoint.py b/test/regression/cocotb/benchmark_entrypoint.py index fb3fa59c5..d700b3a4e 100644 --- a/test/regression/cocotb/benchmark_entrypoint.py +++ b/test/regression/cocotb/benchmark_entrypoint.py @@ -1,5 +1,4 @@ import sys -import cocotb from pathlib import Path top_dir = Path(__file__).parent.parent.parent.parent @@ -10,7 +9,6 @@ async def _do_benchmark(dut, benchmark_name): - cocotb.logging.getLogger().setLevel(cocotb.logging.INFO) await run_benchmark(CocotbSimulation(dut), benchmark_name) diff --git a/test/regression/cocotb/signature.Makefile b/test/regression/cocotb/signature.Makefile index e7da43e25..74b803083 100644 --- a/test/regression/cocotb/signature.Makefile +++ b/test/regression/cocotb/signature.Makefile @@ -4,9 +4,6 @@ SIM ?= verilator TOPLEVEL_LANG ?= verilog -VERILOG_SOURCES += $(PWD)/../../../core.v -# use VHDL_SOURCES for VHDL files - # TOPLEVEL is the name of the toplevel module in your Verilog or VHDL file TOPLEVEL = top diff --git a/test/regression/cocotb/signature_entrypoint.py b/test/regression/cocotb/signature_entrypoint.py index 1508502fe..4b8a9d212 100644 --- a/test/regression/cocotb/signature_entrypoint.py +++ b/test/regression/cocotb/signature_entrypoint.py @@ -12,8 +12,6 @@ @cocotb.test() async def do_test(dut): - cocotb.logging.getLogger().setLevel(cocotb.logging.INFO) - test_name = os.environ["TESTNAME"] if test_name is None: raise RuntimeError("No ELF file provided") diff --git a/test/regression/cocotb/test.Makefile b/test/regression/cocotb/test.Makefile index e81e31804..bda120bc1 100644 --- a/test/regression/cocotb/test.Makefile +++ b/test/regression/cocotb/test.Makefile @@ -4,9 +4,6 @@ SIM ?= verilator TOPLEVEL_LANG ?= verilog -VERILOG_SOURCES += $(PWD)/../../../core.v -# use VHDL_SOURCES for VHDL files - # TOPLEVEL is the name of the toplevel module in your Verilog or VHDL file TOPLEVEL = top diff --git a/test/regression/cocotb/test_entrypoint.py b/test/regression/cocotb/test_entrypoint.py index d5e8fb7a9..71b0ed64f 100644 --- a/test/regression/cocotb/test_entrypoint.py +++ b/test/regression/cocotb/test_entrypoint.py @@ -1,17 +1,21 @@ import sys -import cocotb from pathlib import Path top_dir = Path(__file__).parent.parent.parent.parent sys.path.insert(0, str(top_dir)) from test.regression.cocotb import CocotbSimulation, generate_tests # noqa: E402 -from test.regression.test import run_test, get_all_test_names # noqa: E402 +from test.regression.test_regression import run_test # noqa: E402 +from test.regression.conftest import get_all_test_names # noqa: E402 + +# used to build the Verilator model without starting tests +empty_testcase_name = "SKIP" async def do_test(dut, test_name): - cocotb.logging.getLogger().setLevel(cocotb.logging.INFO) + if test_name == empty_testcase_name: + return await run_test(CocotbSimulation(dut), test_name) -generate_tests(do_test, list(get_all_test_names())) +generate_tests(do_test, list(get_all_test_names()) + [empty_testcase_name]) diff --git a/test/regression/common.py b/test/regression/common.py index bb61e4613..62481c17e 100644 --- a/test/regression/common.py +++ b/test/regression/common.py @@ -1,11 +1,31 @@ from abc import ABC, abstractmethod - +from dataclasses import dataclass, field +from typing import Optional from .memory import CoreMemoryModel +from transactron.profiler import Profile + + +@dataclass +class SimulationExecutionResult: + """Information about the result of the simulation. + + Attributes + ---------- + success: bool + Whether the simulation finished successfully, i.e. no timeouts, + no exceptions, no failed assertions etc. + metric_values: dict[str, dict[str, int]] + Values of the core metrics taken at the end of the simulation. + """ + + success: bool + metric_values: dict[str, dict[str, int]] = field(default_factory=dict) + profile: Optional[Profile] = None class SimulationBackend(ABC): @abstractmethod - async def run(self, mem_model: CoreMemoryModel, timeout_cycles: int) -> bool: + async def run(self, mem_model: CoreMemoryModel, timeout_cycles: int) -> SimulationExecutionResult: raise NotImplementedError @abstractmethod diff --git a/test/regression/conftest.py b/test/regression/conftest.py new file mode 100644 index 000000000..bf0f1cc96 --- /dev/null +++ b/test/regression/conftest.py @@ -0,0 +1,41 @@ +from glob import glob +from pathlib import Path +import pytest +import subprocess + +test_dir = Path(__file__).parent.parent +riscv_tests_dir = test_dir.joinpath("external/riscv-tests") +profile_dir = test_dir.joinpath("__profiles__") + + +def get_all_test_names(): + return sorted([name[5:] for name in glob("test-*", root_dir=riscv_tests_dir)]) + + +def load_regression_tests() -> list[str]: + all_tests = set(get_all_test_names()) + if len(all_tests) == 0: + res = subprocess.run(["make", "-C", "test/external/riscv-tests"]) + if res.returncode != 0: + print("Couldn't build regression tests") + all_tests = set(get_all_test_names()) + + exclude = {"rv32ui-ma_data", "rv32ui-fence_i"} + + return sorted(list(all_tests - exclude)) + + +def pytest_generate_tests(metafunc: pytest.Metafunc): + all_tests = ( + load_regression_tests() + ) # The list has to be always in the same order (e.g. sorted) to allow for parallel testing + if "test_name" in metafunc.fixturenames: + metafunc.parametrize( + "test_name", + [test_name for test_name in all_tests], + ) + + +def pytest_runtest_setup(item: pytest.Item): + if not item.config.getoption("--coreblocks-regression", default=False): # type: ignore + pytest.skip("need --coreblocks-regression option to run this test") diff --git a/test/regression/memory.py b/test/regression/memory.py index d09ea2c10..70b8a9496 100644 --- a/test/regression/memory.py +++ b/test/regression/memory.py @@ -97,9 +97,10 @@ def write(self, req: WriteRequest) -> WriteReply: class CoreMemoryModel: - def __init__(self, segments: list[MemorySegment], fail_on_undefined=True): + def __init__(self, segments: list[MemorySegment], fail_on_undefined_read=False, fail_on_undefined_write=True): self.segments = segments - self.fail_on_undefined = fail_on_undefined + self.fail_on_undefined_read = fail_on_undefined_read # Core may do undefined reads speculatively + self.fail_on_undefined_write = fail_on_undefined_write def _run_on_range(self, f: Callable[[MemorySegment, TReq], TRep], req: TReq) -> Optional[TRep]: for seg in self.segments: @@ -124,7 +125,7 @@ def read(self, req: ReadRequest) -> ReadReply: rep = self._run_on_range(self._do_read, req) if rep is not None: return rep - if self.fail_on_undefined: + if self.fail_on_undefined_read: raise RuntimeError("Undefined read: %x" % req.addr) else: return ReadReply(status=ReplyStatus.ERROR) @@ -133,7 +134,7 @@ def write(self, req: WriteRequest) -> WriteReply: rep = self._run_on_range(self._do_write, req) if rep is not None: return rep - if self.fail_on_undefined: + if self.fail_on_undefined_write: raise RuntimeError("Undefined write: %x <= %x" % (req.addr, req.data)) else: return WriteReply(status=ReplyStatus.ERROR) diff --git a/test/regression/pysim.py b/test/regression/pysim.py index aedf32f60..804687bba 100644 --- a/test/regression/pysim.py +++ b/test/regression/pysim.py @@ -1,26 +1,46 @@ +import re +import os +import logging + from amaranth.sim import Passive, Settle -from amaranth.utils import log2_int +from amaranth.utils import exact_log2 +from amaranth import * -from .memory import * -from .common import SimulationBackend +from transactron.core import TransactionManagerKey -from ..common import SimpleTestCircuit, PysimSimulator +from .memory import * +from .common import SimulationBackend, SimulationExecutionResult + +from transactron.testing import ( + PysimSimulator, + TestGen, + profiler_process, + Profile, + make_logging_process, + parse_logging_level, +) +from transactron.utils.dependencies import DependencyContext, DependencyManager +from transactron.lib.metrics import HardwareMetricsManager from ..peripherals.test_wishbone import WishboneInterfaceWrapper from coreblocks.core import Core from coreblocks.params import GenParams from coreblocks.params.configurations import full_core_config -from coreblocks.peripherals.wishbone import WishboneBus +from coreblocks.peripherals.wishbone import WishboneSignature class PySimulation(SimulationBackend): - def __init__(self, verbose: bool, traces_file: Optional[str] = None): + def __init__(self, traces_file: Optional[str] = None): self.gp = GenParams(full_core_config) self.running = False self.cycle_cnt = 0 - self.verbose = verbose self.traces_file = traces_file + self.log_level = parse_logging_level(os.environ["__TRANSACTRON_LOG_LEVEL"]) + self.log_filter = os.environ["__TRANSACTRON_LOG_FILTER"] + + self.metrics_manager = HardwareMetricsManager() + def _wishbone_slave( self, mem_model: CoreMemoryModel, wb_ctrl: WishboneInterfaceWrapper, is_instr_bus: bool, delay: int = 0 ): @@ -33,23 +53,17 @@ def f(): word_width_bytes = self.gp.isa.xlen // 8 # Wishbone is addressing words, so we need to shift it a bit to get the real address. - addr = (yield wb_ctrl.wb.adr) << log2_int(word_width_bytes) + addr = (yield wb_ctrl.wb.adr) << exact_log2(word_width_bytes) sel = yield wb_ctrl.wb.sel dat_w = yield wb_ctrl.wb.dat_w resp_data = 0 - bus_name = "instr" if is_instr_bus else "data" - if (yield wb_ctrl.wb.we): - if self.verbose: - print(f"Wishbone '{bus_name}' bus write request: addr=0x{addr:x} data={dat_w:x} sel={sel:b}") resp = mem_model.write( WriteRequest(addr=addr, data=dat_w, byte_count=word_width_bytes, byte_sel=sel) ) else: - if self.verbose: - print(f"Wishbone '{bus_name}' bus read request: addr=0x{addr:x} sel={sel:b}") resp = mem_model.read( ReadRequest( addr=addr, @@ -60,9 +74,6 @@ def f(): ) resp_data = resp.data - if self.verbose: - print(f"Wishbone '{bus_name}' bus read response: data=0x{resp.data:x}") - ack = err = rty = 0 match resp.status: case ReplyStatus.OK: @@ -81,37 +92,90 @@ def f(): return f - def _waiter(self): + def _waiter(self, on_finish: Callable[[], TestGen[None]]): def f(): while self.running: self.cycle_cnt += 1 yield + yield from on_finish() + return f - async def run(self, mem_model: CoreMemoryModel, timeout_cycles: int = 5000) -> bool: - wb_instr_bus = WishboneBus(self.gp.wb_params) - wb_data_bus = WishboneBus(self.gp.wb_params) - core = Core(gen_params=self.gp, wb_instr_bus=wb_instr_bus, wb_data_bus=wb_data_bus) + def pretty_dump_metrics(self, metric_values: dict[str, dict[str, int]], filter_regexp: str = ".*"): + str = "=== Core metrics dump ===\n" - m = SimpleTestCircuit(core) + put_space_before = True + for metric_name in sorted(metric_values.keys()): + if not re.search(filter_regexp, metric_name): + continue - wb_instr_ctrl = WishboneInterfaceWrapper(wb_instr_bus) - wb_data_ctrl = WishboneInterfaceWrapper(wb_data_bus) + metric = self.metrics_manager.get_metrics()[metric_name] - self.running = True - self.cycle_cnt = 0 + if metric.description != "": + if not put_space_before: + str += "\n" + + str += f"# {metric.description}\n" + + for reg in metric.regs.values(): + reg_value = metric_values[metric_name][reg.name] + + desc = f" # {reg.description} [reg width={reg.width}]" + str += f"{metric_name}/{reg.name} {reg_value}{desc}\n" + + put_space_before = False + if metric.description != "": + str += "\n" + put_space_before = True + + logging.info(str) + + async def run(self, mem_model: CoreMemoryModel, timeout_cycles: int = 5000) -> SimulationExecutionResult: + with DependencyContext(DependencyManager()): + wb_instr_bus = WishboneSignature(self.gp.wb_params).create() + wb_data_bus = WishboneSignature(self.gp.wb_params).create() + core = Core(gen_params=self.gp, wb_instr_bus=wb_instr_bus, wb_data_bus=wb_data_bus) + + wb_instr_ctrl = WishboneInterfaceWrapper(wb_instr_bus) + wb_data_ctrl = WishboneInterfaceWrapper(wb_data_bus) + + self.running = True + self.cycle_cnt = 0 + + sim = PysimSimulator(core, max_cycles=timeout_cycles, traces_file=self.traces_file) + sim.add_sync_process(self._wishbone_slave(mem_model, wb_instr_ctrl, is_instr_bus=True)) + sim.add_sync_process(self._wishbone_slave(mem_model, wb_data_ctrl, is_instr_bus=False)) + + def on_error(): + raise RuntimeError("Simulation finished due to an error") + + sim.add_sync_process(make_logging_process(self.log_level, self.log_filter, on_error)) + + profile = None + if "__TRANSACTRON_PROFILE" in os.environ: + transaction_manager = DependencyContext.get().get_dependency(TransactionManagerKey()) + profile = Profile() + sim.add_sync_process(profiler_process(transaction_manager, profile)) + + metric_values: dict[str, dict[str, int]] = {} + + def on_sim_finish(): + # Collect metric values before we finish the simulation + for metric_name, metric in self.metrics_manager.get_metrics().items(): + metric = self.metrics_manager.get_metrics()[metric_name] + metric_values[metric_name] = {} + for reg_name in metric.regs: + metric_values[metric_name][reg_name] = yield self.metrics_manager.get_register_value( + metric_name, reg_name + ) - sim = PysimSimulator(m, max_cycles=timeout_cycles, traces_file=self.traces_file) - sim.add_sync_process(self._wishbone_slave(mem_model, wb_instr_ctrl, is_instr_bus=True)) - sim.add_sync_process(self._wishbone_slave(mem_model, wb_data_ctrl, is_instr_bus=False)) - sim.add_sync_process(self._waiter()) - res = sim.run() + sim.add_sync_process(self._waiter(on_finish=on_sim_finish)) + success = sim.run() - if self.verbose: - print(f"Simulation finished in {self.cycle_cnt} cycles") + self.pretty_dump_metrics(metric_values) - return res + return SimulationExecutionResult(success, metric_values, profile) def stop(self): self.running = False diff --git a/test/regression/signature.py b/test/regression/signature.py index e741d3493..b35ffd9f2 100644 --- a/test/regression/signature.py +++ b/test/regression/signature.py @@ -46,7 +46,7 @@ async def run_test(sim_backend: SimulationBackend, test_path: str, signature_pat mem_segments.append(signature_ram) mem_model = CoreMemoryModel(mem_segments) - success = await sim_backend.run(mem_model, timeout_cycles=200000) + success = await sim_backend.run(mem_model, timeout_cycles=60000) if not success: raise RuntimeError(f"{test_path}: Simulation timed out") diff --git a/test/regression/test.py b/test/regression/test.py deleted file mode 100644 index cbe8067cd..000000000 --- a/test/regression/test.py +++ /dev/null @@ -1,51 +0,0 @@ -from glob import glob -from pathlib import Path - -from .memory import * -from .common import SimulationBackend - -test_dir = Path(__file__).parent.parent -riscv_tests_dir = test_dir.joinpath("external/riscv-tests") - -# disable write protection for specific tests with writes to .text section -exclude_write_protection = ["rv32uc-rvc"] - - -class MMIO(MemorySegment): - def __init__(self, on_finish: Callable[[], None]): - super().__init__(range(0x80000000, 0x80000000 + 4), SegmentFlags.READ | SegmentFlags.WRITE) - self.on_finish = on_finish - self.failed_test = 0 - - def read(self, req: ReadRequest) -> ReadReply: - return ReadReply() - - def write(self, req: WriteRequest) -> WriteReply: - self.failed_test = req.data - self.on_finish() - return WriteReply() - - -def get_all_test_names(): - return {name[5:] for name in glob("test-*", root_dir=riscv_tests_dir)} - - -async def run_test(sim_backend: SimulationBackend, test_name: str): - mmio = MMIO(lambda: sim_backend.stop()) - - mem_segments: list[MemorySegment] = [] - mem_segments += load_segments_from_elf( - str(riscv_tests_dir.joinpath("test-" + test_name)), - disable_write_protection=test_name in exclude_write_protection, - ) - mem_segments.append(mmio) - - mem_model = CoreMemoryModel(mem_segments) - - success = await sim_backend.run(mem_model, timeout_cycles=5000) - - if not success: - raise RuntimeError("Simulation timed out") - - if mmio.failed_test: - raise RuntimeError("Failing test: %d" % mmio.failed_test) diff --git a/test/regression/test_regression.py b/test/regression/test_regression.py new file mode 100644 index 000000000..88fc538f1 --- /dev/null +++ b/test/regression/test_regression.py @@ -0,0 +1,143 @@ +from .memory import * +from .common import SimulationBackend +from .conftest import riscv_tests_dir, profile_dir +from test.regression.pysim import PySimulation +import xml.etree.ElementTree as eT +import asyncio +from typing import Literal +import os +import pytest +import subprocess +import json +import tempfile +from filelock import FileLock + +REGRESSION_TESTS_PREFIX = "test.regression." + + +# disable write protection for specific tests with writes to .text section +exclude_write_protection = ["rv32uc-rvc"] + + +class MMIO(MemorySegment): + def __init__(self, on_finish: Callable[[], None]): + super().__init__(range(0x80000000, 0x80000000 + 4), SegmentFlags.READ | SegmentFlags.WRITE) + self.on_finish = on_finish + self.failed_test = 0 + + def read(self, req: ReadRequest) -> ReadReply: + return ReadReply() + + def write(self, req: WriteRequest) -> WriteReply: + self.failed_test = req.data + self.on_finish() + return WriteReply() + + +async def run_test(sim_backend: SimulationBackend, test_name: str): + mmio = MMIO(lambda: sim_backend.stop()) + + mem_segments: list[MemorySegment] = [] + mem_segments += load_segments_from_elf( + str(riscv_tests_dir.joinpath("test-" + test_name)), + disable_write_protection=test_name in exclude_write_protection, + ) + mem_segments.append(mmio) + + mem_model = CoreMemoryModel(mem_segments) + + result = await sim_backend.run(mem_model, timeout_cycles=5000) + + if result.profile is not None: + os.makedirs(profile_dir, exist_ok=True) + result.profile.encode(f"{profile_dir}/test.regression.{test_name}.json") + + if not result.success: + raise RuntimeError("Simulation timed out") + + if mmio.failed_test: + raise RuntimeError("Failing test: %d" % mmio.failed_test) + + +def regression_body_with_cocotb(test_name: str, traces: bool): + arglist = ["make", "-C", "test/regression/cocotb", "-f", "test.Makefile"] + arglist += [f"TESTCASE={test_name}"] + + verilog_code = os.path.join(os.getcwd(), "core.v") + gen_info_path = f"{verilog_code}.json" + arglist += [f"_COREBLOCKS_GEN_INFO={gen_info_path}"] + arglist += [f"VERILOG_SOURCES={verilog_code}"] + tmp_result_file = tempfile.NamedTemporaryFile("r") + arglist += [f"COCOTB_RESULTS_FILE={tmp_result_file.name}"] + + if traces: + arglist += ["TRACES=1"] + + res = subprocess.run(arglist) + + assert res.returncode == 0 + + tree = eT.parse(tmp_result_file.name) + assert len(list(tree.iter("failure"))) == 0 + + +def regression_body_with_pysim(test_name: str, traces: bool): + traces_file = None + if traces: + traces_file = REGRESSION_TESTS_PREFIX + test_name + asyncio.run(run_test(PySimulation(traces_file=traces_file), test_name)) + + +@pytest.fixture(scope="session") +def verilate_model(worker_id, request: pytest.FixtureRequest): + """ + Fixture to prevent races on verilating the coreblocks model. It is run only in + distributed, cocotb, mode. It executes a 'SKIP' regression test which verilates the model. + """ + if request.session.config.getoption("coreblocks_backend") != "cocotb" or worker_id == "master": + # pytest expect yield on every path in fixture + yield None + return + + lock_path = "_coreblocks_regression.lock" + counter_path = "_coreblocks_regression.counter" + with FileLock(lock_path): + regression_body_with_cocotb("SKIP", False) + if os.path.exists(counter_path): + with open(counter_path, "r") as counter_file: + c = json.load(counter_file) + else: + c = 0 + with open(counter_path, "w") as counter_file: + json.dump(c + 1, counter_file) + yield + # Session teardown + deferred_remove = False + with FileLock(lock_path): + with open(counter_path, "r") as counter_file: + c = json.load(counter_file) + if c == 1: + deferred_remove = True + else: + with open(counter_path, "w") as counter_file: + json.dump(c - 1, counter_file) + if deferred_remove: + os.remove(lock_path) + os.remove(counter_path) + + +@pytest.fixture +def sim_backend(request: pytest.FixtureRequest): + return request.config.getoption("coreblocks_backend") + + +@pytest.fixture +def traces_enabled(request: pytest.FixtureRequest): + return request.config.getoption("coreblocks_traces") + + +def test_entrypoint(test_name: str, sim_backend: Literal["pysim", "cocotb"], traces_enabled: bool, verilate_model): + if sim_backend == "cocotb": + regression_body_with_cocotb(test_name, traces_enabled) + elif sim_backend == "pysim": + regression_body_with_pysim(test_name, traces_enabled) diff --git a/test/scheduler/test_rs_selection.py b/test/scheduler/test_rs_selection.py index 31b84b7bc..322323fb2 100644 --- a/test/scheduler/test_rs_selection.py +++ b/test/scheduler/test_rs_selection.py @@ -8,7 +8,7 @@ from coreblocks.params.configurations import test_core_config from coreblocks.scheduler.scheduler import RSSelection from transactron.lib import FIFO, Adapter, AdapterTrans -from test.common import TestCaseWithSimulator, TestbenchIO +from transactron.testing import TestCaseWithSimulator, TestbenchIO _rs1_optypes = {OpType.ARITHMETIC, OpType.COMPARE} _rs2_optypes = {OpType.LOGIC, OpType.COMPARE} diff --git a/test/scheduler/test_scheduler.py b/test/scheduler/test_scheduler.py index 563338324..a25979c49 100644 --- a/test/scheduler/test_scheduler.py +++ b/test/scheduler/test_scheduler.py @@ -19,7 +19,7 @@ from coreblocks.params.configurations import test_core_config from coreblocks.structs_common.rob import ReorderBuffer from coreblocks.utils.protocols import FuncBlock -from ..common import RecordIntDict, TestCaseWithSimulator, TestGen, TestbenchIO, def_method_mock +from transactron.testing import RecordIntDict, TestCaseWithSimulator, TestGen, TestbenchIO, def_method_mock class SchedulerTestCircuit(Elaboratable): @@ -383,4 +383,3 @@ def core_state_mock(): ) sim.add_sync_process(self.make_queue_process(io=self.m.free_rf_inp, input_queues=[self.free_regs_queue])) sim.add_sync_process(instr_input_process) - sim.add_sync_process(core_state_mock) diff --git a/test/scheduler/test_wakeup_select.py b/test/scheduler/test_wakeup_select.py index a24ae7114..ec0cb158c 100644 --- a/test/scheduler/test_wakeup_select.py +++ b/test/scheduler/test_wakeup_select.py @@ -1,5 +1,6 @@ from typing import Optional, cast from amaranth import * +from amaranth.lib.data import StructLayout from amaranth.sim import Settle from collections import deque @@ -14,7 +15,7 @@ from transactron.lib import Adapter from coreblocks.scheduler.wakeup_select import * -from ..common import RecordIntDict, TestCaseWithSimulator, TestbenchIO +from transactron.testing import RecordIntDict, TestCaseWithSimulator, TestbenchIO class WakeupTestCircuit(Elaboratable): @@ -49,14 +50,14 @@ def setUp(self): random.seed(42) - def random_entry(self, layout) -> RecordIntDict: + def random_entry(self, layout: StructLayout) -> RecordIntDict: result = {} - for key, width_or_layout in layout: + for key, width_or_layout in layout.members.items(): if isinstance(width_or_layout, int): result[key] = random.randrange(width_or_layout) elif isclass(width_or_layout) and issubclass(width_or_layout, Enum): result[key] = random.choice(list(width_or_layout)) - else: + elif isinstance(width_or_layout, StructLayout): result[key] = self.random_entry(width_or_layout) return result diff --git a/test/stages/test_backend.py b/test/stages/test_backend.py index d7bdb27fa..2dc1695f8 100644 --- a/test/stages/test_backend.py +++ b/test/stages/test_backend.py @@ -8,7 +8,7 @@ from coreblocks.params.layouts import * from coreblocks.params import GenParams from coreblocks.params.configurations import test_core_config -from ..common import TestCaseWithSimulator, TestbenchIO +from transactron.testing import TestCaseWithSimulator, TestbenchIO class BackendTestCircuit(Elaboratable): diff --git a/test/stages/test_retirement.py b/test/stages/test_retirement.py index 315a44b36..1502eb0b9 100644 --- a/test/stages/test_retirement.py +++ b/test/stages/test_retirement.py @@ -7,7 +7,7 @@ from coreblocks.params import ROBLayouts, RFLayouts, GenParams, LSULayouts, SchedulerLayouts from coreblocks.params.configurations import test_core_config -from ..common import * +from transactron.testing import * from collections import deque import random @@ -51,10 +51,7 @@ def elaborate(self, platform): m.submodules.generic_csr = self.generic_csr = GenericCSRRegisters(self.gen_params) self.gen_params.get(DependencyManager).add_dependency(GenericCSRRegistersKey(), self.generic_csr) - m.submodules.mock_fetch_stall = self.mock_fetch_stall = TestbenchIO(Adapter()) - m.submodules.mock_fetch_continue = self.mock_fetch_continue = TestbenchIO( - Adapter(i=fetch_layouts.branch_verify) - ) + m.submodules.mock_fetch_continue = self.mock_fetch_continue = TestbenchIO(Adapter(i=fetch_layouts.resume)) m.submodules.mock_instr_decrement = self.mock_instr_decrement = TestbenchIO( Adapter(o=core_instr_counter_layouts.decrement) ) @@ -72,7 +69,6 @@ def elaborate(self, platform): exception_cause_get=self.mock_exception_cause.adapter.iface, exception_cause_clear=self.mock_exception_clear.adapter.iface, frat_rename=self.frat.rename, - fetch_stall=self.mock_fetch_stall.adapter.iface, fetch_continue=self.mock_fetch_continue.adapter.iface, instr_decrement=self.mock_instr_decrement.adapter.iface, trap_entry=self.mock_trap_entry.adapter.iface, @@ -158,10 +154,6 @@ def exception_cause_process(self): def exception_clear_process(self): pass - @def_method_mock(lambda self: self.retc.mock_fetch_stall) - def mock_fetch_stall(self): - pass - @def_method_mock(lambda self: self.retc.mock_instr_decrement) def instr_decrement_process(self): pass diff --git a/test/structs_common/test_csr.py b/test/structs_common/test_csr.py index 25e1619a7..4df317ba8 100644 --- a/test/structs_common/test_csr.py +++ b/test/structs_common/test_csr.py @@ -10,7 +10,7 @@ from transactron.utils.dependencies import DependencyManager from coreblocks.params.optypes import OpType -from ..common import * +from transactron.testing import * import random @@ -37,7 +37,7 @@ def elaborate(self, platform): self.gen_params.get(DependencyManager).add_dependency(ExceptionReportKey(), self.exception_report.adapter.iface) self.gen_params.get(DependencyManager).add_dependency(AsyncInterruptInsertSignalKey(), Signal()) - m.submodules.fetch_continue = self.fetch_continue = TestbenchIO(AdapterTrans(self.dut.fetch_continue)) + m.submodules.fetch_resume = self.fetch_resume = TestbenchIO(AdapterTrans(self.dut.fetch_resume)) self.csr = {} @@ -111,7 +111,7 @@ def generate_instruction(self): } def process_test(self): - yield from self.dut.fetch_continue.enable() + yield from self.dut.fetch_resume.enable() yield from self.dut.exception_report.enable() for _ in range(self.cycles): yield from self.random_wait_geom() @@ -132,7 +132,7 @@ def process_test(self): yield from self.random_wait_geom() res = yield from self.dut.accept.call() - self.assertTrue(self.dut.fetch_continue.done()) + self.assertTrue(self.dut.fetch_resume.done()) self.assertEqual(res["rp_dst"], op["exp"]["exp_read"]["rp_dst"]) if op["exp"]["exp_read"]["rp_dst"]: self.assertEqual(res["result"], op["exp"]["exp_read"]["result"]) @@ -158,7 +158,7 @@ def test_randomized(self): ] def process_exception_test(self): - yield from self.dut.fetch_continue.enable() + yield from self.dut.fetch_resume.enable() yield from self.dut.exception_report.enable() for csr in self.exception_csr_numbers: yield from self.random_wait_geom() diff --git a/test/structs_common/test_exception.py b/test/structs_common/test_exception.py index 43b6f4512..1988f5ad3 100644 --- a/test/structs_common/test_exception.py +++ b/test/structs_common/test_exception.py @@ -8,7 +8,7 @@ from transactron.lib import Adapter from transactron.utils import ModuleConnector -from ..common import * +from transactron.testing import * import random @@ -34,14 +34,20 @@ def test_randomized(self): self.cycles = 256 self.rob_idx_mock = TestbenchIO(Adapter(o=self.gen_params.get(ROBLayouts).get_indices)) - self.dut = SimpleTestCircuit(ExceptionCauseRegister(self.gen_params, self.rob_idx_mock.adapter.iface)) - m = ModuleConnector(self.dut, rob_idx_mock=self.rob_idx_mock) + self.fetch_stall_mock = TestbenchIO(Adapter()) + self.dut = SimpleTestCircuit( + ExceptionCauseRegister( + self.gen_params, self.rob_idx_mock.adapter.iface, self.fetch_stall_mock.adapter.iface + ) + ) + m = ModuleConnector(self.dut, rob_idx_mock=self.rob_idx_mock, fetch_stall_mock=self.fetch_stall_mock) self.rob_id = 0 def process_test(): saved_entry = None + yield from self.fetch_stall_mock.enable() for _ in range(self.cycles): self.rob_id = random.randint(0, self.rob_max) @@ -57,6 +63,8 @@ def process_test(): yield from self.dut.report.call(report_arg) yield # additional FIFO delay + self.assertTrue((yield from self.fetch_stall_mock.done())) + new_state = yield from self.dut.get.call() self.assertDictEqual(new_state, expected | {"valid": 1}) # type: ignore diff --git a/test/structs_common/test_rat.py b/test/structs_common/test_rat.py index 39ca2b100..6fb281761 100644 --- a/test/structs_common/test_rat.py +++ b/test/structs_common/test_rat.py @@ -1,4 +1,4 @@ -from ..common import TestCaseWithSimulator, SimpleTestCircuit +from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit from coreblocks.structs_common.rat import FRAT, RRAT from coreblocks.params import GenParams diff --git a/test/structs_common/test_reorder_buffer.py b/test/structs_common/test_reorder_buffer.py index 51abf2d4a..26731e635 100644 --- a/test/structs_common/test_reorder_buffer.py +++ b/test/structs_common/test_reorder_buffer.py @@ -1,6 +1,6 @@ from amaranth.sim import Passive, Settle -from ..common import TestCaseWithSimulator, SimpleTestCircuit +from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit from coreblocks.structs_common.rob import ReorderBuffer from coreblocks.params import GenParams diff --git a/test/structs_common/test_rs.py b/test/structs_common/test_rs.py index aa5fe67ea..d5b9b4741 100644 --- a/test/structs_common/test_rs.py +++ b/test/structs_common/test_rs.py @@ -1,6 +1,6 @@ from amaranth.sim import Settle -from ..common import TestCaseWithSimulator, get_outputs, SimpleTestCircuit +from transactron.testing import TestCaseWithSimulator, get_outputs, SimpleTestCircuit from coreblocks.structs_common.rs import RS from coreblocks.params import * @@ -8,14 +8,11 @@ def create_check_list(rs_entries_bits: int, insert_list: list[dict]) -> list[dict]: - check_list = [ - {"rs_data": None, "rec_ready": 0, "rec_reserved": 0, "rec_full": 0} for _ in range(2**rs_entries_bits) - ] + check_list = [{"rs_data": None, "rec_reserved": 0, "rec_full": 0} for _ in range(2**rs_entries_bits)] for params in insert_list: entry_id = params["rs_entry_id"] check_list[entry_id]["rs_data"] = params["rs_data"] - check_list[entry_id]["rec_ready"] = 1 if params["rs_data"]["rp_s1"] | params["rs_data"]["rp_s2"] == 0 else 0 check_list[entry_id]["rec_full"] = 1 check_list[entry_id]["rec_reserved"] = 1 @@ -111,7 +108,6 @@ def simulation_process(self): # Check if RS state is as expected for expected, record in zip(self.check_list, self.m._dut.data): self.assertEqual((yield record.rec_full), expected["rec_full"]) - self.assertEqual((yield record.rec_ready), expected["rec_ready"]) self.assertEqual((yield record.rec_reserved), expected["rec_reserved"]) # Reserve the last entry, then select ready should be false @@ -174,12 +170,12 @@ def simulation_process(self): # Update second entry first SP, instruction should be not ready value_sp1 = 1010 - self.assertEqual((yield self.m._dut.data[1].rec_ready), 0) + self.assertEqual((yield self.m._dut.data_ready[1]), 0) yield from self.m.update.call(reg_id=2, reg_val=value_sp1) yield Settle() self.assertEqual((yield self.m._dut.data[1].rs_data.rp_s1), 0) self.assertEqual((yield self.m._dut.data[1].rs_data.s1_val), value_sp1) - self.assertEqual((yield self.m._dut.data[1].rec_ready), 0) + self.assertEqual((yield self.m._dut.data_ready[1]), 0) # Update second entry second SP, instruction should be ready value_sp2 = 2020 @@ -187,7 +183,7 @@ def simulation_process(self): yield Settle() self.assertEqual((yield self.m._dut.data[1].rs_data.rp_s2), 0) self.assertEqual((yield self.m._dut.data[1].rs_data.s2_val), value_sp2) - self.assertEqual((yield self.m._dut.data[1].rec_ready), 1) + self.assertEqual((yield self.m._dut.data_ready[1]), 1) # Insert new instruction to entries 0 and 1, check if update of multiple registers works reg_id = 4 @@ -210,7 +206,7 @@ def simulation_process(self): for index in range(2): yield from self.m.insert.call(rs_entry_id=index, rs_data=data) yield Settle() - self.assertEqual((yield self.m._dut.data[index].rec_ready), 0) + self.assertEqual((yield self.m._dut.data_ready[index]), 0) yield from self.m.update.call(reg_id=reg_id, reg_val=value_spx) yield Settle() @@ -219,7 +215,7 @@ def simulation_process(self): self.assertEqual((yield self.m._dut.data[index].rs_data.rp_s2), 0) self.assertEqual((yield self.m._dut.data[index].rs_data.s1_val), value_spx) self.assertEqual((yield self.m._dut.data[index].rs_data.s2_val), value_spx) - self.assertEqual((yield self.m._dut.data[index].rec_ready), 1) + self.assertEqual((yield self.m._dut.data_ready[index]), 1) class TestRSMethodTake(TestCaseWithSimulator): @@ -305,8 +301,8 @@ def simulation_process(self): for index in range(2): yield from self.m.insert.call(rs_entry_id=index, rs_data=entry_data) yield Settle() - self.assertEqual((yield self.m._dut.data[index].rec_ready), 1) self.assertEqual((yield self.m._dut.take.ready), 1) + self.assertEqual((yield self.m._dut.data_ready[index]), 1) data = yield from self.m.take.call(rs_entry_id=0) for key in data: diff --git a/test/test_core.py b/test/test_core.py index 934a28449..a2cfd1d88 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -1,36 +1,25 @@ from amaranth import Elaboratable, Module +from amaranth.lib.wiring import connect from transactron.lib import AdapterTrans from transactron.utils import align_to_power_of_two, signed_to_int -from .common import TestCaseWithSimulator, TestbenchIO +from transactron.testing import TestCaseWithSimulator, TestbenchIO from coreblocks.core import Core from coreblocks.params import GenParams from coreblocks.params.configurations import CoreConfiguration, basic_core_config, full_core_config -from coreblocks.peripherals.wishbone import WishboneBus, WishboneMemorySlave +from coreblocks.peripherals.wishbone import WishboneSignature, WishboneMemorySlave -from typing import Optional, cast +from typing import Optional import random import subprocess import tempfile from parameterized import parameterized_class from riscvmodel.insn import ( InstructionADDI, - InstructionSLTI, - InstructionSLTIU, - InstructionXORI, - InstructionORI, - InstructionANDI, - InstructionSLLI, - InstructionSRLI, - InstructionSRAI, InstructionLUI, - InstructionJAL, ) -from riscvmodel.model import Model -from riscvmodel.isa import Instruction, InstructionRType, get_insns -from riscvmodel.variant import RV32I class CoreTestElaboratable(Elaboratable): @@ -45,8 +34,8 @@ def __init__(self, gen_params: GenParams, instr_mem: list[int] = [0], data_mem: def elaborate(self, platform): m = Module() - wb_instr_bus = WishboneBus(self.gen_params.wb_params) - wb_data_bus = WishboneBus(self.gen_params.wb_params) + wb_instr_bus = WishboneSignature(self.gen_params.wb_params).create() + wb_data_bus = WishboneSignature(self.gen_params.wb_params).create() # Align the size of the memory to the length of a cache line. instr_mem_depth = align_to_power_of_two(len(self.instr_mem), self.gen_params.icache_params.block_size_bits) @@ -58,69 +47,33 @@ def elaborate(self, platform): ) self.core = Core(gen_params=self.gen_params, wb_instr_bus=wb_instr_bus, wb_data_bus=wb_data_bus) self.io_in = TestbenchIO(AdapterTrans(self.core.fetch_continue.method)) - self.rf_write = TestbenchIO(AdapterTrans(self.core.RF.write)) self.interrupt = TestbenchIO(AdapterTrans(self.core.interrupt_controller.report_interrupt)) m.submodules.wb_mem_slave = self.wb_mem_slave m.submodules.wb_mem_slave_data = self.wb_mem_slave_data m.submodules.c = self.core m.submodules.io_in = self.io_in - m.submodules.rf_write = self.rf_write m.submodules.interrupt = self.interrupt - m.d.comb += wb_instr_bus.connect(self.wb_mem_slave.bus) - m.d.comb += wb_data_bus.connect(self.wb_mem_slave_data.bus) + connect(m, wb_instr_bus, self.wb_mem_slave.bus) + connect(m, wb_data_bus, self.wb_mem_slave_data.bus) return m -def gen_riscv_add_instr(dst, src1, src2): - return 0b0110011 | dst << 7 | src1 << 15 | src2 << 20 - - -def gen_riscv_lui_instr(dst, imm): - return 0b0110111 | dst << 7 | imm << 12 - - class TestCoreBase(TestCaseWithSimulator): gen_params: GenParams m: CoreTestElaboratable - def check_RAT_alloc(self, rat, expected_alloc_count=None): # noqa: N802 - allocated = [] - for i in range(self.m.gen_params.isa.reg_cnt): - allocated.append((yield rat.entries[i])) - filtered_zeros = list(filter(lambda x: x != 0, allocated)) - - # check if 0th register is set to 0 - self.assertEqual(allocated[0], 0) - # check if there are no duplicate physical registers allocated for two different architectural registers - self.assertEqual(len(filtered_zeros), len(set(filtered_zeros))) - # check if the expected number of allocated registers matches reality - if expected_alloc_count: - self.assertEqual(len(filtered_zeros), expected_alloc_count) - def get_phys_reg_rrat(self, reg_id): return (yield self.m.core.RRAT.entries[reg_id]) - def get_phys_reg_frat(self, reg_id): - return (yield self.m.core.FRAT.entries[reg_id]) - def get_arch_reg_val(self, reg_id): return (yield self.m.core.RF.entries[(yield from self.get_phys_reg_rrat(reg_id))].reg_val) - def get_phys_reg_val(self, reg_id): - return (yield self.m.core.RF.entries[reg_id].reg_val) - def push_instr(self, opcode): yield from self.m.io_in.call(instr=opcode) - def compare_core_states(self, sw_core): - for i in range(self.gen_params.isa.reg_cnt): - reg_val = sw_core.state.intreg.regs[i].value - unsigned_val = reg_val & 0xFFFFFFFF - self.assertEqual((yield from self.get_arch_reg_val(i)), unsigned_val) - def push_register_load_imm(self, reg_id, val): addi_imm = signed_to_int(val & 0xFFF, 12) lui_imm = (val & 0xFFFFF000) >> 12 @@ -132,123 +85,6 @@ def push_register_load_imm(self, reg_id, val): yield from self.push_instr(InstructionADDI(reg_id, reg_id, addi_imm).encode()) -class TestCoreSimple(TestCoreBase): - def simple_test(self): - # this test first provokes allocation of physical registers, - # then sets the values in those registers, and finally runs - # an actual computation. - - # The test sets values in the reg file by hand - - # provoking allocation of physical register - for i in range(self.m.gen_params.isa.reg_cnt - 1): - yield from self.push_instr(gen_riscv_add_instr(i + 1, 0, 0)) - - # waiting for the retirement rat to be set - for i in range(100): - yield - - # checking if all registers have been allocated - yield from self.check_RAT_alloc(self.m.core.FRAT, 31) - yield from self.check_RAT_alloc(self.m.core.RRAT, 31) - - # writing values to physical registers - yield from self.m.rf_write.call(reg_id=(yield from self.get_phys_reg_rrat(1)), reg_val=1) - yield from self.m.rf_write.call(reg_id=(yield from self.get_phys_reg_rrat(2)), reg_val=2) - yield from self.m.rf_write.call(reg_id=(yield from self.get_phys_reg_rrat(3)), reg_val=3) - - # waiting for potential conflicts on rf_write - for i in range(10): - yield - - self.assertEqual((yield from self.get_arch_reg_val(1)), 1) - self.assertEqual((yield from self.get_arch_reg_val(2)), 2) - self.assertEqual((yield from self.get_arch_reg_val(3)), 3) - - # issuing actual instructions for the test - yield from self.push_instr(gen_riscv_add_instr(4, 1, 2)) - yield from self.push_instr(gen_riscv_add_instr(4, 3, 4)) - yield from self.push_instr(gen_riscv_lui_instr(5, 1)) - - # waiting for the instructions to be processed - for i in range(50): - yield - - self.assertEqual((yield from self.get_arch_reg_val(1)), 1) - self.assertEqual((yield from self.get_arch_reg_val(2)), 2) - self.assertEqual((yield from self.get_arch_reg_val(3)), 3) - # 1 + 2 + 3 = 6 - self.assertEqual((yield from self.get_arch_reg_val(4)), 6) - self.assertEqual((yield from self.get_arch_reg_val(5)), 1 << 12) - - def test_simple(self): - self.gen_params = GenParams(basic_core_config) - m = CoreTestElaboratable(self.gen_params) - self.m = m - - with self.run_simulation(m) as sim: - sim.add_sync_process(self.simple_test) - - -class TestCoreRandomized(TestCoreBase): - def randomized_input(self): - infloop_addr = (len(self.instr_mem) - 1) * 4 - # wait for PC to go past all instruction - while (yield self.m.core.fetch.pc) != infloop_addr: - yield - - # finish calculations - yield from self.tick(50) - - yield from self.compare_core_states(self.software_core) - - def test_randomized(self): - self.gen_params = GenParams(basic_core_config) - self.instr_count = 300 - random.seed(42) - - # cast is there to avoid stubbing riscvmodel - instructions = cast(list[type[Instruction]], get_insns(cls=InstructionRType, variant=RV32I)) - instructions += [ - InstructionADDI, - InstructionSLTI, - InstructionSLTIU, - InstructionXORI, - InstructionORI, - InstructionANDI, - InstructionSLLI, - InstructionSRLI, - InstructionSRAI, - InstructionLUI, - ] - - # allocate some random values for registers - init_instr_list = list( - InstructionADDI(rd=i, rs1=0, imm=random.randint(-(2**11), 2**11 - 1)) - for i in range(self.gen_params.isa.reg_cnt) - ) - - # generate random instruction stream - instr_list = list(random.choice(instructions)() for _ in range(self.instr_count)) - for instr in instr_list: - instr.randomize(RV32I) - - self.software_core = Model(RV32I) - self.software_core.execute(init_instr_list) - self.software_core.execute(instr_list) - - # We add JAL instruction at the end to effectively create a infinite loop at the end of the program. - all_instr = init_instr_list + instr_list + [InstructionJAL(rd=0, imm=0)] - - self.instr_mem = list(map(lambda x: x.encode(), all_instr)) - - m = CoreTestElaboratable(self.gen_params, instr_mem=self.instr_mem) - self.m = m - - with self.run_simulation(m) as sim: - sim.add_sync_process(self.randomized_input) - - class TestCoreAsmSourceBase(TestCoreBase): base_dir: str = "test/asm/" @@ -298,12 +134,12 @@ def prepare_source(self, filename): @parameterized_class( ("name", "source_file", "cycle_count", "expected_regvals", "configuration"), [ - ("fibonacci", "fibonacci.asm", 1200, {2: 2971215073}, basic_core_config), - ("fibonacci_mem", "fibonacci_mem.asm", 610, {3: 55}, basic_core_config), + ("fibonacci", "fibonacci.asm", 1200 * 2, {2: 2971215073}, basic_core_config), + ("fibonacci_mem", "fibonacci_mem.asm", 610 * 2, {3: 55}, basic_core_config), ("csr", "csr.asm", 200, {1: 1, 2: 4}, full_core_config), - ("exception", "exception.asm", 200, {1: 1, 2: 2}, basic_core_config), - ("exception_mem", "exception_mem.asm", 200, {1: 1, 2: 2}, basic_core_config), - ("exception_handler", "exception_handler.asm", 1500, {2: 987, 11: 0xAAAA, 15: 16}, full_core_config), + ("exception", "exception.asm", 200 * 2, {1: 1, 2: 2}, basic_core_config), + ("exception_mem", "exception_mem.asm", 200 * 2, {1: 1, 2: 2}, basic_core_config), + ("exception_handler", "exception_handler.asm", int(1500 * 2.2), {2: 987, 11: 0xAAAA, 15: 16}, full_core_config), ], ) class TestCoreBasicAsm(TestCoreAsmSourceBase): @@ -333,11 +169,11 @@ def test_asm_source(self): @parameterized_class( ("source_file", "main_cycle_count", "start_regvals", "expected_regvals", "lo", "hi"), [ - ("interrupt.asm", 400, {4: 2971215073, 8: 29}, {2: 2971215073, 7: 29, 31: 0xDE}, 300, 500), - ("interrupt.asm", 700, {4: 24157817, 8: 199}, {2: 24157817, 7: 199, 31: 0xDE}, 100, 200), - ("interrupt.asm", 600, {4: 89, 8: 843}, {2: 89, 7: 843, 31: 0xDE}, 30, 50), + ("interrupt.asm", 400 * 4, {4: 2971215073, 8: 29}, {2: 2971215073, 7: 29, 31: 0xDE}, 300, 500), + ("interrupt.asm", 700 * 4, {4: 24157817, 8: 199}, {2: 24157817, 7: 199, 31: 0xDE}, 100, 200), + ("interrupt.asm", 600 * 4, {4: 89, 8: 843}, {2: 89, 7: 843, 31: 0xDE}, 30, 50), # interrupts are only inserted on branches, we always have some forward progression. 15 for trigger variantion. - ("interrupt.asm", 80, {4: 21, 8: 9349}, {2: 21, 7: 9349, 31: 0xDE}, 0, 15), + ("interrupt.asm", 80 * 4, {4: 21, 8: 9349}, {2: 21, 7: 9349, 31: 0xDE}, 0, 15), ], ) class TestCoreInterrupt(TestCoreAsmSourceBase): diff --git a/test/transactions/test_adapter.py b/test/transactions/test_adapter.py index 48728cb02..7c4849657 100644 --- a/test/transactions/test_adapter.py +++ b/test/transactions/test_adapter.py @@ -3,7 +3,7 @@ from transactron import Method, def_method, TModule -from ..common import TestCaseWithSimulator, data_layout, SimpleTestCircuit, ModuleConnector +from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit, ModuleConnector class Echo(Elaboratable): diff --git a/test/transactions/test_assign.py b/test/transactions/test_assign.py index 0659a01e4..73d5b28f9 100644 --- a/test/transactions/test_assign.py +++ b/test/transactions/test_assign.py @@ -1,9 +1,9 @@ from typing import Callable from amaranth import * from amaranth.lib import data -from amaranth.hdl.ast import ArrayProxy, Slice +from amaranth.hdl._ast import ArrayProxy, Slice -from transactron.utils._typing import LayoutLike +from transactron.utils._typing import MethodLayout from transactron.utils import AssignType, assign from transactron.utils.assign import AssignArg, AssignFields @@ -24,7 +24,7 @@ def mkproxy(layout): - arr = Array([Record(layout) for _ in range(4)]) + arr = Array([Signal(reclayout2datalayout(layout)) for _ in range(4)]) sig = Signal(2) return arr[sig] @@ -40,7 +40,6 @@ def mkstruct(layout): params_mk = [ - ("rec", Record), ("proxy", mkproxy), ("struct", mkstruct), ] @@ -55,15 +54,15 @@ def mkstruct(layout): ], ) class TestAssign(TestCase): - # constructs `assign` arguments (records, proxies, dicts) which have an "inner" and "outer" part - # parameterized with a Record-like constructor and a layout of the inner part - build: Callable[[Callable[[LayoutLike], AssignArg], LayoutLike], AssignArg] + # constructs `assign` arguments (views, proxies, dicts) which have an "inner" and "outer" part + # parameterized with a constructor and a layout of the inner part + build: Callable[[Callable[[MethodLayout], AssignArg], MethodLayout], AssignArg] # constructs field specifications for `assign`, takes field specifications for the inner part wrap: Callable[[AssignFields], AssignFields] # extracts the inner part of the structure - extr: Callable[[AssignArg], Record | ArrayProxy] - # Record-like constructor, takes a record layout - mk: Callable[[LayoutLike], AssignArg] + extr: Callable[[AssignArg], ArrayProxy] + # constructor, takes a layout + mk: Callable[[MethodLayout], AssignArg] def test_rhs_exception(self): with self.assertRaises(KeyError): @@ -100,7 +99,7 @@ def test_wrong_bits(self): ("list", layout_ab, layout_ab, ["a", "a"]), ] ) - def test_assign_a(self, name, layout1: LayoutLike, layout2: LayoutLike, atype: AssignType): + def test_assign_a(self, name, layout1: MethodLayout, layout2: MethodLayout, atype: AssignType): lhs = self.build(self.mk, layout1) rhs = self.build(self.mk, layout2) alist = list(assign(lhs, rhs, fields=self.wrap(atype))) diff --git a/test/transactions/test_branches.py b/test/transactions/test_branches.py index ba2a4545a..f66b954b7 100644 --- a/test/transactions/test_branches.py +++ b/test/transactions/test_branches.py @@ -11,7 +11,8 @@ def_method, ) from unittest import TestCase -from ..common import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator +from transactron.utils.dependencies import DependencyContext class TestExclusivePath(TestCase): @@ -87,9 +88,9 @@ def test_conflict_removal(self): circ = ExclusiveConflictRemovalCircuit() tm = TransactionManager() - dut = TransactionModule(circ, tm) + dut = TransactionModule(circ, DependencyContext.get(), tm) - with self.run_simulation(dut): + with self.run_simulation(dut, add_transaction_module=False): pass cgr, _, _ = tm._conflict_graph(MethodMap(tm.transactions)) diff --git a/test/transactions/test_methods.py b/test/transactions/test_methods.py index 838199e0a..13ea4cdb0 100644 --- a/test/transactions/test_methods.py +++ b/test/transactions/test_methods.py @@ -2,9 +2,10 @@ from amaranth import * from amaranth.sim import * -from ..common import TestCaseWithSimulator, TestbenchIO, data_layout +from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout from transactron import * +from transactron.utils import MethodStruct from transactron.lib import * from parameterized import parameterized @@ -41,7 +42,7 @@ def definition(arg): self.do_test_definition(definition) def test_fields_valid2(self): - rec = Record([("bar1", 4), ("bar2", 6)]) + rec = Signal(from_method_layout([("bar1", 4), ("bar2", 6)])) def definition(arg): return {"foo1": Signal(3), "foo2": rec} @@ -55,7 +56,7 @@ def definition(arg): self.do_test_definition(definition) def test_fields_valid4(self): - def definition(arg: Record): + def definition(arg: MethodStruct): return arg self.do_test_definition(definition) @@ -345,11 +346,10 @@ def _(arg): class TestQuadrupleCircuits(TestCaseWithSimulator): - def test(self): - self.work(QuadrupleCircuit(Quadruple())) - self.work(QuadrupleCircuit(Quadruple2())) + @parameterized.expand([(Quadruple,), (Quadruple2,)]) + def test(self, quadruple): + circ = QuadrupleCircuit(quadruple()) - def work(self, circ): def process(): for n in range(1 << (WIDTH - 2)): out = yield from circ.tb.call(data=n) @@ -560,8 +560,8 @@ def __init__(self, n=2, ready_function=lambda arg: arg.data != 3): self.method = Method(i=data_layout(n)) self.ready_function = ready_function - self.in_t1 = Record(data_layout(n)) - self.in_t2 = Record(data_layout(n)) + self.in_t1 = Signal(from_method_layout(data_layout(n))) + self.in_t2 = Signal(from_method_layout(data_layout(n))) self.ready = Signal() self.req_t1 = Signal() self.req_t2 = Signal() diff --git a/test/transactions/test_simultaneous.py b/test/transactions/test_simultaneous.py index 9a55bce36..71fe3e9eb 100644 --- a/test/transactions/test_simultaneous.py +++ b/test/transactions/test_simultaneous.py @@ -5,7 +5,7 @@ from transactron.utils import ModuleConnector -from ..common import SimpleTestCircuit, TestCaseWithSimulator, TestbenchIO, def_method_mock +from transactron.testing import SimpleTestCircuit, TestCaseWithSimulator, TestbenchIO, def_method_mock from transactron import * from transactron.lib import Adapter, Connect, ConnectTrans @@ -108,7 +108,7 @@ def elaborate(self, platform): m = TModule() with Transaction().body(m, request=self.request): - self.target(m, self.data ^ self.source(m)) + self.target(m, self.data ^ self.source(m).data) return m diff --git a/test/transactions/test_transaction_lib.py b/test/transactions/test_transaction_lib.py index 8c05cea66..058557f22 100644 --- a/test/transactions/test_transaction_lib.py +++ b/test/transactions/test_transaction_lib.py @@ -13,9 +13,9 @@ from transactron.core import RecordDict from transactron.lib import * from coreblocks.utils import * -from transactron.utils._typing import LayoutLike, ModuleLike +from transactron.utils._typing import ModuleLike, MethodStruct from transactron.utils import ModuleConnector -from ..common import ( +from transactron.testing import ( SimpleTestCircuit, TestCaseWithSimulator, TestbenchIO, @@ -25,7 +25,7 @@ class RevConnect(Elaboratable): - def __init__(self, layout: LayoutLike): + def __init__(self, layout: MethodLayout): self.connect = Connect(rev_layout=layout) self.read = self.connect.write self.write = self.connect.read @@ -235,7 +235,7 @@ def process(): class ManyToOneConnectTransTestCircuit(Elaboratable): - def __init__(self, count: int, lay: LayoutLike): + def __init__(self, count: int, lay: MethodLayout): self.count = count self.lay = lay self.inputs = [] @@ -353,20 +353,20 @@ def elaborate(self, platform): layout = data_layout(self.iosize) - def itransform_rec(m: ModuleLike, v: Record) -> Record: - s = Record.like(v) + def itransform_rec(m: ModuleLike, v: MethodStruct) -> MethodStruct: + s = Signal.like(v) m.d.comb += s.data.eq(v.data + 1) return s - def otransform_rec(m: ModuleLike, v: Record) -> Record: - s = Record.like(v) + def otransform_rec(m: ModuleLike, v: MethodStruct) -> MethodStruct: + s = Signal.like(v) m.d.comb += s.data.eq(v.data - 1) return s - def itransform_dict(_, v: Record) -> RecordDict: + def itransform_dict(_, v: MethodStruct) -> RecordDict: return {"data": v.data + 1} - def otransform_dict(_, v: Record) -> RecordDict: + def otransform_dict(_, v: MethodStruct) -> RecordDict: return {"data": v.data - 1} if self.use_dicts: @@ -383,11 +383,11 @@ def otransform_dict(_, v: Record) -> RecordDict: ometh = Method(i=layout, o=layout) @def_method(m, imeth) - def _(arg: Record): + def _(arg: MethodStruct): return itransform(m, arg) @def_method(m, ometh) - def _(arg: Record): + def _(arg: MethodStruct): return otransform(m, arg) trans = MethodMap(self.target.adapter.iface, i_transform=(layout, imeth), o_transform=(layout, ometh)) @@ -471,7 +471,7 @@ def test_method_filter(self, use_condition): self.initialize() def condition(_, v): - return v[0] + return v.data[0] self.tc = SimpleTestCircuit(MethodFilter(self.target.adapter.iface, condition, use_condition=use_condition)) m = ModuleConnector(test_circuit=self.tc, target=self.target, cmeth=self.cmeth) @@ -501,7 +501,7 @@ def elaborate(self, platform): combiner = None if self.add_combiner: - combiner = (layout, lambda _, vs: {"data": sum(vs)}) + combiner = (layout, lambda _, vs: {"data": sum(x.data for x in vs)}) product = MethodProduct(methods, combiner) diff --git a/test/transactions/test_transactions.py b/test/transactions/test_transactions.py index ee131868e..c73a1642b 100644 --- a/test/transactions/test_transactions.py +++ b/test/transactions/test_transactions.py @@ -9,7 +9,7 @@ from typing import Iterable, Callable from parameterized import parameterized, parameterized_class -from ..common import TestCaseWithSimulator, TestbenchIO, data_layout +from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout from transactron import * from transactron.lib import Adapter, AdapterTrans @@ -21,6 +21,7 @@ trivial_roundrobin_cc_scheduler, eager_deterministic_cc_scheduler, ) +from transactron.utils.dependencies import DependencyContext class TestNames(TestCase): @@ -110,7 +111,7 @@ def __init__(self, scheduler): def elaborate(self, platform): m = TModule() - tm = TransactionModule(m, TransactionManager(self.scheduler)) + tm = TransactionModule(m, DependencyContext.get(), TransactionManager(self.scheduler)) adapter = Adapter(i=data_layout(32), o=data_layout(32)) m.submodules.out = self.out = TestbenchIO(adapter) m.submodules.in1 = self.in1 = TestbenchIO(AdapterTrans(adapter.iface)) @@ -317,7 +318,7 @@ def elaborate(self, platform): m = TModule() tm = TransactionModule(m) - with tm.transaction_context(): + with tm.context(): with Transaction().body(m, request=self.r1): m.d.comb += self.t1.eq(1) with Transaction().body(m, request=self.r2): @@ -342,7 +343,7 @@ def _(): def _(): m.d.comb += self.t2.eq(1) - with tm.transaction_context(): + with tm.context(): with Transaction().body(m): method1(m) @@ -389,7 +390,7 @@ def elaborate(self, platform): def _(): pass - with tm.transaction_context(): + with tm.context(): with (t1 := Transaction()).body(m, request=self.r1): method(m) m.d.comb += self.t1.eq(1) diff --git a/test/transactron/__init__.py b/test/transactron/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/transactron/test_metrics.py b/test/transactron/test_metrics.py new file mode 100644 index 000000000..12acdfd27 --- /dev/null +++ b/test/transactron/test_metrics.py @@ -0,0 +1,399 @@ +import json +import random +import queue +from parameterized import parameterized_class + +from amaranth import * +from amaranth.sim import Passive, Settle + +from transactron.lib.metrics import * +from transactron import * +from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit +from transactron.utils.dependencies import DependencyContext + + +class CounterInMethodCircuit(Elaboratable): + def __init__(self): + self.method = Method() + self.counter = HwCounter("in_method") + + def elaborate(self, platform): + m = TModule() + + m.submodules.counter = self.counter + + @def_method(m, self.method) + def _(): + self.counter.incr(m) + + return m + + +class CounterWithConditionInMethodCircuit(Elaboratable): + def __init__(self): + self.method = Method(i=[("cond", 1)]) + self.counter = HwCounter("with_condition_in_method") + + def elaborate(self, platform): + m = TModule() + + m.submodules.counter = self.counter + + @def_method(m, self.method) + def _(cond): + self.counter.incr(m, cond=cond) + + return m + + +class CounterWithoutMethodCircuit(Elaboratable): + def __init__(self): + self.cond = Signal() + self.counter = HwCounter("with_condition_without_method") + + def elaborate(self, platform): + m = TModule() + + m.submodules.counter = self.counter + + with Transaction().body(m): + self.counter.incr(m, cond=self.cond) + + return m + + +class TestHwCounter(TestCaseWithSimulator): + def setUp(self) -> None: + random.seed(42) + + def test_counter_in_method(self): + m = SimpleTestCircuit(CounterInMethodCircuit()) + DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) + + def test_process(): + called_cnt = 0 + for _ in range(200): + call_now = random.randint(0, 1) == 0 + + if call_now: + yield from m.method.call() + else: + yield + + # Note that it takes one cycle to update the register value, so here + # we are comparing the "previous" values. + self.assertEqual(called_cnt, (yield m._dut.counter.count.value)) + + if call_now: + called_cnt += 1 + + with self.run_simulation(m) as sim: + sim.add_sync_process(test_process) + + def test_counter_with_condition_in_method(self): + m = SimpleTestCircuit(CounterWithConditionInMethodCircuit()) + DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) + + def test_process(): + called_cnt = 0 + for _ in range(200): + call_now = random.randint(0, 1) == 0 + condition = random.randint(0, 1) + + if call_now: + yield from m.method.call(cond=condition) + else: + yield + + # Note that it takes one cycle to update the register value, so here + # we are comparing the "previous" values. + self.assertEqual(called_cnt, (yield m._dut.counter.count.value)) + + if call_now and condition == 1: + called_cnt += 1 + + with self.run_simulation(m) as sim: + sim.add_sync_process(test_process) + + def test_counter_with_condition_without_method(self): + m = CounterWithoutMethodCircuit() + DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) + + def test_process(): + called_cnt = 0 + for _ in range(200): + condition = random.randint(0, 1) + + yield m.cond.eq(condition) + yield + + # Note that it takes one cycle to update the register value, so here + # we are comparing the "previous" values. + self.assertEqual(called_cnt, (yield m.counter.count.value)) + + if condition == 1: + called_cnt += 1 + + with self.run_simulation(m) as sim: + sim.add_sync_process(test_process) + + +class ExpHistogramCircuit(Elaboratable): + def __init__(self, bucket_cnt: int, sample_width: int): + self.sample_width = sample_width + + self.method = Method(i=data_layout(32)) + self.histogram = HwExpHistogram("histogram", bucket_count=bucket_cnt, sample_width=sample_width) + + def elaborate(self, platform): + m = TModule() + + m.submodules.histogram = self.histogram + + @def_method(m, self.method) + def _(data): + self.histogram.add(m, data[0 : self.sample_width]) + + return m + + +@parameterized_class( + ("bucket_count", "sample_width"), + [ + (5, 5), # last bucket is [8, inf), max sample=31 + (8, 5), # last bucket is [64, inf), max sample=31 + (8, 6), # last bucket is [64, inf), max sample=63 + (8, 20), # last bucket is [64, inf), max sample=big + ], +) +class TestHwHistogram(TestCaseWithSimulator): + bucket_count: int + sample_width: int + + def test_histogram(self): + random.seed(42) + + m = SimpleTestCircuit(ExpHistogramCircuit(bucket_cnt=self.bucket_count, sample_width=self.sample_width)) + DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) + + max_sample_value = 2**self.sample_width - 1 + + def test_process(): + min = max_sample_value + 1 + max = 0 + sum = 0 + count = 0 + + buckets = [0] * self.bucket_count + + for _ in range(500): + if random.randrange(3) == 0: + value = random.randint(0, max_sample_value) + if value < min: + min = value + if value > max: + max = value + sum += value + count += 1 + for i in range(self.bucket_count): + if value < 2**i or i == self.bucket_count - 1: + buckets[i] += 1 + break + yield from m.method.call(data=value) + yield + else: + yield + + histogram = m._dut.histogram + # Skip the assertion if the min is still uninitialized + if min != max_sample_value + 1: + self.assertEqual(min, (yield histogram.min.value)) + + self.assertEqual(max, (yield histogram.max.value)) + self.assertEqual(sum, (yield histogram.sum.value)) + self.assertEqual(count, (yield histogram.count.value)) + + total_count = 0 + for i in range(self.bucket_count): + bucket_value = yield histogram.buckets[i].value + total_count += bucket_value + self.assertEqual(buckets[i], bucket_value) + + # Sanity check if all buckets sum up to the total count value + self.assertEqual(total_count, (yield histogram.count.value)) + + with self.run_simulation(m) as sim: + sim.add_sync_process(test_process) + + +@parameterized_class( + ("slots_number", "expected_consumer_wait"), + [ + (2, 5), + (2, 10), + (5, 10), + (10, 1), + (10, 10), + (5, 5), + ], +) +class TestLatencyMeasurer(TestCaseWithSimulator): + slots_number: int + expected_consumer_wait: float + + def test_latency_measurer(self): + random.seed(42) + + m = SimpleTestCircuit(LatencyMeasurer("latency", slots_number=self.slots_number, max_latency=300)) + DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) + + latencies: list[int] = [] + + event_queue = queue.Queue() + + time = 0 + + def ticker(): + nonlocal time + + yield Passive() + + while True: + yield + time += 1 + + finish = False + + def producer(): + nonlocal finish + + for _ in range(200): + yield from m._start.call() + + # Make sure that the time is updated first. + yield Settle() + event_queue.put(time) + yield from self.random_wait_geom(0.8) + + finish = True + + def consumer(): + while not finish: + yield from m._stop.call() + + # Make sure that the time is updated first. + yield Settle() + latencies.append(time - event_queue.get()) + + yield from self.random_wait_geom(1.0 / self.expected_consumer_wait) + + self.assertEqual(min(latencies), (yield m._dut.histogram.min.value)) + self.assertEqual(max(latencies), (yield m._dut.histogram.max.value)) + self.assertEqual(sum(latencies), (yield m._dut.histogram.sum.value)) + self.assertEqual(len(latencies), (yield m._dut.histogram.count.value)) + + for i in range(m._dut.histogram.bucket_count): + bucket_start = 0 if i == 0 else 2 ** (i - 1) + bucket_end = 1e10 if i == m._dut.histogram.bucket_count - 1 else 2**i + + count = sum(1 for x in latencies if bucket_start <= x < bucket_end) + self.assertEqual(count, (yield m._dut.histogram.buckets[i].value)) + + with self.run_simulation(m) as sim: + sim.add_sync_process(producer) + sim.add_sync_process(consumer) + sim.add_sync_process(ticker) + + +class MetricManagerTestCircuit(Elaboratable): + def __init__(self): + self.incr_counters = Method(i=[("counter1", 1), ("counter2", 1), ("counter3", 1)]) + + self.counter1 = HwCounter("foo.counter1", "this is the description") + self.counter2 = HwCounter("bar.baz.counter2") + self.counter3 = HwCounter("bar.baz.counter3", "yet another description") + + def elaborate(self, platform): + m = TModule() + + m.submodules += [self.counter1, self.counter2, self.counter3] + + @def_method(m, self.incr_counters) + def _(counter1, counter2, counter3): + self.counter1.incr(m, cond=counter1) + self.counter2.incr(m, cond=counter2) + self.counter3.incr(m, cond=counter3) + + return m + + +class TestMetricsManager(TestCaseWithSimulator): + def test_metrics_metadata(self): + # We need to initialize the circuit to make sure that metrics are registered + # in the dependency manager. + m = MetricManagerTestCircuit() + metrics_manager = HardwareMetricsManager() + + # Run the simulation so Amaranth doesn't scream that we have unused elaboratables. + with self.run_simulation(m): + pass + + self.assertEqual( + metrics_manager.get_metrics()["foo.counter1"].to_json(), # type: ignore + json.dumps( + { + "fully_qualified_name": "foo.counter1", + "description": "this is the description", + "regs": {"count": {"name": "count", "description": "the value of the counter", "width": 32}}, + } + ), + ) + + self.assertEqual( + metrics_manager.get_metrics()["bar.baz.counter2"].to_json(), # type: ignore + json.dumps( + { + "fully_qualified_name": "bar.baz.counter2", + "description": "", + "regs": {"count": {"name": "count", "description": "the value of the counter", "width": 32}}, + } + ), + ) + + self.assertEqual( + metrics_manager.get_metrics()["bar.baz.counter3"].to_json(), # type: ignore + json.dumps( + { + "fully_qualified_name": "bar.baz.counter3", + "description": "yet another description", + "regs": {"count": {"name": "count", "description": "the value of the counter", "width": 32}}, + } + ), + ) + + def test_returned_reg_values(self): + random.seed(42) + + m = SimpleTestCircuit(MetricManagerTestCircuit()) + metrics_manager = HardwareMetricsManager() + + DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) + + def test_process(): + counters = [0] * 3 + for _ in range(200): + rand = [random.randint(0, 1) for _ in range(3)] + + yield from m.incr_counters.call(counter1=rand[0], counter2=rand[1], counter3=rand[2]) + yield + + for i in range(3): + if rand[i] == 1: + counters[i] += 1 + + self.assertEqual(counters[0], (yield metrics_manager.get_register_value("foo.counter1", "count"))) + self.assertEqual(counters[1], (yield metrics_manager.get_register_value("bar.baz.counter2", "count"))) + self.assertEqual(counters[2], (yield metrics_manager.get_register_value("bar.baz.counter3", "count"))) + + with self.run_simulation(m) as sim: + sim.add_sync_process(test_process) diff --git a/test/transactron/testing/__init__.py b/test/transactron/testing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/common/_test/test_infrastructure.py b/test/transactron/testing/test_infrastructure.py similarity index 95% rename from test/common/_test/test_infrastructure.py rename to test/transactron/testing/test_infrastructure.py index ecf1c84d9..4e219bca9 100644 --- a/test/common/_test/test_infrastructure.py +++ b/test/transactron/testing/test_infrastructure.py @@ -1,5 +1,5 @@ from amaranth import * -from test.common import * +from transactron.testing import * class EmptyCircuit(Elaboratable): diff --git a/test/transactron/testing/test_log.py b/test/transactron/testing/test_log.py new file mode 100644 index 000000000..69f537fdd --- /dev/null +++ b/test/transactron/testing/test_log.py @@ -0,0 +1,124 @@ +from amaranth import * + +from transactron import * +from transactron.testing import TestCaseWithSimulator +from transactron.lib import logging + +LOGGER_NAME = "test_logger" + +log = logging.HardwareLogger(LOGGER_NAME) + + +class LogTest(Elaboratable): + def __init__(self): + self.input = Signal(range(100)) + self.counter = Signal(range(200)) + + def elaborate(self, platform): + m = TModule() + + with m.If(self.input == 42): + log.warning(m, True, "Log triggered under Amaranth If value+3=0x{:x}", self.input + 3) + + log.warning(m, self.input[0] == 0, "Input is even! input={}, counter={}", self.input, self.counter) + + m.d.sync += self.counter.eq(self.counter + 1) + + return m + + +class ErrorLogTest(Elaboratable): + def __init__(self): + self.input = Signal() + self.output = Signal() + + def elaborate(self, platform): + m = TModule() + + m.d.comb += self.output.eq(self.input & ~self.input) + + log.error( + m, + self.input != self.output, + "Input is different than output! input=0x{:x} output=0x{:x}", + self.input, + self.output, + ) + + return m + + +class AssertionTest(Elaboratable): + def __init__(self): + self.input = Signal() + self.output = Signal() + + def elaborate(self, platform): + m = TModule() + + m.d.comb += self.output.eq(self.input & ~self.input) + + log.assertion(m, self.input == self.output, "Output differs") + + return m + + +class TestLog(TestCaseWithSimulator): + def test_log(self): + m = LogTest() + + def proc(): + for i in range(50): + yield + yield m.input.eq(i) + + with self.assertLogs(LOGGER_NAME) as logs: + with self.run_simulation(m) as sim: + sim.add_sync_process(proc) + + self.assertIn( + "WARNING:test_logger:test/transactron/testing/test_log.py:21] Log triggered under Amaranth If value+3=0x2d", + logs.output, + ) + for i in range(0, 50, 2): + expected_msg = ( + "WARNING:test_logger:test/transactron/testing/test_log.py:23] " + + f"Input is even! input={i}, counter={i + 2}" + ) + self.assertIn( + expected_msg, + logs.output, + ) + + def test_error_log(self): + m = ErrorLogTest() + + def proc(): + yield + yield m.input.eq(1) + + with self.assertLogs(LOGGER_NAME) as logs: + with self.assertRaises(AssertionError): + with self.run_simulation(m) as sim: + sim.add_sync_process(proc) + + extected_out = ( + "ERROR:test_logger:test/transactron/testing/test_log.py:40] " + + "Input is different than output! input=0x1 output=0x0" + ) + self.assertIn(extected_out, logs.output) + + def test_assertion(self): + m = AssertionTest() + + def proc(): + yield + yield m.input.eq(1) + + with self.assertLogs(LOGGER_NAME) as logs: + with self.assertRaises(AssertionError): + with self.run_simulation(m) as sim: + sim.add_sync_process(proc) + + extected_out = "ERROR:test_logger:test/transactron/testing/test_log.py:61] Output differs" + self.assertIn(extected_out, logs.output) diff --git a/test/utils/test_fifo.py b/test/utils/test_fifo.py index 934a58475..30ca10c1c 100644 --- a/test/utils/test_fifo.py +++ b/test/utils/test_fifo.py @@ -3,7 +3,7 @@ from transactron.lib import AdapterTrans, BasicFifo -from test.common import TestCaseWithSimulator, TestbenchIO, data_layout +from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout from collections import deque from parameterized import parameterized_class import random diff --git a/test/utils/test_onehotswitch.py b/test/utils/test_onehotswitch.py index 12d8373b9..4097b76d1 100644 --- a/test/utils/test_onehotswitch.py +++ b/test/utils/test_onehotswitch.py @@ -3,7 +3,7 @@ from transactron.utils import OneHotSwitch -from test.common import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator from parameterized import parameterized diff --git a/test/utils/test_utils.py b/test/utils/test_utils.py index a26567003..09f3d710c 100644 --- a/test/utils/test_utils.py +++ b/test/utils/test_utils.py @@ -2,7 +2,7 @@ import random from amaranth import * -from test.common import * +from transactron.testing import * from transactron.utils import ( align_to_power_of_two, align_down_to_power_of_two, diff --git a/transactron/__init__.py b/transactron/__init__.py index ce1898da3..de27375ac 100644 --- a/transactron/__init__.py +++ b/transactron/__init__.py @@ -3,7 +3,6 @@ __all__ = [ "TModule", "TransactionManager", - "TransactionContext", "TransactionModule", "Transaction", "Method", diff --git a/transactron/core.py b/transactron/core.py index 0fd6fc29e..e85627437 100644 --- a/transactron/core.py +++ b/transactron/core.py @@ -12,16 +12,18 @@ Tuple, TypeVar, Protocol, + Self, runtime_checkable, ) from os import environ from graphlib import TopologicalSorter -from typing_extensions import Self from dataclasses import dataclass, replace from amaranth import * from amaranth import tracer from itertools import count, chain, filterfalse, product -from amaranth.hdl.dsl import FSM, _ModuleBuilderDomain +from amaranth.hdl._dsl import FSM + +from transactron.utils.assign import AssignArg from .graph import Owned, OwnershipGraph, Direction from transactron.utils import * @@ -32,7 +34,7 @@ "Priority", "TModule", "TransactionManager", - "TransactionContext", + "TransactionManagerKey", "TransactionModule", "Transaction", "Method", @@ -518,9 +520,9 @@ def visual_graph(self, fragment): graph = OwnershipGraph(fragment) method_map = MethodMap(self.transactions) for method, transactions in method_map.transactions_by_method.items(): - if len(method.data_in) > len(method.data_out): + if len(method.data_in.as_value()) > len(method.data_out.as_value()): direction = Direction.IN - elif len(method.data_in) < len(method.data_out): + elif method.data_in.shape().size < method.data_out.shape().size: direction = Direction.OUT else: direction = Direction.INOUT @@ -551,25 +553,9 @@ def method_debug(m: Method): } -class TransactionContext: - stack: list[TransactionManager] = [] - - def __init__(self, manager: TransactionManager): - self.manager = manager - - def __enter__(self): - self.stack.append(self.manager) - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - top = self.stack.pop() - assert self.manager is top - - @classmethod - def get(cls) -> TransactionManager: - if not cls.stack: - raise RuntimeError("TransactionContext stack is empty") - return cls.stack[-1] +@dataclass(frozen=True) +class TransactionManagerKey(SimpleKey[TransactionManager]): + pass class TransactionModule(Elaboratable): @@ -578,56 +564,85 @@ class TransactionModule(Elaboratable): which adds support for transactions. It creates a `TransactionManager` which will handle transaction scheduling and can be used in definition of `Method`\\s and `Transaction`\\s. + The `TransactionManager` is stored in a `DependencyManager`. """ - def __init__(self, elaboratable: HasElaborate, manager: Optional[TransactionManager] = None): + def __init__( + self, + elaboratable: HasElaborate, + dependency_manager: Optional[DependencyManager] = None, + transaction_manager: Optional[TransactionManager] = None, + ): """ Parameters ---------- elaboratable: HasElaborate - The `Elaboratable` which should be wrapped to add support for - transactions and methods. + The `Elaboratable` which should be wrapped to add support for + transactions and methods. + dependency_manager: DependencyManager, optional + The `DependencyManager` to use inside the transaction module. + If omitted, a new one is created. + transaction_manager: TransactionManager, optional + The `TransactionManager` to use inside the transaction module. + If omitted, a new one is created. """ - if manager is None: - manager = TransactionManager() - self.transactionManager = manager + if transaction_manager is None: + transaction_manager = TransactionManager() + if dependency_manager is None: + dependency_manager = DependencyManager() + self.manager = dependency_manager + self.manager.add_dependency(TransactionManagerKey(), transaction_manager) self.elaboratable = elaboratable - def transaction_context(self) -> TransactionContext: - return TransactionContext(self.transactionManager) + def context(self) -> DependencyContext: + return DependencyContext(self.manager) def elaborate(self, platform): - with silence_mustuse(self.transactionManager): - with self.transaction_context(): + with silence_mustuse(self.manager.get_dependency(TransactionManagerKey())): + with self.context(): elaboratable = Fragment.get(self.elaboratable, platform) m = Module() m.submodules.main_module = elaboratable - m.submodules.transactionManager = self.transactionManager + m.submodules.transactionManager = self.manager.get_dependency(TransactionManagerKey()) return m +class _AvoidingModuleBuilderDomain: + """ + A wrapper over Amaranth domain to abstract away internal Amaranth implementation. + It is needed to allow for correctness check in `__setattr__` which uses `isinstance`. + """ + + def __init__(self, amaranth_module_domain): + self._domain = amaranth_module_domain + + def __iadd__(self, assigns: StatementLike) -> Self: + self._domain.__iadd__(assigns) + return self + + class _AvoidingModuleBuilderDomains: _m: "TModule" def __init__(self, m: "TModule"): object.__setattr__(self, "_m", m) - def __getattr__(self, name: str) -> _ModuleBuilderDomain: + def __getattr__(self, name: str) -> _AvoidingModuleBuilderDomain: if name == "av_comb": - return self._m.avoiding_module.d["comb"] + return _AvoidingModuleBuilderDomain(self._m.avoiding_module.d["comb"]) elif name == "top_comb": - return self._m.top_module.d["comb"] + return _AvoidingModuleBuilderDomain(self._m.top_module.d["comb"]) else: - return self._m.main_module.d[name] + return _AvoidingModuleBuilderDomain(self._m.main_module.d[name]) - def __getitem__(self, name: str) -> _ModuleBuilderDomain: + def __getitem__(self, name: str) -> _AvoidingModuleBuilderDomain: return self.__getattr__(name) def __setattr__(self, name: str, value): - if not isinstance(value, _ModuleBuilderDomain): + if not isinstance(value, _AvoidingModuleBuilderDomain): raise AttributeError(f"Cannot assign 'd.{name}' attribute; did you mean 'd.{name} +='?") def __setitem__(self, name: str, value): @@ -880,8 +895,8 @@ class TransactionBase(Owned, Protocol): defined: bool = False name: str src_loc: SrcLoc - method_uses: dict["Method", tuple[Record, Signal]] - method_calls: defaultdict["Method", list[tuple[CtrlPath, Record, ValueLike]]] + method_uses: dict["Method", tuple[MethodStruct, Signal]] + method_calls: defaultdict["Method", list[tuple[CtrlPath, MethodStruct, ValueLike]]] relations: list[RelationBase] simultaneous_list: list[TransactionOrMethod] independent_list: list[TransactionOrMethod] @@ -1091,7 +1106,7 @@ def __init__( self.owner, owner_name = get_caller_class_name(default="$transaction") self.name = name or tracer.get_var_name(depth=2, default=owner_name) if manager is None: - manager = TransactionContext.get() + manager = DependencyContext.get().get_dependency(TransactionManagerKey()) manager.add_transaction(self) self.request = Signal(name=self.owned_name + "_request") self.runnable = Signal(name=self.owned_name + "_runnable") @@ -1145,9 +1160,10 @@ class Method(TransactionBase): behavior.) Calling a `Method` always takes a single clock cycle. Data is combinationally transferred between to and from `Method`\\s - using Amaranth `Record`\\s. The transfer can take place in both directions - at the same time: from the called `Method` to the caller (`data_out`) - and from the caller to the called `Method` (`data_in`). + using Amaranth structures (`View` with a `StructLayout`). The transfer + can take place in both directions at the same time: from the called + `Method` to the caller (`data_out`) and from the caller to the called + `Method` (`data_in`). A module which defines a `Method` should use `body` or `def_method` to describe the method's effect on the module state. @@ -1162,10 +1178,10 @@ class Method(TransactionBase): run: Signal, out Signals that the method is called in the current cycle by some `Transaction`. Defined by the `TransactionManager`. - data_in: Record, out + data_in: MethodStruct, out Contains the data passed to the `Method` by the caller (a `Transaction` or another `Method`). - data_out: Record, in + data_out: MethodStruct, in Contains the data passed from the `Method` to the caller (a `Transaction` or another `Method`). Typically defined by calling `body`. @@ -1187,12 +1203,10 @@ def __init__( name: str or None Name hint for this `Method`. If `None` (default) the name is inferred from the variable name this `Method` is assigned to. - i: record layout - The format of `data_in`. - An `int` corresponds to a `Record` with a single `data` field. - o: record layout + i: method layout The format of `data_in`. - An `int` corresponds to a `Record` with a single `data` field. + o: method layout + The format of `data_out`. nonexclusive: bool If true, the method is non-exclusive: it can be called by multiple transactions in the same clock cycle. If such a situation happens, @@ -1211,13 +1225,21 @@ def __init__( self.name = name or tracer.get_var_name(depth=2, default=owner_name) self.ready = Signal(name=self.owned_name + "_ready") self.run = Signal(name=self.owned_name + "_run") - self.data_in = Record(i) - self.data_out = Record(o) + self.data_in: MethodStruct = Signal(from_method_layout(i)) + self.data_out: MethodStruct = Signal(from_method_layout(o)) self.nonexclusive = nonexclusive self.single_caller = single_caller self.validate_arguments: Optional[Callable[..., ValueLike]] = None if nonexclusive: - assert len(self.data_in) == 0 + assert len(self.data_in.as_value()) == 0 + + @property + def layout_in(self): + return self.data_in.shape() + + @property + def layout_out(self): + return self.data_out.shape() @staticmethod def like(other: "Method", *, name: Optional[str] = None, src_loc: int | SrcLoc = 0) -> "Method": @@ -1241,7 +1263,7 @@ def like(other: "Method", *, name: Optional[str] = None, src_loc: int | SrcLoc = Method The freshly constructed `Method`. """ - return Method(name=name, i=other.data_in.layout, o=other.data_out.layout, src_loc=get_src_loc(src_loc)) + return Method(name=name, i=other.layout_in, o=other.layout_out, src_loc=get_src_loc(src_loc)) def proxy(self, m: TModule, method: "Method"): """Define as a proxy for another method. @@ -1269,7 +1291,7 @@ def body( ready: ValueLike = C(1), out: ValueLike = C(0, 0), validate_arguments: Optional[Callable[..., ValueLike]] = None, - ) -> Iterator[Record]: + ) -> Iterator[MethodStruct]: """Define method body The `body` context manager can be used to define the actions @@ -1288,7 +1310,7 @@ def body( Signal to indicate if the method is ready to be run. By default it is `Const(1)`, so the method is always ready. Assigned combinationially to the `ready` attribute. - out : Record, in + out : Value, in Data generated by the `Method`, which will be passed to the caller (a `Transaction` or another `Method`). Assigned combinationally to the `data_out` attribute. @@ -1327,14 +1349,14 @@ def body( with m.AvoidedIf(self.run): yield self.data_in - def _validate_arguments(self, arg_rec: Record) -> ValueLike: + def _validate_arguments(self, arg_rec: MethodStruct) -> ValueLike: if self.validate_arguments is not None: return self.ready & method_def_helper(self, self.validate_arguments, arg_rec) return self.ready def __call__( - self, m: TModule, arg: Optional[RecordDict] = None, enable: ValueLike = C(1), /, **kwargs: RecordDict - ) -> Record: + self, m: TModule, arg: Optional[AssignArg] = None, enable: ValueLike = C(1), /, **kwargs: AssignArg + ) -> MethodStruct: """Call a method. Methods can only be called from transaction and method bodies. @@ -1348,7 +1370,7 @@ def __call__( m : TModule Module in which operations on signals should be executed, arg : Value or dict of Values - Call argument. Can be passed as a `Record` of the method's + Call argument. Can be passed as a `View` of the method's input layout or as a dictionary. Alternative syntax uses keyword arguments. enable : Value @@ -1361,7 +1383,7 @@ def __call__( Returns ------- - data_out : Record + data_out : MethodStruct The result of the method call. Examples @@ -1370,7 +1392,7 @@ def __call__( .. code-block:: python m = Module() - with Transaction.body(m): + with Transaction().body(m): ret = my_sum_method(m, arg1=2, arg2=3) Alternative syntax: @@ -1378,10 +1400,10 @@ def __call__( .. highlight:: python .. code-block:: python - with Transaction.body(m): + with Transaction().body(m): ret = my_sum_method(m, {"arg1": 2, "arg2": 3}) """ - arg_rec = Record.like(self.data_in) + arg_rec = Signal.like(self.data_in) if arg is not None and kwargs: raise ValueError(f"Method '{self.name}' call with both keyword arguments and legacy record argument") @@ -1399,7 +1421,7 @@ def __call__( caller.method_calls[self].append((m.ctrl_path, arg_rec, enable_sig)) if self not in caller.method_uses: - arg_rec_use = Record.like(self.data_in) + arg_rec_use = Signal(self.layout_in) arg_rec_enable_sig = Signal() caller.method_uses[self] = (arg_rec_use, arg_rec_enable_sig) @@ -1427,9 +1449,9 @@ def def_method( The decorated function should take keyword arguments corresponding to the fields of the method's input layout. The `**kwargs` syntax is supported. Alternatively, it can take one argument named `arg`, which will be a - record with input signals. + structure with input signals. - The returned value can be either a record with the method's output layout + The returned value can be either a structure with the method's output layout or a dictionary of outputs. Parameters @@ -1469,7 +1491,7 @@ def _(arg1, arg2): def _(**args): return args["arg1"] + args["arg2"] - Alternative syntax (arg record): + Alternative syntax (arg structure): .. highlight:: python .. code-block:: python @@ -1479,8 +1501,8 @@ def _(arg): return {"res": arg.arg1 + arg.arg2} """ - def decorator(func: Callable[..., Optional[RecordDict]]): - out = Record.like(method.data_out) + def decorator(func: Callable[..., Optional[AssignArg]]): + out = Signal(method.layout_out) ret_out = None with method.body(m, ready=ready, out=out, validate_arguments=validate_arguments) as arg: diff --git a/transactron/graph.py b/transactron/graph.py index 4cd51d067..024e9bb0b 100644 --- a/transactron/graph.py +++ b/transactron/graph.py @@ -6,7 +6,7 @@ from collections import defaultdict from typing import Literal, Optional, Protocol -from amaranth.hdl.ir import Elaboratable, Fragment +from amaranth import Elaboratable, Fragment from .tracing import TracingFragment diff --git a/transactron/lib/__init__.py b/transactron/lib/__init__.py index c814b5e93..f6dd3ef0a 100644 --- a/transactron/lib/__init__.py +++ b/transactron/lib/__init__.py @@ -6,3 +6,4 @@ from .reqres import * # noqa: F401 from .storage import * # noqa: F401 from .simultaneous import * # noqa: F401 +from .metrics import * # noqa: F401 diff --git a/transactron/lib/adapters.py b/transactron/lib/adapters.py index 6f99e8983..ed7f2640f 100644 --- a/transactron/lib/adapters.py +++ b/transactron/lib/adapters.py @@ -1,6 +1,6 @@ from amaranth import * -from ..utils import SrcLoc, get_src_loc +from ..utils import SrcLoc, get_src_loc, MethodStruct from ..core import * from ..core import SignalBundle from ..utils._typing import type_self_kwargs_as @@ -13,8 +13,8 @@ class AdapterBase(Elaboratable): - data_in: Record - data_out: Record + data_in: MethodStruct + data_out: MethodStruct def __init__(self, iface: Method): self.iface = iface @@ -56,8 +56,8 @@ def __init__(self, iface: Method, *, src_loc: int | SrcLoc = 0): """ super().__init__(iface) self.src_loc = get_src_loc(src_loc) - self.data_in = Record.like(iface.data_in) - self.data_out = Record.like(iface.data_out) + self.data_in = Signal.like(iface.data_in) + self.data_out = Signal.like(iface.data_out) def elaborate(self, platform): m = TModule() @@ -105,8 +105,8 @@ def __init__(self, **kwargs): kwargs["src_loc"] = get_src_loc(kwargs.setdefault("src_loc", 0)) super().__init__(Method(**kwargs)) - self.data_in = Record.like(self.iface.data_out) - self.data_out = Record.like(self.iface.data_in) + self.data_in = Signal.like(self.iface.data_out) + self.data_out = Signal.like(self.iface.data_in) def elaborate(self, platform): m = TModule() diff --git a/transactron/lib/buttons.py b/transactron/lib/buttons.py index ec3c796ce..59bf081b5 100644 --- a/transactron/lib/buttons.py +++ b/transactron/lib/buttons.py @@ -1,4 +1,6 @@ from amaranth import * + +from transactron.utils.transactron_helpers import from_method_layout from ..core import * from ..utils import SrcLoc, get_src_loc @@ -17,10 +19,10 @@ class ClickIn(Elaboratable): ---------- get: Method The method for retrieving data from the input. Accepts an empty - argument, returns a `Record`. + argument, returns a structure. btn: Signal, in The button input. - dat: Record, in + dat: MethodStruct, in The data input. """ @@ -28,7 +30,7 @@ def __init__(self, layout: MethodLayout, src_loc: int | SrcLoc = 0): """ Parameters ---------- - layout: record layout + layout: method layout The data format for the input. src_loc: int | SrcLoc How many stack frames deep the source location is taken from. @@ -37,7 +39,7 @@ def __init__(self, layout: MethodLayout, src_loc: int | SrcLoc = 0): src_loc = get_src_loc(src_loc) self.get = Method(o=layout, src_loc=src_loc) self.btn = Signal() - self.dat = Record(layout) + self.dat = Signal(from_method_layout(layout)) def elaborate(self, platform): m = TModule() @@ -73,11 +75,11 @@ class ClickOut(Elaboratable): Attributes ---------- put: Method - The method for retrieving data from the input. Accepts a `Record`, + The method for retrieving data from the input. Accepts a structure, returns empty result. btn: Signal, in The button input. - dat: Record, out + dat: MethodStruct, out The data output. """ @@ -85,7 +87,7 @@ def __init__(self, layout: MethodLayout, *, src_loc: int | SrcLoc = 0): """ Parameters ---------- - layout: record layout + layout: method layout The data format for the output. src_loc: int | SrcLoc How many stack frames deep the source location is taken from. @@ -94,7 +96,7 @@ def __init__(self, layout: MethodLayout, *, src_loc: int | SrcLoc = 0): src_loc = get_src_loc(src_loc) self.put = Method(i=layout, src_loc=src_loc) self.btn = Signal() - self.dat = Record(layout) + self.dat = Signal(from_method_layout(layout)) def elaborate(self, platform): m = TModule() diff --git a/transactron/lib/connectors.py b/transactron/lib/connectors.py index 81918c2a6..b9a6eb204 100644 --- a/transactron/lib/connectors.py +++ b/transactron/lib/connectors.py @@ -1,5 +1,7 @@ from amaranth import * import amaranth.lib.fifo + +from transactron.utils.transactron_helpers import from_method_layout from ..core import * from ..utils import SrcLoc, get_src_loc @@ -24,9 +26,9 @@ class FIFO(Elaboratable): Attributes ---------- read: Method - The read method. Accepts an empty argument, returns a `Record`. + The read method. Accepts an empty argument, returns a structure. write: Method - The write method. Accepts a `Record`, returns empty result. + The write method. Accepts a structure, returns empty result. """ def __init__( @@ -35,8 +37,8 @@ def __init__( """ Parameters ---------- - layout: record layout - The format of records stored in the FIFO. + layout: method layout + The format of structures stored in the FIFO. depth: int Size of the FIFO. fifoType: Elaboratable @@ -46,7 +48,7 @@ def __init__( How many stack frames deep the source location is taken from. Alternatively, the source location to use instead of the default. """ - self.width = len(Record(layout)) + self.width = from_method_layout(layout).size self.depth = depth self.fifoType = fifo_type @@ -59,8 +61,6 @@ def elaborate(self, platform): m.submodules.fifo = fifo = self.fifoType(width=self.width, depth=self.depth) - assert fifo.fwft # the read method requires FWFT behavior - @def_method(m, self.write, ready=fifo.w_rdy) def _(arg): m.d.comb += fifo.w_en.eq(1) @@ -92,17 +92,17 @@ class Forwarder(Elaboratable): Attributes ---------- read: Method - The read method. Accepts an empty argument, returns a `Record`. + The read method. Accepts an empty argument, returns a structure. write: Method - The write method. Accepts a `Record`, returns empty result. + The write method. Accepts a structure, returns empty result. """ def __init__(self, layout: MethodLayout, *, src_loc: int | SrcLoc = 0): """ Parameters ---------- - layout: record layout - The format of records forwarded. + layout: method layout + The format of structures forwarded. src_loc: int | SrcLoc How many stack frames deep the source location is taken from. Alternatively, the source location to use instead of the default. @@ -111,7 +111,7 @@ def __init__(self, layout: MethodLayout, *, src_loc: int | SrcLoc = 0): self.read = Method(o=layout, src_loc=src_loc) self.write = Method(i=layout, src_loc=src_loc) self.clear = Method(src_loc=src_loc) - self.head = Record.like(self.read.data_out) + self.head = Signal.like(self.read.data_out) self.clear.add_conflict(self.read, Priority.LEFT) self.clear.add_conflict(self.write, Priority.LEFT) @@ -119,9 +119,9 @@ def __init__(self, layout: MethodLayout, *, src_loc: int | SrcLoc = 0): def elaborate(self, platform): m = TModule() - reg = Record.like(self.read.data_out) + reg = Signal.like(self.read.data_out) reg_valid = Signal() - read_value = Record.like(self.read.data_out) + read_value = Signal.like(self.read.data_out) m.d.comb += self.head.eq(read_value) self.write.schedule_before(self.read) # to avoid combinational loops @@ -159,21 +159,21 @@ class Connect(Elaboratable): Attributes ---------- read: Method - The read method. Accepts a (possibly empty) `Record`, returns - a `Record`. + The read method. Accepts a (possibly empty) structure, returns + a structure. write: Method - The write method. Accepts a `Record`, returns a (possibly empty) - `Record`. + The write method. Accepts a structure, returns a (possibly empty) + structure. """ def __init__(self, layout: MethodLayout = (), rev_layout: MethodLayout = (), *, src_loc: int | SrcLoc = 0): """ Parameters ---------- - layout: record layout - The format of records forwarded. - rev_layout: record layout - The format of records forwarded in the reverse direction. + layout: method layout + The format of structures forwarded. + rev_layout: method layout + The format of structures forwarded in the reverse direction. src_loc: int | SrcLoc How many stack frames deep the source location is taken from. Alternatively, the source location to use instead of the default. @@ -185,8 +185,8 @@ def __init__(self, layout: MethodLayout = (), rev_layout: MethodLayout = (), *, def elaborate(self, platform): m = TModule() - read_value = Record.like(self.read.data_out) - rev_read_value = Record.like(self.write.data_out) + read_value = Signal.like(self.read.data_out) + rev_read_value = Signal.like(self.write.data_out) self.write.simultaneous(self.read) @@ -232,8 +232,8 @@ def elaborate(self, platform): m = TModule() with Transaction(src_loc=self.src_loc).body(m): - data1 = Record.like(self.method1.data_out) - data2 = Record.like(self.method2.data_out) + data1 = Signal.like(self.method1.data_out) + data2 = Signal.like(self.method2.data_out) m.d.top_comb += data1.eq(self.method1(m, data2)) m.d.top_comb += data2.eq(self.method2(m, data1)) diff --git a/transactron/lib/fifo.py b/transactron/lib/fifo.py index 3c74dad04..92ac0f7bb 100644 --- a/transactron/lib/fifo.py +++ b/transactron/lib/fifo.py @@ -1,8 +1,8 @@ from amaranth import * from transactron import Method, def_method, Priority, TModule -from transactron.utils._typing import ValueLike, MethodLayout, SrcLoc +from transactron.utils._typing import ValueLike, MethodLayout, SrcLoc, MethodStruct from transactron.utils.amaranth_ext import mod_incr -from transactron.utils.transactron_helpers import get_src_loc +from transactron.utils.transactron_helpers import from_method_layout, get_src_loc class BasicFifo(Elaboratable): @@ -11,10 +11,10 @@ class BasicFifo(Elaboratable): Attributes ---------- read: Method - Reads from the FIFO. Accepts an empty argument, returns a `Record`. + Reads from the FIFO. Accepts an empty argument, returns a structure. Ready only if the FIFO is not empty. write: Method - Writes to the FIFO. Accepts a `Record`, returns empty result. + Writes to the FIFO. Accepts a structure, returns empty result. Ready only if the FIFO is not full. clear: Method Clears the FIFO entries. Has priority over `read` and `write` methods. @@ -26,9 +26,8 @@ def __init__(self, layout: MethodLayout, depth: int, *, src_loc: int | SrcLoc = """ Parameters ---------- - layout: record layout + layout: method layout Layout of data stored in the FIFO. - If integer is given, Record with field `data` and width of this paramter is used as internal layout. depth: int Size of the FIFO. src_loc: int | SrcLoc @@ -36,14 +35,14 @@ def __init__(self, layout: MethodLayout, depth: int, *, src_loc: int | SrcLoc = Alternatively, the source location to use instead of the default. """ self.layout = layout - self.width = len(Record(self.layout)) + self.width = from_method_layout(self.layout).size self.depth = depth src_loc = get_src_loc(src_loc) self.read = Method(o=self.layout, src_loc=src_loc) self.write = Method(i=self.layout, src_loc=src_loc) self.clear = Method(src_loc=src_loc) - self.head = Record(self.layout) + self.head = Signal(from_method_layout(layout)) self.buff = Memory(width=self.width, depth=self.depth) @@ -82,7 +81,7 @@ def elaborate(self, platform): m.d.comb += self.head.eq(self.buff_rdport.data) @def_method(m, self.write, ready=self.write_ready) - def _(arg: Record) -> None: + def _(arg: MethodStruct) -> None: m.d.top_comb += self.buff_wrport.addr.eq(self.write_idx) m.d.top_comb += self.buff_wrport.data.eq(arg) m.d.comb += self.buff_wrport.en.eq(1) diff --git a/transactron/lib/logging.py b/transactron/lib/logging.py new file mode 100644 index 000000000..7eb06deb1 --- /dev/null +++ b/transactron/lib/logging.py @@ -0,0 +1,229 @@ +import os +import re +import operator +import logging +from functools import reduce +from dataclasses import dataclass, field +from dataclasses_json import dataclass_json +from typing import TypeAlias + +from amaranth import * +from amaranth.tracer import get_src_loc + +from transactron.utils import SrcLoc +from transactron.utils._typing import ModuleLike, ValueLike +from transactron.utils.dependencies import DependencyContext, ListKey + +LogLevel: TypeAlias = int + + +@dataclass_json +@dataclass +class LogRecordInfo: + """Simulator-backend-agnostic information about a log record that can + be serialized and used outside the Amaranth context. + + Attributes + ---------- + logger_name: str + + level: LogLevel + The severity level of the log. + format_str: str + The template of the message. Should follow PEP 3101 standard. + location: SrcLoc + Source location of the log. + """ + + logger_name: str + level: LogLevel + format_str: str + location: SrcLoc + + def format(self, *args) -> str: + """Format the log message with a set of concrete arguments.""" + + return self.format_str.format(*args) + + +@dataclass +class LogRecord(LogRecordInfo): + """A LogRecord instance represents an event being logged. + + Attributes + ---------- + trigger: Signal + Amaranth signal triggering the log. + fields: Signal + Amaranth signals that will be used to format the message. + """ + + trigger: Signal + fields: list[Signal] = field(default_factory=list) + + +@dataclass(frozen=True) +class LogKey(ListKey[LogRecord]): + pass + + +class HardwareLogger: + """A class for creating log messages in the hardware. + + Intuitively, the hardware logger works similarly to a normal software + logger. You can log a message anywhere in the circuit, but due to the + parallel nature of the hardware you must specify a special trigger signal + which will indicate if a message shall be reported in that cycle. + + Hardware logs are evaluated and printed during simulation, so both + the trigger and the format fields are Amaranth values, i.e. + signals or arbitrary Amaranth expressions. + + Instances of the HardwareLogger class represent a logger for a single + submodule of the circuit. Exactly how a "submodule" is defined is up + to the developer. Submodule are identified by a unique string and + the names can be nested. Names are organized into a namespace hierarchy + where levels are separated by periods, much like the Python package + namespace. So in the instance, submodules names might be "frontend" + for the upper level, and "frontend.icache" and "frontend.bpu" for + the sub-levels. There is no arbitrary limit to the depth of nesting. + + Attributes + ---------- + name: str + Name of this logger. + """ + + def __init__(self, name: str): + """ + Parameters + ---------- + name: str + Name of this logger. Hierarchy levels are separated by periods, + e.g. "backend.fu.jumpbranch". + """ + self.name = name + + def log(self, m: ModuleLike, level: LogLevel, trigger: ValueLike, format: str, *args, src_loc_at: int = 0): + """Registers a hardware log record with the given severity. + + Parameters + ---------- + m: ModuleLike + The module for which the log record is added. + trigger: ValueLike + If the value of this Amaranth expression is true, the log will reported. + format: str + The format of the message as defined in PEP 3101. + *args + Amaranth values that will be read during simulation and used to format + the message. + src_loc_at: int, optional + How many stack frames below to look for the source location, used to + identify the failing assertion. + """ + + def local_src_loc(src_loc: SrcLoc): + return (os.path.relpath(src_loc[0]), src_loc[1]) + + src_loc = local_src_loc(get_src_loc(src_loc_at + 1)) + + trigger_signal = Signal() + m.d.comb += trigger_signal.eq(trigger) + + record = LogRecord( + logger_name=self.name, level=level, format_str=format, location=src_loc, trigger=trigger_signal + ) + + for arg in args: + sig = Signal.like(arg) + m.d.top_comb += sig.eq(arg) + record.fields.append(sig) + + dependencies = DependencyContext.get() + dependencies.add_dependency(LogKey(), record) + + def debug(self, m: ModuleLike, trigger: ValueLike, format: str, *args, **kwargs): + """Log a message with severity 'DEBUG'. + + See `HardwareLogger.log` function for more details. + """ + self.log(m, logging.DEBUG, trigger, format, *args, **kwargs) + + def info(self, m: ModuleLike, trigger: ValueLike, format: str, *args, **kwargs): + """Log a message with severity 'INFO'. + + See `HardwareLogger.log` function for more details. + """ + self.log(m, logging.INFO, trigger, format, *args, **kwargs) + + def warning(self, m: ModuleLike, trigger: ValueLike, format: str, *args, **kwargs): + """Log a message with severity 'WARNING'. + + See `HardwareLogger.log` function for more details. + """ + self.log(m, logging.WARNING, trigger, format, *args, **kwargs) + + def error(self, m: ModuleLike, trigger: ValueLike, format: str, *args, **kwargs): + """Log a message with severity 'ERROR'. + + This severity level has special semantics. If a log with this serverity + level is triggered, the simulation will be terminated. + + See `HardwareLogger.log` function for more details. + """ + self.log(m, logging.ERROR, trigger, format, *args, **kwargs) + + def assertion(self, m: ModuleLike, value: Value, format: str = "", *args, src_loc_at: int = 0, **kwargs): + """Define an assertion. + + This function might help find some hardware bugs which might otherwise be + hard to detect. If `value` is false, it will terminate the simulation or + it can also be used to turn on a warning LED on a board. + + Internally, this is a convenience wrapper over log.error. + + See `HardwareLogger.log` function for more details. + """ + self.error(m, ~value, format, *args, **kwargs, src_loc_at=src_loc_at + 1) + + +def get_log_records(level: LogLevel, namespace_regexp: str = ".*") -> list[LogRecord]: + """Get log records in for the given severity level and in the + specified namespace. + + This function returns all log records with the severity bigger or equal + to the specified level and belonging to the specified namespace. + + Parameters + ---------- + level: LogLevel + The minimum severity level. + namespace: str, optional + The regexp of the namespace. If not specified, logs from all namespaces + will be processed. + """ + + dependencies = DependencyContext.get() + all_logs = dependencies.get_dependency(LogKey()) + return [rec for rec in all_logs if rec.level >= level and re.search(namespace_regexp, rec.logger_name)] + + +def get_trigger_bit(level: LogLevel, namespace_regexp: str = ".*") -> Value: + """Get a trigger bit for logs of the given severity level and + in the specified namespace. + + The signal returned by this function is high whenever the trigger signal + of any of the records with the severity bigger or equal to the specified + level is high. + + Parameters + ---------- + level: LogLevel + The minimum severity level. + namespace: str, optional + The regexp of the namespace. If not specified, logs from all namespaces + will be processed. + """ + + return reduce(operator.or_, [rec.trigger for rec in get_log_records(level, namespace_regexp)], C(0)) diff --git a/transactron/lib/metrics.py b/transactron/lib/metrics.py new file mode 100644 index 000000000..2e706e0a3 --- /dev/null +++ b/transactron/lib/metrics.py @@ -0,0 +1,558 @@ +from dataclasses import dataclass, field +from dataclasses_json import dataclass_json +from typing import Optional +from abc import ABC + +from amaranth import * +from amaranth.utils import bits_for + +from transactron.utils import ValueLike +from transactron import Method, def_method, TModule +from transactron.utils import SignalBundle +from transactron.lib import FIFO +from transactron.utils.dependencies import ListKey, DependencyContext, SimpleKey + +__all__ = [ + "MetricRegisterModel", + "MetricModel", + "HwMetric", + "HwCounter", + "HwExpHistogram", + "LatencyMeasurer", + "HardwareMetricsManager", + "HwMetricsEnabledKey", +] + + +@dataclass_json +@dataclass(frozen=True) +class MetricRegisterModel: + """ + Represents a single register of a metric, serving as a fundamental + building block that holds a singular value. + + Attributes + ---------- + name: str + The unique identifier for the register (among remaning + registers of a specific metric). + description: str + A brief description of the metric's purpose. + width: int + The bit-width of the register. + """ + + name: str + description: str + width: int + + +@dataclass_json +@dataclass +class MetricModel: + """ + Provides information about a metric exposed by the circuit. Each metric + comprises multiple registers, each dedicated to storing specific values. + + The configuration of registers is internally determined by a + specific metric type and is not user-configurable. + + Attributes + ---------- + fully_qualified_name: str + The fully qualified name of the metric, with name components joined by dots ('.'), + e.g., 'foo.bar.requests'. + description: str + A human-readable description of the metric's functionality. + regs: list[MetricRegisterModel] + A list of registers associated with the metric. + """ + + fully_qualified_name: str + description: str + regs: dict[str, MetricRegisterModel] = field(default_factory=dict) + + +class HwMetricRegister(MetricRegisterModel): + """ + A concrete implementation of a metric register that holds its value as Amaranth signal. + + Attributes + ---------- + value: Signal + Amaranth signal representing the value of the register. + """ + + def __init__(self, name: str, width_bits: int, description: str = "", reset: int = 0): + """ + Parameters + ---------- + name: str + The unique identifier for the register (among remaning + registers of a specific metric). + width: int + The bit-width of the register. + description: str + A brief description of the metric's purpose. + reset: int + The reset value of the register. + """ + super().__init__(name, description, width_bits) + + self.value = Signal(width_bits, reset=reset, name=name) + + +@dataclass(frozen=True) +class HwMetricsListKey(ListKey["HwMetric"]): + """DependencyManager key collecting hardware metrics globally as a list.""" + + pass + + +@dataclass(frozen=True) +class HwMetricsEnabledKey(SimpleKey[bool]): + """ + DependencyManager key for enabling hardware metrics. If metrics are disabled, + none of theirs signals will be synthesized. + """ + + lock_on_get = False + empty_valid = True + default_value = False + + +class HwMetric(ABC, MetricModel): + """ + A base for all metric implementations. It should be only used for declaring + new types of metrics. + + It takes care of registering the metric in the dependency manager. + + Attributes + ---------- + signals: dict[str, Signal] + A mapping from a register name to a Signal containing the value of that register. + """ + + def __init__(self, fully_qualified_name: str, description: str): + """ + Parameters + ---------- + fully_qualified_name: str + The fully qualified name of the metric. + description: str + A human-readable description of the metric's functionality. + """ + super().__init__(fully_qualified_name, description) + + self.signals: dict[str, Signal] = {} + + # add the metric to the global list of all metrics + DependencyContext.get().add_dependency(HwMetricsListKey(), self) + + # So Amaranth doesn't report that the module is unused when metrics are disabled + self._MustUse__silence = True # type: ignore + + def add_registers(self, regs: list[HwMetricRegister]): + """ + Adds registers to a metric. Should be only called by inheriting classes + during initialization. + + Parameters + ---------- + regs: list[HwMetricRegister] + A list of registers to be registered. + """ + for reg in regs: + if reg.name in self.regs: + raise RuntimeError(f"Register {reg.name}' is already added to the metric {self.fully_qualified_name}") + + self.regs[reg.name] = reg + self.signals[reg.name] = reg.value + + def metrics_enabled(self) -> bool: + return DependencyContext.get().get_dependency(HwMetricsEnabledKey()) + + +class HwCounter(Elaboratable, HwMetric): + """Hardware Counter + + The most basic hardware metric that can just increase its value. + """ + + def __init__(self, fully_qualified_name: str, description: str = "", *, width_bits: int = 32): + """ + Parameters + ---------- + fully_qualified_name: str + The fully qualified name of the metric. + description: str + A human-readable description of the metric's functionality. + width_bits: int + The bit-width of the register. Defaults to 32 bits. + """ + + super().__init__(fully_qualified_name, description) + + self.count = HwMetricRegister("count", width_bits, "the value of the counter") + + self.add_registers([self.count]) + + self._incr = Method() + + def elaborate(self, platform): + if not self.metrics_enabled(): + return TModule() + + m = TModule() + + @def_method(m, self._incr) + def _(): + m.d.sync += self.count.value.eq(self.count.value + 1) + + return m + + def incr(self, m: TModule, *, cond: ValueLike = C(1)): + """ + Increases the value of the counter by 1. + + Should be called in the body of either a transaction or a method. + + Parameters + ---------- + m: TModule + Transactron module + """ + if not self.metrics_enabled(): + return + + with m.If(cond): + self._incr(m) + + +class HwExpHistogram(Elaboratable, HwMetric): + """Hardware Exponential Histogram + + Represents the distribution of sampled data through a histogram. A histogram + samples observations (usually things like request durations or queue sizes) and counts + them in a configurable number of buckets. The buckets are of exponential size. For example, + a histogram with 5 buckets would have the following value ranges: + [0, 1); [1, 2); [2, 4); [4, 8); [8, +inf). + + Additionally, the histogram tracks the number of observations, the sum + of observed values, and the minimum and maximum values. + """ + + def __init__( + self, + fully_qualified_name: str, + description: str = "", + *, + bucket_count: int, + sample_width: int = 32, + registers_width: int = 32, + ): + """ + Parameters + ---------- + fully_qualified_name: str + The fully qualified name of the metric. + description: str + A human-readable description of the metric's functionality. + max_value: int + The maximum value that the histogram would be able to count. This + value is used to calculate the number of buckets. + """ + + super().__init__(fully_qualified_name, description) + self.bucket_count = bucket_count + self.sample_width = sample_width + + self._add = Method(i=[("sample", self.sample_width)]) + + self.count = HwMetricRegister("count", registers_width, "the count of events that have been observed") + self.sum = HwMetricRegister("sum", registers_width, "the total sum of all observed values") + self.min = HwMetricRegister( + "min", + self.sample_width, + "the minimum of all observed values", + reset=(1 << self.sample_width) - 1, + ) + self.max = HwMetricRegister("max", self.sample_width, "the maximum of all observed values") + + self.buckets = [] + for i in range(self.bucket_count): + bucket_start = 0 if i == 0 else 2 ** (i - 1) + bucket_end = "inf" if i == self.bucket_count - 1 else 2**i + + self.buckets.append( + HwMetricRegister( + f"bucket-{bucket_end}", + registers_width, + f"the cumulative counter for the observation bucket [{bucket_start}, {bucket_end})", + ) + ) + + self.add_registers([self.count, self.sum, self.max, self.min] + self.buckets) + + def elaborate(self, platform): + if not self.metrics_enabled(): + return TModule() + + m = TModule() + + @def_method(m, self._add) + def _(sample): + m.d.sync += self.count.value.eq(self.count.value + 1) + m.d.sync += self.sum.value.eq(self.sum.value + sample) + + with m.If(sample > self.max.value): + m.d.sync += self.max.value.eq(sample) + + with m.If(sample < self.min.value): + m.d.sync += self.min.value.eq(sample) + + # todo: perhaps replace with a recursive implementation of the priority encoder + bucket_idx = Signal(range(self.sample_width)) + for i in range(self.sample_width): + with m.If(sample[i]): + m.d.av_comb += bucket_idx.eq(i) + + for i, bucket in enumerate(self.buckets): + should_incr = C(0) + if i == 0: + # The first bucket has a range [0, 1). + should_incr = sample == 0 + elif i == self.bucket_count - 1: + # The last bucket should count values bigger or equal to 2**(self.bucket_count-1) + should_incr = (bucket_idx >= i - 1) & (sample != 0) + else: + should_incr = (bucket_idx == i - 1) & (sample != 0) + + with m.If(should_incr): + m.d.sync += bucket.value.eq(bucket.value + 1) + + return m + + def add(self, m: TModule, sample: Value): + """ + Adds a new sample to the histogram. + + Should be called in the body of either a transaction or a method. + + Parameters + ---------- + m: TModule + Transactron module + sample: ValueLike + The value that will be added to the histogram + """ + + if not self.metrics_enabled(): + return + + self._add(m, sample) + + +class LatencyMeasurer(Elaboratable): + """ + Measures duration between two events, e.g. request processing latency. + It can track multiple events at the same time, i.e. the second event can + be registered as started, before the first finishes. However, they must be + processed in the FIFO order. + + The module exposes an exponential histogram of the measured latencies. + """ + + def __init__( + self, + fully_qualified_name: str, + description: str = "", + *, + slots_number: int, + max_latency: int, + ): + """ + Parameters + ---------- + fully_qualified_name: str + The fully qualified name of the metric. + description: str + A human-readable description of the metric's functionality. + slots_number: str + A number of events that the module can track simultaneously. + max_latency: int + The maximum latency of an event. Used to set signal widths and + number of buckets in the histogram. If a latency turns to be + bigger than the maximum, it will overflow and result in a false + measurement. + """ + self.fully_qualified_name = fully_qualified_name + self.description = description + self.slots_number = slots_number + self.max_latency = max_latency + + self._start = Method() + self._stop = Method() + + # This bucket count gives us the best possible granularity. + bucket_count = bits_for(self.max_latency) + 1 + self.histogram = HwExpHistogram( + self.fully_qualified_name, + self.description, + bucket_count=bucket_count, + sample_width=bits_for(self.max_latency), + ) + + def elaborate(self, platform): + if not self.metrics_enabled(): + return TModule() + + m = TModule() + + epoch_width = bits_for(self.max_latency) + + m.submodules.fifo = self.fifo = FIFO([("epoch", epoch_width)], self.slots_number) + m.submodules.histogram = self.histogram + + epoch = Signal(epoch_width) + + m.d.sync += epoch.eq(epoch + 1) + + @def_method(m, self._start) + def _(): + self.fifo.write(m, epoch) + + @def_method(m, self._stop) + def _(): + ret = self.fifo.read(m) + # The result of substracting two unsigned n-bit is a signed (n+1)-bit value, + # so we need to cast the result and discard the most significant bit. + duration = (epoch - ret.epoch).as_unsigned()[:-1] + self.histogram.add(m, duration) + + return m + + def start(self, m: TModule): + """ + Registers the start of an event. Can be called before the previous events + finish. If there are no slots available, the method will be blocked. + + Should be called in the body of either a transaction or a method. + + Parameters + ---------- + m: TModule + Transactron module + """ + + if not self.metrics_enabled(): + return + + self._start(m) + + def stop(self, m: TModule): + """ + Registers the end of the oldest event (the FIFO order). If there are no + started events in the queue, the method will block. + + Should be called in the body of either a transaction or a method. + + Parameters + ---------- + m: TModule + Transactron module + """ + + if not self.metrics_enabled(): + return + + self._stop(m) + + def metrics_enabled(self) -> bool: + return DependencyContext.get().get_dependency(HwMetricsEnabledKey()) + + +class HardwareMetricsManager: + """ + Collects all metrics registered in the circuit and provides an easy + access to them. + """ + + def __init__(self): + self._metrics: Optional[dict[str, HwMetric]] = None + + def _collect_metrics(self) -> dict[str, HwMetric]: + # We lazily collect all metrics so that the metrics manager can be + # constructed at any time. Otherwise, if a metric object was created + # after the manager object had been created, that metric wouldn't end up + # being registered. + metrics: dict[str, HwMetric] = {} + for metric in DependencyContext.get().get_dependency(HwMetricsListKey()): + if metric.fully_qualified_name in metrics: + raise RuntimeError(f"Metric '{metric.fully_qualified_name}' is already registered") + + metrics[metric.fully_qualified_name] = metric + + return metrics + + def get_metrics(self) -> dict[str, HwMetric]: + """ + Returns all metrics registered in the circuit. + """ + if self._metrics is None: + self._metrics = self._collect_metrics() + return self._metrics + + def get_register_value(self, metric_name: str, reg_name: str) -> Signal: + """ + Returns the signal holding the register value of the given metric. + + Parameters + ---------- + metric_name: str + The fully qualified name of the metric, for example 'frontend.icache.loads'. + reg_name: str + The name of the register from that metric, for example if + the metric is a histogram, the 'reg_name' could be 'min' + or 'bucket-32'. + """ + + metrics = self.get_metrics() + if metric_name not in metrics: + raise RuntimeError(f"Couldn't find metric '{metric_name}'") + return metrics[metric_name].signals[reg_name] + + def debug_signals(self) -> SignalBundle: + """ + Returns tree-like SignalBundle composed of all metric registers. + """ + metrics = self.get_metrics() + + def rec(metric_names: list[str], depth: int = 1): + bundle: list[SignalBundle] = [] + components: dict[str, list[str]] = {} + + for metric in metric_names: + parts = metric.split(".") + + if len(parts) == depth: + signals = metrics[metric].signals + reg_values = [signals[reg_name] for reg_name in signals] + + bundle.append({metric: reg_values}) + + continue + + component_prefix = ".".join(parts[:depth]) + + if component_prefix not in components: + components[component_prefix] = [] + components[component_prefix].append(metric) + + for component_name, elements in components.items(): + bundle.append({component_name: rec(elements, depth + 1)}) + + return bundle + + return {"metrics": rec(list(self.get_metrics().keys()))} diff --git a/transactron/lib/reqres.py b/transactron/lib/reqres.py index 9bfcdee2b..518d53443 100644 --- a/transactron/lib/reqres.py +++ b/transactron/lib/reqres.py @@ -45,16 +45,16 @@ class ArgumentsToResultsZipper(Elaboratable): Method to save results with `results_layout` in the Forwarder. read: Method Reads latest entries from the fifo and the forwarder and return them as - record with two fields: 'args' and 'results'. + a structure with two fields: 'args' and 'results'. """ def __init__(self, args_layout: MethodLayout, results_layout: MethodLayout, src_loc: int | SrcLoc = 0): """ Parameters ---------- - args_layout: record layout + args_layout: method layout The format of arguments. - results_layout: record layout + results_layout: method layout The format of results. src_loc: int | SrcLoc How many stack frames deep the source location is taken from. @@ -147,7 +147,7 @@ def __init__( self.depth = depth - self.id_layout = [("id", log2_int(self.port_count))] + self.id_layout = [("id", exact_log2(self.port_count))] self.clear = Method() self.serialize_in = [ diff --git a/transactron/lib/storage.py b/transactron/lib/storage.py index c3f8b00d7..4cdb080eb 100644 --- a/transactron/lib/storage.py +++ b/transactron/lib/storage.py @@ -1,9 +1,11 @@ from amaranth import * from amaranth.utils import * + +from transactron.utils.transactron_helpers import from_method_layout, make_layout from ..core import * from ..utils import SrcLoc, get_src_loc, MultiPriorityEncoder from typing import Optional -from transactron.utils import assign, AssignType, LayoutLike +from transactron.utils import assign, AssignType, LayoutList from .reqres import ArgumentsToResultsZipper __all__ = ["MemoryBank", "ContentAddressableMemory"] @@ -21,7 +23,7 @@ class MemoryBank(Elaboratable): The read request method. Accepts an `addr` from which data should be read. Only ready if there is there is a place to buffer response. read_resp: Method - The read response method. Return `data_layout` Record which was saved on `addr` given by last + The read response method. Return `data_layout` View which was saved on `addr` given by last `read_req` method call. Only ready after `read_req` call. write: Method The write method. Accepts `addr` where data should be saved, `data` in form of `data_layout` @@ -31,7 +33,7 @@ class MemoryBank(Elaboratable): def __init__( self, *, - data_layout: MethodLayout, + data_layout: LayoutList, elem_count: int, granularity: Optional[int] = None, safe_writes: bool = True, @@ -40,12 +42,12 @@ def __init__( """ Parameters ---------- - data_layout: record layout - The format of records stored in the Memory. + data_layout: method layout + The format of structures stored in the Memory. elem_count: int Number of elements stored in Memory. granularity: Optional[int] - Granularity of write, forwarded to Amaranth. If `None` the whole record is always saved at once. + Granularity of write, forwarded to Amaranth. If `None` the whole structure is always saved at once. If not, the width of `data_layout` is split into `granularity` parts, which can be saved independently. safe_writes: bool Set to `False` if an optimisation can be done to increase throughput of writes. This will cause that @@ -56,17 +58,18 @@ def __init__( Alternatively, the source location to use instead of the default. """ self.src_loc = get_src_loc(src_loc) - self.data_layout = data_layout + self.data_layout = make_layout(*data_layout) self.elem_count = elem_count self.granularity = granularity - self.width = len(Record(self.data_layout)) + self.width = from_method_layout(self.data_layout).size self.addr_width = bits_for(self.elem_count - 1) self.safe_writes = safe_writes - self.read_req_layout = [("addr", self.addr_width)] - self.write_layout = [("addr", self.addr_width), ("data", self.data_layout)] + self.read_req_layout: LayoutList = [("addr", self.addr_width)] + write_layout = [("addr", self.addr_width), ("data", self.data_layout)] if self.granularity is not None: - self.write_layout.append(("mask", self.width // self.granularity)) + write_layout.append(("mask", self.width // self.granularity)) + self.write_layout = make_layout(*write_layout) self.read_req = Method(i=self.read_req_layout, src_loc=self.src_loc) self.read_resp = Method(o=self.data_layout, src_loc=self.src_loc) @@ -83,8 +86,8 @@ def elaborate(self, platform) -> TModule: prev_read_addr = Signal(self.addr_width) write_pending = Signal() write_req = Signal() - write_args = Record(self.write_layout) - write_args_prev = Record(self.write_layout) + write_args = Signal(self.write_layout) + write_args_prev = Signal(self.write_layout) m.d.comb += read_port.addr.eq(prev_read_addr) zipper = ArgumentsToResultsZipper([("valid", 1)], self.data_layout) @@ -160,7 +163,7 @@ class ContentAddressableMemory(Elaboratable): Inserts new data. """ - def __init__(self, address_layout: LayoutLike, data_layout: LayoutLike, entries_number: int): + def __init__(self, address_layout: LayoutList, data_layout: LayoutList, entries_number: int): """ Parameters ---------- diff --git a/transactron/lib/transformers.py b/transactron/lib/transformers.py index c0034b67e..a1445fcf5 100644 --- a/transactron/lib/transformers.py +++ b/transactron/lib/transformers.py @@ -6,7 +6,7 @@ from ..utils import SrcLoc from typing import Optional, Protocol from collections.abc import Callable -from transactron.utils import ValueLike, assign, AssignType, ModuleLike, HasElaborate +from transactron.utils import ValueLike, assign, AssignType, ModuleLike, MethodStruct, HasElaborate from .connectors import Forwarder, ManyToOneConnectTrans, ConnectTrans from .simultaneous import condition @@ -62,7 +62,7 @@ class MethodMap(Elaboratable, Transformer): Takes a target method and creates a transformed method which calls the original target method, mapping the input and output values with functions. The mapping functions take two parameters, a `Module` and the - `Record` being transformed. Alternatively, a `Method` can be + structure being transformed. Alternatively, a `Method` can be passed. Attributes @@ -75,8 +75,8 @@ def __init__( self, target: Method, *, - i_transform: Optional[tuple[MethodLayout, Callable[[TModule, Record], RecordDict]]] = None, - o_transform: Optional[tuple[MethodLayout, Callable[[TModule, Record], RecordDict]]] = None, + i_transform: Optional[tuple[MethodLayout, Callable[[TModule, MethodStruct], RecordDict]]] = None, + o_transform: Optional[tuple[MethodLayout, Callable[[TModule, MethodStruct], RecordDict]]] = None, src_loc: int | SrcLoc = 0 ): """ @@ -84,11 +84,11 @@ def __init__( ---------- target: Method The target method. - i_transform: (record layout, function or Method), optional + i_transform: (method layout, function or Method), optional Input mapping function. If specified, it should be a pair of a function and a input layout for the transformed method. If not present, input is passed unmodified. - o_transform: (record layout, function or Method), optional + o_transform: (method layout, function or Method), optional Output mapping function. If specified, it should be a pair of a function and a output layout for the transformed method. If not present, output is passed unmodified. @@ -97,9 +97,9 @@ def __init__( Alternatively, the source location to use instead of the default. """ if i_transform is None: - i_transform = (target.data_in.layout, lambda _, x: x) + i_transform = (target.layout_in, lambda _, x: x) if o_transform is None: - o_transform = (target.data_out.layout, lambda _, x: x) + o_transform = (target.layout_out, lambda _, x: x) self.target = target src_loc = get_src_loc(src_loc) @@ -122,7 +122,7 @@ class MethodFilter(Elaboratable, Transformer): Takes a target method and creates a method which calls the target method only when some condition is true. The condition function takes two - parameters, a module and the input `Record` of the method. Non-zero + parameters, a module and the input structure of the method. Non-zero return value is interpreted as true. Alternatively to using a function, a `Method` can be passed as a condition. By default, the target method is locked for use even if it is not called. @@ -139,7 +139,7 @@ class MethodFilter(Elaboratable, Transformer): def __init__( self, target: Method, - condition: Callable[[TModule, Record], ValueLike], + condition: Callable[[TModule, MethodStruct], ValueLike], default: Optional[RecordDict] = None, *, use_condition: bool = False, @@ -164,21 +164,19 @@ def __init__( Alternatively, the source location to use instead of the default. """ if default is None: - default = Record.like(target.data_out) + default = Signal.like(target.data_out) self.target = target self.use_condition = use_condition src_loc = get_src_loc(src_loc) - self.method = Method( - i=target.data_in.layout, o=target.data_out.layout, single_caller=self.use_condition, src_loc=src_loc - ) + self.method = Method(i=target.layout_in, o=target.layout_out, single_caller=self.use_condition, src_loc=src_loc) self.condition = condition self.default = default def elaborate(self, platform): m = TModule() - ret = Record.like(self.target.data_out) + ret = Signal.like(self.target.data_out) m.d.comb += assign(ret, self.default, fields=AssignType.ALL) @def_method(m, self.method) @@ -203,7 +201,7 @@ class MethodProduct(Elaboratable, Unifier): def __init__( self, targets: list[Method], - combiner: Optional[tuple[MethodLayout, Callable[[TModule, list[Record]], RecordDict]]] = None, + combiner: Optional[tuple[MethodLayout, Callable[[TModule, list[MethodStruct]], RecordDict]]] = None, *, src_loc: int | SrcLoc = 0 ): @@ -234,11 +232,11 @@ def __init__( The product method. """ if combiner is None: - combiner = (targets[0].data_out.layout, lambda _, x: x[0]) + combiner = (targets[0].layout_out, lambda _, x: x[0]) self.targets = targets self.combiner = combiner src_loc = get_src_loc(src_loc) - self.method = Method(i=targets[0].data_in.layout, o=combiner[0], src_loc=src_loc) + self.method = Method(i=targets[0].layout_in, o=combiner[0], src_loc=src_loc) def elaborate(self, platform): m = TModule() @@ -257,7 +255,9 @@ class MethodTryProduct(Elaboratable, Unifier): def __init__( self, targets: list[Method], - combiner: Optional[tuple[MethodLayout, Callable[[TModule, list[tuple[Value, Record]]], RecordDict]]] = None, + combiner: Optional[ + tuple[MethodLayout, Callable[[TModule, list[tuple[Value, MethodStruct]]], RecordDict]] + ] = None, *, src_loc: int | SrcLoc = 0 ): @@ -293,14 +293,14 @@ def __init__( self.targets = targets self.combiner = combiner self.src_loc = get_src_loc(src_loc) - self.method = Method(i=targets[0].data_in.layout, o=combiner[0], src_loc=self.src_loc) + self.method = Method(i=targets[0].layout_in, o=combiner[0], src_loc=self.src_loc) def elaborate(self, platform): m = TModule() @def_method(m, self.method) def _(arg): - results: list[tuple[Value, Record]] = [] + results: list[tuple[Value, MethodStruct]] = [] for target in self.targets: success = Signal() with Transaction(src_loc=self.src_loc).body(m): @@ -335,18 +335,18 @@ def __init__(self, targets: list[Method], *, src_loc: int | SrcLoc = 0): Alternatively, the source location to use instead of the default. """ self.method_list = targets - layout = targets[0].data_out.layout + layout = targets[0].layout_out self.src_loc = get_src_loc(src_loc) self.method = Method(o=layout, src_loc=self.src_loc) for method in targets: - if layout != method.data_out.layout: + if layout != method.layout_out: raise Exception("Not all methods have this same layout") def elaborate(self, platform): m = TModule() - m.submodules.forwarder = forwarder = Forwarder(self.method.data_out.layout, src_loc=self.src_loc) + m.submodules.forwarder = forwarder = Forwarder(self.method.layout_out, src_loc=self.src_loc) m.submodules.connect = ManyToOneConnectTrans( get_results=[get for get in self.method_list], put_result=forwarder.write, src_loc=self.src_loc @@ -386,7 +386,7 @@ def elaborate(self, platform): with Transaction().body(m): sdata1 = self.src1(m) sdata2 = self.src2(m) - ddata = Record.like(self.dst.data_in) + ddata = Signal.like(self.dst.data_in) self.dst(m, ddata) m.d.comb += ddata.eq(Cat(sdata1, sdata2)) @@ -400,7 +400,7 @@ class ConnectAndMapTrans(Elaboratable): Behaves like `ConnectTrans`, but modifies the transferred data using functions or `Method`s. Equivalent to a combination of `ConnectTrans` and `MethodMap`. The mapping functions take two parameters, a `Module` - and the `Record` being transformed. + and the structure being transformed. """ def __init__( @@ -408,8 +408,8 @@ def __init__( method1: Method, method2: Method, *, - i_fun: Optional[Callable[[TModule, Record], RecordDict]] = None, - o_fun: Optional[Callable[[TModule, Record], RecordDict]] = None, + i_fun: Optional[Callable[[TModule, MethodStruct], RecordDict]] = None, + o_fun: Optional[Callable[[TModule, MethodStruct], RecordDict]] = None, src_loc: int | SrcLoc = 0 ): """ @@ -438,8 +438,8 @@ def elaborate(self, platform): m.submodules.transformer = transformer = MethodMap( self.method2, - i_transform=(self.method1.data_out.layout, self.i_fun), - o_transform=(self.method1.data_in.layout, self.o_fun), + i_transform=(self.method1.layout_out, self.i_fun), + o_transform=(self.method1.layout_in, self.o_fun), src_loc=self.src_loc, ) m.submodules.connect = ConnectTrans(self.method1, transformer.method) diff --git a/transactron/profiler.py b/transactron/profiler.py index 410538ac4..0132b2ef7 100644 --- a/transactron/profiler.py +++ b/transactron/profiler.py @@ -1,18 +1,31 @@ +import os from collections import defaultdict from typing import Optional from dataclasses import dataclass, field -from transactron.utils import SrcLoc from dataclasses_json import dataclass_json +from transactron.utils import SrcLoc, IdGenerator +from transactron.core import MethodMap, TransactionManager -__all__ = ["ProfileInfo", "Profile"] +__all__ = [ + "ProfileInfo", + "ProfileData", + "RunStat", + "RunStatNode", + "Profile", + "TransactionSamples", + "MethodSamples", + "ProfileSamples", +] @dataclass_json @dataclass class ProfileInfo: - """Information about transactions and methods. In `Profile`, transactions - and methods are referred to by their unique ID numbers. + """Information about transactions and methods. + + In `Profile`, transactions and methods are referred to by their unique ID + numbers. Attributes ---------- @@ -29,6 +42,68 @@ class ProfileInfo: is_transaction: bool +@dataclass +class ProfileData: + """Information about transactions and methods from the transaction manager. + + This data is required for transaction profile generation in simulators. + Transactions and methods are referred to by their unique ID numbers. + + Attributes + ---------- + transactions_and_methods: dict[int, ProfileInfo] + Information about individual transactions and methods. + method_parents: dict[int, list[int]] + Lists the callers (transactions and methods) for each method. Key is + method ID. + transactions_by_method: dict[int, list[int]] + Lists which transactions are calling each method. Key is method ID. + transaction_conflicts: dict[int, list[int]] + List which other transactions conflict with each transaction. + """ + + transactions_and_methods: dict[int, ProfileInfo] + method_parents: dict[int, list[int]] + transactions_by_method: dict[int, list[int]] + transaction_conflicts: dict[int, list[int]] + + @staticmethod + def make(transaction_manager: TransactionManager): + transactions_and_methods = dict[int, ProfileInfo]() + method_parents = dict[int, list[int]]() + transactions_by_method = dict[int, list[int]]() + transaction_conflicts = dict[int, list[int]]() + + method_map = MethodMap(transaction_manager.transactions) + cgr, _, _ = TransactionManager._conflict_graph(method_map) + get_id = IdGenerator() + + def local_src_loc(src_loc: SrcLoc): + return (os.path.relpath(src_loc[0]), src_loc[1]) + + for transaction in method_map.transactions: + transactions_and_methods[get_id(transaction)] = ProfileInfo( + transaction.owned_name, local_src_loc(transaction.src_loc), True + ) + + for method in method_map.methods: + transactions_and_methods[get_id(method)] = ProfileInfo( + method.owned_name, local_src_loc(method.src_loc), False + ) + method_parents[get_id(method)] = [get_id(t_or_m) for t_or_m in method_map.method_parents[method]] + transactions_by_method[get_id(method)] = [ + get_id(t_or_m) for t_or_m in method_map.transactions_by_method[method] + ] + + for transaction, transactions in cgr.items(): + transaction_conflicts[get_id(transaction)] = [get_id(transaction2) for transaction2 in transactions] + + return ( + ProfileData(transactions_and_methods, method_parents, transactions_by_method, transaction_conflicts), + get_id, + ) + + @dataclass class RunStat: """Collected statistics about a transaction or method. @@ -76,6 +151,54 @@ def make(info: ProfileInfo): return RunStatNode(RunStat.make(info)) +@dataclass +class TransactionSamples: + """Runtime value of transaction control signals in a given clock cycle. + + Attributes + ---------- + request: bool + The value of the transaction's ``request`` signal. + runnable: bool + The value of the transaction's ``runnable`` signal. + grant: bool + The value of the transaction's ``grant`` signal. + """ + + request: bool + runnable: bool + grant: bool + + +@dataclass +class MethodSamples: + """Runtime value of method control signals in a given clock cycle. + + Attributes + ---------- + run: bool + The value of the method's ``run`` signal. + """ + + run: bool + + +@dataclass +class ProfileSamples: + """Runtime values of all transaction and method control signals. + + Attributes + ---------- + transactions: dict[int, TransactionSamples] + Runtime values of transaction control signals for each transaction. + methods: dict[int, MethodSamples] + Runtime values of method control signals for each method. + """ + + transactions: dict[int, TransactionSamples] = field(default_factory=dict) + methods: dict[int, MethodSamples] = field(default_factory=dict) + + @dataclass_json @dataclass class CycleProfile: @@ -98,6 +221,44 @@ class CycleProfile: locked: dict[int, int] = field(default_factory=dict) running: dict[int, Optional[int]] = field(default_factory=dict) + @staticmethod + def make(samples: ProfileSamples, data: ProfileData): + cprof = CycleProfile() + + for transaction_id, transaction_samples in samples.transactions.items(): + if transaction_samples.grant: + cprof.running[transaction_id] = None + elif transaction_samples.request and transaction_samples.runnable: + for transaction2_id in data.transaction_conflicts[transaction_id]: + if samples.transactions[transaction2_id].grant: + cprof.locked[transaction_id] = transaction2_id + + running = set(cprof.running) + for method_id, method_samples in samples.methods.items(): + if method_samples.run: + running.add(method_id) + + locked_methods = set[int]() + for method_id in samples.methods.keys(): + if method_id not in running: + if any(transaction_id in running for transaction_id in data.transactions_by_method[method_id]): + locked_methods.add(method_id) + + for method_id in samples.methods.keys(): + if method_id in running: + for t_or_m_id in data.method_parents[method_id]: + if t_or_m_id in running: + cprof.running[method_id] = t_or_m_id + elif method_id in locked_methods: + caller = next( + t_or_m_id + for t_or_m_id in data.method_parents[method_id] + if t_or_m_id in running or t_or_m_id in locked_methods + ) + cprof.locked[method_id] = caller + + return cprof + @dataclass_json @dataclass diff --git a/test/common/__init__.py b/transactron/testing/__init__.py similarity index 88% rename from test/common/__init__.py rename to transactron/testing/__init__.py index dea6ec6e0..aa215228e 100644 --- a/test/common/__init__.py +++ b/transactron/testing/__init__.py @@ -4,4 +4,5 @@ from .sugar import * # noqa: F401 from .testbenchio import * # noqa: F401 from .profiler import * # noqa: F401 +from .logging import * # noqa: F401 from transactron.utils import data_layout # noqa: F401 diff --git a/transactron/testing/functions.py b/transactron/testing/functions.py new file mode 100644 index 000000000..7d5bcfb92 --- /dev/null +++ b/transactron/testing/functions.py @@ -0,0 +1,39 @@ +from amaranth import * +from amaranth.lib.data import Layout, StructLayout, View +from amaranth.sim.core import Command +from typing import TypeVar, Any, Generator, TypeAlias, TYPE_CHECKING, Union +from transactron.utils._typing import RecordValueDict, RecordIntDict + + +if TYPE_CHECKING: + from amaranth.hdl._ast import Statement + from .infrastructure import CoreblocksCommand + + +T = TypeVar("T") +TestGen: TypeAlias = Generator[Union[Command, Value, "Statement", "CoreblocksCommand", None], Any, T] + + +def set_inputs(values: RecordValueDict, field: View) -> TestGen[None]: + for name, value in values.items(): + if isinstance(value, dict): + yield from set_inputs(value, getattr(field, name)) + else: + yield getattr(field, name).eq(value) + + +def get_outputs(field: View) -> TestGen[RecordIntDict]: + # return dict of all signal values in a record because amaranth's simulator can't read all + # values of a View in a single yield - it can only read Values (Signals) + result = {} + layout = field.shape() + assert isinstance(layout, StructLayout) + for name, fld in layout: + val = field[name] + if isinstance(fld.shape, Layout): + result[name] = yield from get_outputs(View(fld.shape, val)) + elif isinstance(val, Value): + result[name] = yield val + else: + raise ValueError + return result diff --git a/test/gtkw_extension.py b/transactron/testing/gtkw_extension.py similarity index 78% rename from test/gtkw_extension.py rename to transactron/testing/gtkw_extension.py index 1229bad2c..835886273 100644 --- a/test/gtkw_extension.py +++ b/transactron/testing/gtkw_extension.py @@ -1,5 +1,6 @@ from typing import Iterable, Mapping from contextlib import contextmanager +from amaranth.lib.data import View from amaranth.sim.pysim import _VCDWriter from amaranth import * from transactron.utils import flatten_signals @@ -13,6 +14,12 @@ def __init__(self, fragment, *, vcd_file, gtkw_file, traces): self._tree_traces = traces def close(self, timestamp): + def save_signal(value: Value): + for signal in value._rhs_signals(): # type: ignore + if signal in self.gtkw_names: + for name in self.gtkw_names[signal]: + self.gtkw_save.trace(name) + def gtkw_traces(traces): if isinstance(traces, Mapping): for k, v in traces.items(): @@ -28,12 +35,12 @@ def gtkw_traces(traces): gtkw_traces(v) elif len(traces.fields) == 1: # to make gtkwave view less verbose gtkw_traces(next(iter(traces.fields.values()))) - elif isinstance(traces, Signal): - if len(traces) > 1 and not traces.decoder: - suffix = "[{}:0]".format(len(traces) - 1) - else: - suffix = "" - self.gtkw_save.trace(".".join(self.gtkw_names[traces]) + suffix) + elif isinstance(traces, View): + v = Value.cast(traces) + with self.gtkw_save.group(v.name if isinstance(v, Signal) else ""): + save_signal(v) + elif isinstance(traces, Value): + save_signal(traces) if self.vcd_writer is not None: self.vcd_writer.close(timestamp) diff --git a/test/common/infrastructure.py b/transactron/testing/infrastructure.py similarity index 82% rename from test/common/infrastructure.py rename to transactron/testing/infrastructure.py index 51dd3e8ce..a769bba13 100644 --- a/test/common/infrastructure.py +++ b/transactron/testing/infrastructure.py @@ -8,13 +8,16 @@ from abc import ABC from amaranth import * from amaranth.sim import * + +from transactron.utils.dependencies import DependencyContext, DependencyManager from .testbenchio import TestbenchIO from .profiler import profiler_process, Profile from .functions import TestGen -from ..gtkw_extension import write_vcd_ext +from .logging import make_logging_process, parse_logging_level +from .gtkw_extension import write_vcd_ext from transactron import Method from transactron.lib import AdapterTrans -from transactron.core import TransactionModule +from transactron.core import TransactionManagerKey, TransactionModule from transactron.utils import ModuleConnector, HasElaborate, auto_debug_signals, HasDebugSignals T = TypeVar("T") @@ -90,8 +93,12 @@ def debug_signals(self): class _TestModule(Elaboratable): - def __init__(self, tested_module: HasElaborate, add_transaction_module): - self.tested_module = TransactionModule(tested_module) if add_transaction_module else tested_module + def __init__(self, tested_module: HasElaborate, add_transaction_module: bool): + self.tested_module = ( + TransactionModule(tested_module, dependency_manager=DependencyContext.get()) + if add_transaction_module + else tested_module + ) self.add_transaction_module = add_transaction_module def elaborate(self, platform) -> HasElaborate: @@ -103,6 +110,8 @@ def elaborate(self, platform) -> HasElaborate: m.submodules.tested_module = self.tested_module + m.domains.sync_neg = ClockDomain(clk_edge="neg", local=True) + return m @@ -154,6 +163,7 @@ def __init__( super().__init__(test_module) self.add_clock(clk_period) + self.add_clock(clk_period, domain="sync_neg") if isinstance(tested_module, HasDebugSignals): extra_signals = tested_module.debug_signals @@ -192,6 +202,27 @@ def run(self) -> bool: class TestCaseWithSimulator(unittest.TestCase): + dependency_manager: DependencyManager + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.dependency_manager = DependencyManager() + + def wrap(f: Callable[[], None]): + @functools.wraps(f) + def wrapper(): + with DependencyContext(self.dependency_manager): + f() + + return wrapper + + for k in dir(self): + if k.startswith("test") or k == "setUp": + f = getattr(self, k) + if isinstance(f, Callable): + setattr(self, k, wrap(getattr(self, k))) + def add_class_mocks(self, sim: PysimSimulator) -> None: for key in dir(self): val = getattr(self, key) @@ -210,7 +241,7 @@ def add_all_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None: @contextmanager def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True): traces_file = None - if "__COREBLOCKS_DUMP_TRACES" in os.environ: + if "__TRANSACTRON_DUMP_TRACES" in os.environ: traces_file = unittest.TestCase.id(self) clk_period = 1e-6 @@ -222,12 +253,22 @@ def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_tra clk_period=clk_period, ) self.add_all_mocks(sim, sys._getframe(2).f_locals) + yield sim profile = None if "__TRANSACTRON_PROFILE" in os.environ and isinstance(sim.tested_module, TransactionModule): profile = Profile() - sim.add_sync_process(profiler_process(sim.tested_module.transactionManager, profile, clk_period)) + sim.add_sync_process( + profiler_process(sim.tested_module.manager.get_dependency(TransactionManagerKey()), profile) + ) + + def on_error(): + self.assertTrue(False, "Simulation finished due to an error") + + log_level = parse_logging_level(os.environ["__TRANSACTRON_LOG_LEVEL"]) + log_filter = os.environ["__TRANSACTRON_LOG_FILTER"] + sim.add_sync_process(make_logging_process(log_level, log_filter, on_error)) res = sim.run() diff --git a/test/common/input_generation.py b/transactron/testing/input_generation.py similarity index 100% rename from test/common/input_generation.py rename to transactron/testing/input_generation.py diff --git a/transactron/testing/logging.py b/transactron/testing/logging.py new file mode 100644 index 000000000..6a2ad0881 --- /dev/null +++ b/transactron/testing/logging.py @@ -0,0 +1,108 @@ +from collections.abc import Callable +from typing import Any +import logging + +from amaranth.sim import Passive, Tick +from transactron.lib import logging as tlog + + +__all__ = ["make_logging_process", "parse_logging_level"] + + +def parse_logging_level(str: str) -> tlog.LogLevel: + """Parse the log level from a string. + + The level can be either a non-negative integer or a string representation + of one of the predefined levels. + + Raises an exception if the level cannot be parsed. + """ + str = str.upper() + names_mapping = logging.getLevelNamesMapping() + if str in names_mapping: + return names_mapping[str] + + # try convert to int + try: + return int(str) + except ValueError: + pass + + raise ValueError("Log level must be either {error, warn, info, debug} or a non-negative integer.") + + +_sim_cycle = 0 + + +class _LogFormatter(logging.Formatter): + """ + Log formatter to provide colors and to inject simulator times into + the log messages. Adapted from https://stackoverflow.com/a/56944256/3638629 + """ + + magenta = "\033[0;35m" + grey = "\033[0;34m" + blue = "\033[0;34m" + yellow = "\033[0;33m" + red = "\033[0;31m" + reset = "\033[0m" + + loglevel2colour = { + logging.DEBUG: grey + "{}" + reset, + logging.INFO: magenta + "{}" + reset, + logging.WARNING: yellow + "{}" + reset, + logging.ERROR: red + "{}" + reset, + } + + def format(self, record: logging.LogRecord): + level_name = self.loglevel2colour[record.levelno].format(record.levelname) + return f"{_sim_cycle} {level_name} {record.name} {record.getMessage()}" + + +def make_logging_process(level: tlog.LogLevel, namespace_regexp: str, on_error: Callable[[], Any]): + combined_trigger = tlog.get_trigger_bit(level, namespace_regexp) + records = tlog.get_log_records(level, namespace_regexp) + + root_logger = logging.getLogger() + ch = logging.StreamHandler() + formatter = _LogFormatter() + ch.setFormatter(formatter) + root_logger.handlers = [ch] + + def handle_logs(): + if not (yield combined_trigger): + return + + for record in records: + if not (yield record.trigger): + continue + + values: list[int] = [] + for field in record.fields: + values.append((yield field)) + + formatted_msg = record.format(*values) + + logger = root_logger.getChild(record.logger_name) + logger.log( + record.level, + "%s:%d] %s", + record.location[0], + record.location[1], + formatted_msg, + ) + + if record.level >= logging.ERROR: + on_error() + + def log_process(): + global _sim_cycle + + yield Passive() + while True: + yield Tick("sync_neg") + yield from handle_logs() + yield + _sim_cycle += 1 + + return log_process diff --git a/transactron/testing/profiler.py b/transactron/testing/profiler.py new file mode 100644 index 000000000..18451112c --- /dev/null +++ b/transactron/testing/profiler.py @@ -0,0 +1,36 @@ +from amaranth.sim import * +from transactron.core import MethodMap, TransactionManager +from transactron.profiler import CycleProfile, MethodSamples, Profile, ProfileData, ProfileSamples, TransactionSamples +from .functions import TestGen + +__all__ = ["profiler_process"] + + +def profiler_process(transaction_manager: TransactionManager, profile: Profile): + def process() -> TestGen: + profile_data, get_id = ProfileData.make(transaction_manager) + method_map = MethodMap(transaction_manager.transactions) + profile.transactions_and_methods = profile_data.transactions_and_methods + + yield Passive() + while True: + yield Tick("sync_neg") + + samples = ProfileSamples() + + for transaction in method_map.transactions: + samples.transactions[get_id(transaction)] = TransactionSamples( + bool((yield transaction.request)), + bool((yield transaction.runnable)), + bool((yield transaction.grant)), + ) + + for method in method_map.methods: + samples.methods[get_id(method)] = MethodSamples(bool((yield method.run))) + + cprof = CycleProfile.make(samples, profile_data) + profile.cycles.append(cprof) + + yield + + return process diff --git a/test/common/sugar.py b/transactron/testing/sugar.py similarity index 100% rename from test/common/sugar.py rename to transactron/testing/sugar.py diff --git a/test/common/testbenchio.py b/transactron/testing/testbenchio.py similarity index 100% rename from test/common/testbenchio.py rename to transactron/testing/testbenchio.py diff --git a/transactron/tracing.py b/transactron/tracing.py index 036044aed..f418915cb 100644 --- a/transactron/tracing.py +++ b/transactron/tracing.py @@ -4,15 +4,19 @@ import warnings -from amaranth.hdl.ir import Elaboratable, Fragment, Instance -from amaranth.hdl.xfrm import FragmentTransformer -from amaranth.hdl import dsl, ir, mem, xfrm +from amaranth.hdl import Elaboratable, Fragment, Instance +from amaranth.hdl._xfrm import FragmentTransformer +from amaranth.hdl import _dsl, _ir, _mem, _xfrm from transactron.utils import HasElaborate from . import core # generic tuple because of aggressive monkey-patching -modules_with_fragment: tuple = core, ir, dsl, mem, xfrm +modules_with_fragment: tuple = core, _ir, _dsl, _mem, _xfrm +# List of Fragment subclasses which should be patched to inherit from TracingFragment. +# The first element of the tuple is a subclass name to patch, and the second element +# of the tuple is tuple with modules in which the patched subclass should be installed. +fragment_subclasses_to_patch = [("MemoryInstance", (_mem, _xfrm))] DIAGNOSTICS = False orig_on_fragment = FragmentTransformer.on_fragment @@ -22,13 +26,34 @@ class TracingEnabler: def __enter__(self): self.orig_fragment_get = Fragment.get self.orig_on_fragment = FragmentTransformer.on_fragment - self.orig_fragment_class = ir.Fragment - self.orig_instance_class = ir.Instance + self.orig_fragment_class = _ir.Fragment + self.orig_instance_class = _ir.Instance + self.orig_patched_fragment_subclasses = [] Fragment.get = TracingFragment.get FragmentTransformer.on_fragment = TracingFragmentTransformer.on_fragment for mod in modules_with_fragment: mod.Fragment = TracingFragment mod.Instance = TracingInstance + for class_name, modules in fragment_subclasses_to_patch: + orig_fragment_subclass = getattr(modules[0], class_name) + # `type` is used to declare new class dynamicaly. There is passed `orig_fragment_subclass` as a first + # base class to allow `super()` to work. Calls to `super` without arguments are syntax sugar and are + # extended on compile/interpretation (not execution!) phase to the `super(OriginalClass, self)`, + # so they are hardcoded on execution time to look for the original class + # (see: https://docs.python.org/3/library/functions.html#super). + # This cause that OriginalClass has to be in `__mro__` of the newly created class, because else an + # TypeError will be raised (see: https://stackoverflow.com/a/40819403). Adding OriginalClass to the + # bases of patched class allows us to fix the TypeError. Everything works correctly because `super` + # starts search of `__mro__` from the class right after the first argument. In our case the first + # checked class will be `TracingFragment` as we want. + newclass = type( + class_name, + (orig_fragment_subclass, TracingFragment, ), + dict(orig_fragment_subclass.__dict__) + ) + for mod in modules: + setattr(mod, class_name, newclass) + self.orig_patched_fragment_subclasses.append((class_name, orig_fragment_subclass, modules)) def __exit__(self, tp, val, tb): Fragment.get = self.orig_fragment_get @@ -36,6 +61,9 @@ def __exit__(self, tp, val, tb): for mod in modules_with_fragment: mod.Fragment = self.orig_fragment_class mod.Instance = self.orig_instance_class + for class_name, orig_fragment_subclass, modules in self.orig_patched_fragment_subclasses: + for mod in modules: + setattr(mod, class_name, orig_fragment_subclass) class TracingFragmentTransformer(FragmentTransformer): diff --git a/transactron/utils/__init__.py b/transactron/utils/__init__.py index cfe28ff2a..ebf845b7d 100644 --- a/transactron/utils/__init__.py +++ b/transactron/utils/__init__.py @@ -6,3 +6,4 @@ from .transactron_helpers import * # noqa: F401 from .dependencies import * # noqa: F401 from .depcache import * # noqa: F401 +from .idgen import * # noqa: F401 diff --git a/transactron/utils/_typing.py b/transactron/utils/_typing.py index 519431f4a..555c41736 100644 --- a/transactron/utils/_typing.py +++ b/transactron/utils/_typing.py @@ -13,28 +13,36 @@ runtime_checkable, Union, Any, + TYPE_CHECKING, ) -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Iterable, Iterator, Mapping from contextlib import AbstractContextManager from enum import Enum from amaranth import * -from amaranth.lib.data import View -from amaranth.hdl.ast import ShapeCastable, Statement, ValueCastable -from amaranth.hdl.dsl import _ModuleBuilderSubmodules, _ModuleBuilderDomainSet, _ModuleBuilderDomain, FSM -from amaranth.hdl.rec import Direction, Layout +from amaranth.lib.data import StructLayout, View +from amaranth.lib.wiring import Flow, Member +from amaranth.hdl import ShapeCastable, ValueCastable + +if TYPE_CHECKING: + from amaranth.hdl._ast import Statement + from amaranth.hdl._dsl import _ModuleBuilderSubmodules, _ModuleBuilderDomainSet, _ModuleBuilderDomain + import amaranth.hdl._dsl __all__ = [ "FragmentLike", "ValueLike", + "ShapeLike", "StatementLike", - "LayoutLike", "SimpleLayout", "SwitchKey", + "SrcLoc", "MethodLayout", + "MethodStruct", "SrcLoc", "SignalBundle", "LayoutListField", "LayoutList", + "LayoutIterable", "RecordIntDict", "RecordIntDictRet", "RecordValueDict", @@ -51,12 +59,8 @@ FragmentLike: TypeAlias = Fragment | Elaboratable ValueLike: TypeAlias = Value | int | Enum | ValueCastable ShapeLike: TypeAlias = Shape | ShapeCastable | int | range | type[Enum] -StatementLike: TypeAlias = Statement | Iterable["StatementLike"] -LayoutLike: TypeAlias = ( - Layout | Sequence[tuple[str, "ShapeLike | LayoutLike"] | tuple[str, "ShapeLike | LayoutLike", Direction]] -) +StatementLike: TypeAlias = Union["Statement", Iterable["StatementLike"]] SwitchKey: TypeAlias = str | int | Enum -MethodLayout: TypeAlias = LayoutLike SrcLoc: TypeAlias = tuple[str, int] # Internal Coreblocks types @@ -64,6 +68,9 @@ LayoutListField: TypeAlias = tuple[str, "ShapeLike | LayoutList"] LayoutList: TypeAlias = list[LayoutListField] SimpleLayout = list[Tuple[str, Union[int, "SimpleLayout"]]] +LayoutIterable: TypeAlias = Iterable[LayoutListField] +MethodLayout: TypeAlias = StructLayout | LayoutIterable +MethodStruct: TypeAlias = "View[StructLayout]" RecordIntDict: TypeAlias = Mapping[str, Union[int, "RecordIntDict"]] RecordIntDictRet: TypeAlias = Mapping[str, Any] # full typing hard to work with @@ -78,17 +85,18 @@ GraphCC: TypeAlias = set[T] +# Protocols for Amaranth classes class _ModuleBuilderDomainsLike(Protocol): - def __getattr__(self, name: str) -> _ModuleBuilderDomain: + def __getattr__(self, name: str) -> "_ModuleBuilderDomain": ... - def __getitem__(self, name: str) -> _ModuleBuilderDomain: + def __getitem__(self, name: str) -> "_ModuleBuilderDomain": ... - def __setattr__(self, name: str, value: _ModuleBuilderDomain) -> None: + def __setattr__(self, name: str, value: "_ModuleBuilderDomain") -> None: ... - def __setitem__(self, name: str, value: _ModuleBuilderDomain) -> None: + def __setitem__(self, name: str, value: "_ModuleBuilderDomain") -> None: ... @@ -96,8 +104,8 @@ def __setitem__(self, name: str, value: _ModuleBuilderDomain) -> None: class ModuleLike(Protocol, Generic[_T_ModuleBuilderDomains]): - submodules: _ModuleBuilderSubmodules - domains: _ModuleBuilderDomainSet + submodules: "_ModuleBuilderSubmodules" + domains: "_ModuleBuilderDomainSet" d: _T_ModuleBuilderDomains def If(self, cond: ValueLike) -> AbstractContextManager[None]: # noqa: N802 @@ -120,7 +128,7 @@ def Default(self) -> AbstractContextManager[None]: # noqa: N802 def FSM( # noqa: N802 self, reset: Optional[str] = ..., domain: str = ..., name: str = ... - ) -> AbstractContextManager[FSM]: + ) -> AbstractContextManager["amaranth.hdl._dsl.FSM"]: ... def State(self, name: str) -> AbstractContextManager[None]: # noqa: N802 @@ -135,6 +143,74 @@ def next(self, name: str) -> None: ... +class AbstractSignatureMembers(Protocol): + def flip(self) -> "AbstractSignatureMembers": + ... + + def __eq__(self, other) -> bool: + ... + + def __contains__(self, name: str) -> bool: + ... + + def __getitem__(self, name: str) -> Member: + ... + + def __setitem__(self, name: str, member: Member) -> NoReturn: + ... + + def __delitem__(self, name: str) -> NoReturn: + ... + + def __iter__(self) -> Iterator[str]: + ... + + def __len__(self) -> int: + ... + + def flatten(self, *, path: tuple[str | int, ...] = ...) -> Iterator[tuple[tuple[str | int, ...], Member]]: + ... + + def create(self, *, path: tuple[str | int, ...] = ..., src_loc_at: int = ...) -> dict[str, Any]: + ... + + def __repr__(self) -> str: + ... + + +class AbstractSignature(Protocol): + def flip(self) -> "AbstractSignature": + ... + + @property + def members(self) -> AbstractSignatureMembers: + ... + + def __eq__(self, other) -> bool: + ... + + def flatten(self, obj) -> Iterator[tuple[tuple[str | int, ...], Flow, ValueLike]]: + ... + + def is_compliant(self, obj, *, reasons: Optional[list[str]] = ..., path: tuple[str, ...] = ...) -> bool: + ... + + def create( + self, *, path: tuple[str | int, ...] = ..., src_loc_at: int = ... + ) -> "AbstractInterface[AbstractSignature]": + ... + + def __repr__(self) -> str: + ... + + +_T_AbstractSignature = TypeVar("_T_AbstractSignature", bound=AbstractSignature) + + +class AbstractInterface(Protocol, Generic[_T_AbstractSignature]): + signature: _T_AbstractSignature + + class HasElaborate(Protocol): def elaborate(self, platform) -> "HasElaborate": ... diff --git a/transactron/utils/amaranth_ext/elaboratables.py b/transactron/utils/amaranth_ext/elaboratables.py index 22aa9aad4..8f95b4667 100644 --- a/transactron/utils/amaranth_ext/elaboratables.py +++ b/transactron/utils/amaranth_ext/elaboratables.py @@ -45,7 +45,7 @@ def OneHotSwitch(m: ModuleLike, test: Value): @contextmanager def case(n: Optional[int] = None): if n is None: - with m.Case(): + with m.Default(): yield else: # find the index of the least significant bit set diff --git a/transactron/utils/amaranth_ext/functions.py b/transactron/utils/amaranth_ext/functions.py index fba8bb8ed..d09c7b53b 100644 --- a/transactron/utils/amaranth_ext/functions.py +++ b/transactron/utils/amaranth_ext/functions.py @@ -1,5 +1,5 @@ from amaranth import * -from amaranth.utils import bits_for, log2_int +from amaranth.utils import bits_for, exact_log2 from amaranth.lib import data from collections.abc import Iterable, Mapping from transactron.utils._typing import SignalBundle @@ -54,7 +54,7 @@ def iter(s: Value, step: int) -> Value: return result try: - xlen_log = log2_int(len(s)) + xlen_log = exact_log2(len(s)) except ValueError: raise NotImplementedError("CountLeadingZeros - only sizes aligned to power of 2 are supperted") @@ -71,7 +71,7 @@ def iter(s: Value, step: int) -> Value: def count_trailing_zeros(s: Value) -> Value: try: - log2_int(len(s)) + exact_log2(len(s)) except ValueError: raise NotImplementedError("CountTrailingZeros - only sizes aligned to power of 2 are supperted") diff --git a/transactron/utils/assign.py b/transactron/utils/assign.py index b28ac738f..0be471e80 100644 --- a/transactron/utils/assign.py +++ b/transactron/utils/assign.py @@ -1,11 +1,14 @@ from enum import Enum -from typing import Optional, TypeAlias, cast +from typing import Optional, TypeAlias, cast, TYPE_CHECKING from collections.abc import Iterable, Mapping from amaranth import * -from amaranth.hdl.ast import Assign, ArrayProxy +from amaranth.hdl._ast import ArrayProxy from amaranth.lib import data from ._typing import ValueLike +if TYPE_CHECKING: + from amaranth.hdl._ast import Assign + __all__ = [ "AssignType", "assign", @@ -31,30 +34,28 @@ def flatten_elems(proxy: ArrayProxy): yield elem elems = list(flatten_elems(proxy)) - if elems and all(isinstance(el, Record) for el in elems): - return set.intersection(*[set(cast(Record, el).fields) for el in elems]) + if elems and all(isinstance(el, data.View) for el in elems): + return set.intersection(*[set(cast(data.View, el).shape().members.keys()) for el in elems]) def assign_arg_fields(val: AssignArg) -> Optional[set[str]]: if isinstance(val, ArrayProxy): return arrayproxy_fields(val) - elif isinstance(val, Record): - return set(val.fields) elif isinstance(val, data.View): layout = val.shape() if isinstance(layout, data.StructLayout): - return set(k for k, _ in layout) + return set(k for k in layout.members) elif isinstance(val, dict): return set(val.keys()) def assign( lhs: AssignArg, rhs: AssignArg, *, fields: AssignFields = AssignType.RHS, lhs_strict=False, rhs_strict=False -) -> Iterable[Assign]: - """Safe record assignment. +) -> Iterable["Assign"]: + """Safe structured assignment. This function recursively generates assignment statements for - field-containing structures. This includes: Amaranth `Record`\\s, + field-containing structures. This includes: Amaranth `View`\\s using `StructLayout`, Python `dict`\\s. In case of mismatching fields or bit widths, error is raised. @@ -65,16 +66,16 @@ def assign( The bit width check is performed if: - - Any of `lhs` or `rhs` is a `Record` or `View`. + - Any of `lhs` or `rhs` is a `View`. - Both `lhs` and `rhs` have an explicitly defined shape (e.g. are a - `Signal`, a field of a `Record` or a `View`). + `Signal`, a field of a `View`). Parameters ---------- - lhs : Record or View or Value-castable or dict - Record, signal or dict being assigned. - rhs : Record or View or Value-castable or dict - Record, signal or dict containing assigned values. + lhs : View or Value-castable or dict + View, signal or dict being assigned. + rhs : View or Value-castable or dict + View, signal or dict containing assigned values. fields : AssignType or Iterable or Mapping, optional Determines which fields will be assigned. Possible values: @@ -84,12 +85,12 @@ def assign( All fields in `rhs` are assigned. If one of them is not present in `lhs`, an exception is raised. AssignType.ALL - Assume that both records have the same layouts. All fields present + Assume that both structures have the same layouts. All fields present in `lhs` or `rhs` are assigned. Mapping Keys are field names, values follow the format for `fields`. Iterable - Items are field names. For subrecords, AssignType.ALL is assumed. + Items are field names. For subfields, AssignType.ALL is assumed. Returns ------- @@ -106,18 +107,8 @@ def assign( if lhs_fields is not None and rhs_fields is not None: # asserts for type checking - assert ( - isinstance(lhs, Record) - or isinstance(lhs, ArrayProxy) - or isinstance(lhs, Mapping) - or isinstance(lhs, data.View) - ) - assert ( - isinstance(rhs, Record) - or isinstance(rhs, ArrayProxy) - or isinstance(rhs, Mapping) - or isinstance(rhs, data.View) - ) + assert isinstance(lhs, ArrayProxy) or isinstance(lhs, Mapping) or isinstance(lhs, data.View) + assert isinstance(rhs, ArrayProxy) or isinstance(rhs, Mapping) or isinstance(rhs, data.View) if fields is AssignType.COMMON: names = lhs_fields & rhs_fields @@ -152,7 +143,7 @@ def assign( ) else: if not isinstance(fields, AssignType): - raise ValueError("Fields on assigning non-records") + raise ValueError("Fields on assigning non-structures lhs: {} rhs: {}".format(lhs, rhs)) if not isinstance(lhs, ValueLike) or not isinstance(rhs, ValueLike): raise TypeError("Unsupported assignment lhs: {} rhs: {}".format(lhs, rhs)) @@ -163,9 +154,7 @@ def has_explicit_shape(val: ValueLike): return isinstance(val, Signal) or isinstance(val, ArrayProxy) if ( - isinstance(lhs, Record) - or isinstance(rhs, Record) - or isinstance(lhs, data.View) + isinstance(lhs, data.View) or isinstance(rhs, data.View) or (lhs_strict or has_explicit_shape(lhs)) and (rhs_strict or has_explicit_shape(rhs)) diff --git a/transactron/utils/data_repr.py b/transactron/utils/data_repr.py index c18952d17..0974ba4f0 100644 --- a/transactron/utils/data_repr.py +++ b/transactron/utils/data_repr.py @@ -1,7 +1,8 @@ from collections.abc import Iterable, Mapping -from ._typing import LayoutList, SimpleLayout +from ._typing import LayoutList, SimpleLayout, ShapeLike from typing import Any, Sized from statistics import fmean +from amaranth.lib.data import StructLayout __all__ = [ @@ -18,8 +19,8 @@ ] -def layout_subset(layout: LayoutList, *, fields: set[str]) -> LayoutList: - return [item for item in layout if item[0] in fields] +def layout_subset(layout: StructLayout, *, fields: set[str]) -> StructLayout: + return StructLayout({item: value for item, value in layout.members.items() if item in fields}) def make_hashable(val): @@ -77,7 +78,7 @@ def bits_from_int(num: int, lower: int, length: int): return (num >> lower) & ((1 << (length)) - 1) -def data_layout(val: int) -> SimpleLayout: +def data_layout(val: ShapeLike) -> SimpleLayout: return [("data", val)] diff --git a/transactron/utils/dependencies.py b/transactron/utils/dependencies.py index 2aa0c73df..e66a7a23b 100644 --- a/transactron/utils/dependencies.py +++ b/transactron/utils/dependencies.py @@ -7,6 +7,7 @@ __all__ = [ "DependencyManager", "DependencyKey", + "DependencyContext", "SimpleKey", "ListKey" ] @@ -61,9 +62,19 @@ class SimpleKey(Generic[T], DependencyKey[T, T]): Simple dependency keys are used when there is an one-to-one relation between keys and dependencies. If more than one dependency is added to a simple key, an error is raised. + + Parameters + ---------- + default_value: T + Specifies the default value returned when no dependencies are added. To + enable it `empty_valid` must be True. """ + default_value: T + def combine(self, data: list[T]) -> T: + if len(data) == 0: + return self.default_value if len(data) != 1: raise RuntimeError(f"Key {self} assigned {len(data)} values, expected 1") return data[0] @@ -116,3 +127,24 @@ def get_dependency(self, key: DependencyKey[Any, U]) -> U: self.locked_dependencies.add(key) return key.combine(self.dependencies[key]) + + +class DependencyContext: + stack: list[DependencyManager] = [] + + def __init__(self, manager: DependencyManager): + self.manager = manager + + def __enter__(self): + self.stack.append(self.manager) + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + top = self.stack.pop() + assert self.manager is top + + @classmethod + def get(cls) -> DependencyManager: + if not cls.stack: + raise RuntimeError("DependencyContext stack is empty") + return cls.stack[-1] diff --git a/transactron/utils/gen.py b/transactron/utils/gen.py new file mode 100644 index 000000000..daf462ce7 --- /dev/null +++ b/transactron/utils/gen.py @@ -0,0 +1,246 @@ +from dataclasses import dataclass, field +from dataclasses_json import dataclass_json +from typing import TypeAlias + +from amaranth import * +from amaranth.back import verilog +from amaranth.hdl import Fragment + +from transactron.core import TransactionManager, MethodMap, TransactionManagerKey +from transactron.lib.metrics import HardwareMetricsManager +from transactron.lib import logging +from transactron.utils.dependencies import DependencyContext +from transactron.utils.idgen import IdGenerator +from transactron.profiler import ProfileData + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from amaranth.hdl._ast import SignalDict + + +__all__ = [ + "MetricLocation", + "GeneratedLog", + "GenerationInfo", + "generate_verilog", +] + +SignalHandle: TypeAlias = list[str] +"""The location of a signal is a list of Verilog identifiers that denote a path +consisting of module names (and the signal name at the end) leading +to the signal wire.""" + + +@dataclass_json +@dataclass +class MetricLocation: + """Information about the location of a metric in the generated Verilog code. + + Attributes + ---------- + regs : dict[str, SignalHandle] + The location of each register of that metric. + """ + + regs: dict[str, SignalHandle] = field(default_factory=dict) + + +@dataclass_json +@dataclass +class TransactionSignalsLocation: + """Information about transaction control signals in the generated Verilog code. + + Attributes + ---------- + request: list[str] + The location of the ``request`` signal. + runnable: list[str] + The location of the ``runnable`` signal. + grant: list[str] + The location of the ``grant`` signal. + """ + + request: list[str] + runnable: list[str] + grant: list[str] + + +@dataclass_json +@dataclass +class MethodSignalsLocation: + """Information about method control signals in the generated Verilog code. + + Attributes + ---------- + run: list[str] + The location of the ``run`` signal. + """ + + run: list[str] + + +@dataclass_json +@dataclass +class GeneratedLog(logging.LogRecordInfo): + """Information about a log record in the generated Verilog code. + + Attributes + ---------- + trigger_location : SignalHandle + The location of the trigger signal. + fields_location : list[SignalHandle] + Locations of the log fields. + """ + + trigger_location: SignalHandle + fields_location: list[SignalHandle] + + +@dataclass_json +@dataclass +class GenerationInfo: + """Various information about the generated circuit. + + Attributes + ---------- + metrics_location : dict[str, MetricInfo] + Mapping from a metric name to an object storing Verilog locations + of its registers. + logs : list[GeneratedLog] + Locations and metadata for all log records. + """ + + metrics_location: dict[str, MetricLocation] + transaction_signals_location: dict[int, TransactionSignalsLocation] + method_signals_location: dict[int, MethodSignalsLocation] + profile_data: ProfileData + logs: list[GeneratedLog] + + def encode(self, file_name: str): + """ + Encodes the generation information as JSON and saves it to a file. + """ + with open(file_name, "w") as fp: + fp.write(self.to_json()) # type: ignore + + @staticmethod + def decode(file_name: str) -> "GenerationInfo": + """ + Loads the generation information from a JSON file. + """ + with open(file_name, "r") as fp: + return GenerationInfo.from_json(fp.read()) # type: ignore + + +def escape_verilog_identifier(identifier: str) -> str: + """ + Escapes a Verilog identifier according to the language standard. + + From IEEE Std 1364-2001 (IEEE Standard VerilogĀ® Hardware Description Language) + + "2.7.1 Escaped identifiers + + Escaped identifiers shall start with the backslash character and end with white + space (space, tab, newline). They provide a means of including any of the printable ASCII + characters in an identifier (the decimal values 33 through 126, or 21 through 7E in hexadecimal)." + """ + + # The standard says how to escape a identifier, but not when. So this is + # a non-exhaustive list of characters that Yosys escapes (it is used + # by Amaranth when generating Verilog code). + characters_to_escape = [".", "$", "-"] + + for char in characters_to_escape: + if char in identifier: + return f"\\{identifier} " + + return identifier + + +def get_signal_location(signal: Signal, name_map: "SignalDict") -> SignalHandle: + raw_location = name_map[signal] + return raw_location + + +def collect_metric_locations(name_map: "SignalDict") -> dict[str, MetricLocation]: + metrics_location: dict[str, MetricLocation] = {} + + # Collect information about the location of metric registers in the generated code. + metrics_manager = HardwareMetricsManager() + for metric_name, metric in metrics_manager.get_metrics().items(): + metric_loc = MetricLocation() + for reg_name in metric.regs: + metric_loc.regs[reg_name] = get_signal_location( + metrics_manager.get_register_value(metric_name, reg_name), name_map + ) + + metrics_location[metric_name] = metric_loc + + return metrics_location + + +def collect_transaction_method_signals( + transaction_manager: TransactionManager, name_map: "SignalDict" +) -> tuple[dict[int, TransactionSignalsLocation], dict[int, MethodSignalsLocation]]: + transaction_signals_location: dict[int, TransactionSignalsLocation] = {} + method_signals_location: dict[int, MethodSignalsLocation] = {} + + method_map = MethodMap(transaction_manager.transactions) + get_id = IdGenerator() + + for transaction in method_map.transactions: + request_loc = get_signal_location(transaction.request, name_map) + runnable_loc = get_signal_location(transaction.runnable, name_map) + grant_loc = get_signal_location(transaction.grant, name_map) + transaction_signals_location[get_id(transaction)] = TransactionSignalsLocation( + request_loc, runnable_loc, grant_loc + ) + + for method in method_map.methods: + run_loc = get_signal_location(method.run, name_map) + method_signals_location[get_id(method)] = MethodSignalsLocation(run_loc) + + return (transaction_signals_location, method_signals_location) + + +def collect_logs(name_map: "SignalDict") -> list[GeneratedLog]: + logs: list[GeneratedLog] = [] + + # Get all records. + for record in logging.get_log_records(0): + trigger_loc = get_signal_location(record.trigger, name_map) + fields_loc = [get_signal_location(field, name_map) for field in record.fields] + log = GeneratedLog( + logger_name=record.logger_name, + level=record.level, + format_str=record.format_str, + location=record.location, + trigger_location=trigger_loc, + fields_location=fields_loc, + ) + logs.append(log) + + return logs + + +def generate_verilog( + top_module: Elaboratable, ports: list[Signal], top_name: str = "top" +) -> tuple[str, GenerationInfo]: + fragment = Fragment.get(top_module, platform=None).prepare(ports=ports) + verilog_text, name_map = verilog.convert_fragment(fragment, name=top_name, emit_src=True, strip_internal_attrs=True) + + transaction_manager = DependencyContext.get().get_dependency(TransactionManagerKey()) + transaction_signals, method_signals = collect_transaction_method_signals( + transaction_manager, name_map # type: ignore + ) + profile_data, _ = ProfileData.make(transaction_manager) + gen_info = GenerationInfo( + metrics_location=collect_metric_locations(name_map), # type: ignore + transaction_signals_location=transaction_signals, + method_signals_location=method_signals, + profile_data=profile_data, + logs=collect_logs(name_map), + ) + + return verilog_text, gen_info diff --git a/transactron/utils/idgen.py b/transactron/utils/idgen.py new file mode 100644 index 000000000..459f3160e --- /dev/null +++ b/transactron/utils/idgen.py @@ -0,0 +1,15 @@ +__all__ = ["IdGenerator"] + + +class IdGenerator: + def __init__(self): + self.id_map = dict[int, int]() + self.id_seq = 0 + + def __call__(self, obj): + try: + return self.id_map[id(obj)] + except KeyError: + self.id_seq += 1 + self.id_map[id(obj)] = self.id_seq + return self.id_seq diff --git a/transactron/utils/transactron_helpers.py b/transactron/utils/transactron_helpers.py index 80bee6ffd..048a2bb61 100644 --- a/transactron/utils/transactron_helpers.py +++ b/transactron/utils/transactron_helpers.py @@ -2,11 +2,12 @@ from contextlib import contextmanager from typing import Optional, Any, Concatenate, TypeGuard, TypeVar from collections.abc import Callable, Mapping -from ._typing import ROGraph, GraphCC, SrcLoc +from ._typing import ROGraph, GraphCC, SrcLoc, MethodLayout, MethodStruct, ShapeLike, LayoutList, LayoutListField from inspect import Parameter, signature from itertools import count from amaranth import * from amaranth import tracer +from amaranth.lib.data import StructLayout __all__ = [ @@ -16,6 +17,7 @@ "method_def_helper", "mock_def_helper", "get_src_loc", + "from_method_layout", ] T = TypeVar("T") @@ -89,8 +91,9 @@ def mock_def_helper(tb, func: Callable[..., T], arg: Mapping[str, Any]) -> T: return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], arg, **arg) -def method_def_helper(method, func: Callable[..., T], arg: Record) -> T: - return def_helper(f"method definition for {method}", func, Record, arg, **arg.fields) +def method_def_helper(method, func: Callable[..., T], arg: MethodStruct) -> T: + kwargs = {k: arg[k] for k in arg.shape().members} + return def_helper(f"method definition for {method}", func, MethodStruct, arg, **kwargs) def get_caller_class_name(default: Optional[str] = None) -> tuple[Optional[Elaboratable], str]: @@ -121,3 +124,21 @@ def silence_mustuse(elaboratable: Elaboratable): def get_src_loc(src_loc: int | SrcLoc) -> SrcLoc: return tracer.get_src_loc(1 + src_loc) if isinstance(src_loc, int) else src_loc + + +def from_layout_field(shape: ShapeLike | LayoutList) -> ShapeLike: + if isinstance(shape, list): + return from_method_layout(shape) + else: + return shape + + +def make_layout(*fields: LayoutListField): + return from_method_layout(fields) + + +def from_method_layout(layout: MethodLayout) -> StructLayout: + if isinstance(layout, StructLayout): + return layout + else: + return StructLayout({k: from_layout_field(v) for k, v in layout})