Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pytest - step 1 #553

Merged
merged 11 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion coreblocks/frontend/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def elaborate(self, platform) -> TModule:
m.d.av_comb += stalled.eq(stalled_unsafe | stalled_exception)

with Transaction().body(m, request=~stalled):
aligned_pc = Cat(Repl(0, 2), cache_req_pc[2:])
aligned_pc = Cat(C(0, 2), cache_req_pc[2:])
self.icache.issue_req(m, addr=aligned_pc)
req_limiter.acquire(m)

Expand Down
8 changes: 4 additions & 4 deletions coreblocks/frontend/icache.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _(addr: Value) -> None:
addr=addr >> log2_int(self.params.word_width_bytes),
data=0,
we=0,
sel=Repl(1, self.wb_master.wb_params.data_width // self.wb_master.wb_params.granularity),
sel=C(1).replicate(self.wb_master.wb_params.data_width // self.wb_master.wb_params.granularity),
)

@def_method(m, self.accept_res)
Expand Down Expand Up @@ -275,7 +275,7 @@ def _() -> None:

with m.If(fsm.ongoing("FLUSH")):
m.d.comb += [
self.mem.way_wr_en.eq(Repl(1, self.params.num_of_ways)),
self.mem.way_wr_en.eq(C(1).replicate(self.params.num_of_ways)),
self.mem.tag_wr_index.eq(flush_index),
self.mem.tag_wr_data.valid.eq(0),
self.mem.tag_wr_data.tag.eq(0),
Expand Down Expand Up @@ -393,7 +393,7 @@ def elaborate(self, platform):
addr=Cat(address["word_counter"], address["refill_address"]),
data=0,
we=0,
sel=Repl(1, self.wb_master.wb_params.data_width // self.wb_master.wb_params.granularity),
sel=C(1).replicate(self.wb_master.wb_params.data_width // self.wb_master.wb_params.granularity),
)

@def_method(m, self.start_refill, ready=~refill_active)
Expand Down Expand Up @@ -421,7 +421,7 @@ def _():
address_fwd.write(m, word_counter=next_word_counter, refill_address=refill_address)

return {
"addr": Cat(Repl(0, log2_int(self.params.word_width_bytes)), word_counter, refill_address),
"addr": Cat(C(0, log2_int(self.params.word_width_bytes)), word_counter, refill_address),
"data": fetched.data,
"error": fetched.err,
"last": last,
Expand Down
34 changes: 17 additions & 17 deletions coreblocks/frontend/rvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def _quadrant_0(self) -> list[DecodedInstr]:
rd = self.decompr_reg(self.instr_in[2:5])

addi4spn_imm = Cat(
Repl(0, 2), self.instr_in[6], self.instr_in[5], self.instr_in[11:13], self.instr_in[7:11], Repl(0, 2)
C(0, 2), self.instr_in[6], self.instr_in[5], self.instr_in[11:13], self.instr_in[7:11], C(0, 2)
)
lsd_imm = Cat(Repl(0, 3), self.instr_in[10:13], self.instr_in[5:7], Repl(0, 4))
lsw_imm = Cat(Repl(0, 2), self.instr_in[6], self.instr_in[10:13], self.instr_in[5], Repl(0, 5))
lsd_imm = Cat(C(0, 3), self.instr_in[10:13], self.instr_in[5:7], C(0, 4))
lsw_imm = Cat(C(0, 2), self.instr_in[6], self.instr_in[10:13], self.instr_in[5], C(0, 5))

addi4spn = (
ITypeInstr(opcode=Opcode.OP_IMM, rd=rd, funct3=Funct3.ADD, rs1=Registers.SP, imm=addi4spn_imm),
Expand Down Expand Up @@ -98,34 +98,34 @@ def _quadrant_1(self) -> list[DecodedInstr]:
rs2 = self.decompr_reg(self.instr_in[2:5])
rd = self.instr_in[7:12]

addi_imm = Cat(self.instr_in[2:7], Repl(self.instr_in[12], 7))
addi_imm = Cat(self.instr_in[2:7], self.instr_in[12].replicate(7))
addi16sp_imm = Cat(
Repl(0, 4),
C(0, 4),
self.instr_in[6],
self.instr_in[2],
self.instr_in[5],
self.instr_in[3:5],
Repl(self.instr_in[12], 3),
self.instr_in[12].replicate(3),
)
lui_imm = Cat(Repl(0, 12), self.instr_in[2:7], Repl(self.instr_in[12], 15))
lui_imm = Cat(C(0, 12), self.instr_in[2:7], self.instr_in[12].replicate(15))
j_imm = Cat(
Repl(0, 1),
C(0, 1),
self.instr_in[3:6],
self.instr_in[11],
self.instr_in[2],
self.instr_in[7],
self.instr_in[6],
self.instr_in[9:11],
self.instr_in[8],
Repl(self.instr_in[12], 10),
self.instr_in[12].replicate(10),
)
b_imm = Cat(
Repl(0, 1),
C(0, 1),
self.instr_in[3:5],
self.instr_in[10:12],
self.instr_in[2],
self.instr_in[5:7],
Repl(self.instr_in[12], 5),
self.instr_in[12].replicate(5),
)
shamt = Cat(self.instr_in[2:7], self.instr_in[12])

Expand Down Expand Up @@ -206,10 +206,10 @@ def _quadrant_2(self) -> list[DecodedInstr]:
rs2 = self.instr_in[2:7]

shamt = Cat(self.instr_in[2:7], self.instr_in[12])
ldsp_imm = Cat(Repl(0, 3), self.instr_in[5:7], self.instr_in[12], self.instr_in[2:5], Repl(0, 3))
lwsp_imm = Cat(Repl(0, 2), self.instr_in[4:7], self.instr_in[12], self.instr_in[2:4], Repl(0, 4))
sdsp_imm = Cat(Repl(0, 3), self.instr_in[10:13], self.instr_in[7:10], Repl(0, 2))
swsp_imm = Cat(Repl(0, 2), self.instr_in[9:13], self.instr_in[7:9], Repl(0, 4))
ldsp_imm = Cat(C(0, 3), self.instr_in[5:7], self.instr_in[12], self.instr_in[2:5], C(0, 3))
lwsp_imm = Cat(C(0, 2), self.instr_in[4:7], self.instr_in[12], self.instr_in[2:4], C(0, 4))
sdsp_imm = Cat(C(0, 3), self.instr_in[10:13], self.instr_in[7:10], C(0, 2))
swsp_imm = Cat(C(0, 2), self.instr_in[9:13], self.instr_in[7:9], C(0, 4))

slli = (
RTypeInstr(
Expand Down Expand Up @@ -249,10 +249,10 @@ def _quadrant_2(self) -> list[DecodedInstr]:
sdsp = STypeInstr(opcode=Opcode.STORE, imm=sdsp_imm, funct3=Funct3.D, rs1=Registers.SP, rs2=rs2)

jr = (
ITypeInstr(opcode=Opcode.JALR, rd=Registers.ZERO, funct3=Funct3.JALR, rs1=rd_rs1, imm=Repl(0, 12)),
ITypeInstr(opcode=Opcode.JALR, rd=Registers.ZERO, funct3=Funct3.JALR, rs1=rd_rs1, imm=C(0, 12)),
rd_rs1.any(),
)
jalr = ITypeInstr(opcode=Opcode.JALR, rd=Registers.RA, funct3=Funct3.JALR, rs1=rd_rs1, imm=Repl(0, 12))
jalr = ITypeInstr(opcode=Opcode.JALR, rd=Registers.RA, funct3=Funct3.JALR, rs1=rd_rs1, imm=C(0, 12))

ebreak = EBreakInstr()

Expand Down
6 changes: 3 additions & 3 deletions coreblocks/fu/alu.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ def elaborate(self, platform):
m.d.comb += clz.in_sig.eq(self.in1[::-1])
m.d.comb += self.out.eq(clz.out_sig)
with OneHotCase(AluFn.Fn.SEXTH):
m.d.comb += self.out.eq(Cat(self.in1[0:16], Repl(self.in1[15], xlen - 16)))
m.d.comb += self.out.eq(Cat(self.in1[0:16], self.in1[15].replicate(xlen - 16)))
with OneHotCase(AluFn.Fn.SEXTB):
m.d.comb += self.out.eq(Cat(self.in1[0:8], Repl(self.in1[7], xlen - 8)))
m.d.comb += self.out.eq(Cat(self.in1[0:8], self.in1[7].replicate(xlen - 8)))
with OneHotCase(AluFn.Fn.ZEXTH):
m.d.comb += self.out.eq(Cat(self.in1[0:16], C(0, shape=unsigned(xlen - 16))))
with OneHotCase(AluFn.Fn.ORCB):

def _or(s: Value) -> Value:
return Repl(s.any(), 8)
return s.any().replicate(8)

for i in range(xlen // 8):
m.d.comb += self.out[i * 8 : (i + 1) * 8].eq(_or(self.in1[i * 8 : (i + 1) * 8]))
Expand Down
2 changes: 1 addition & 1 deletion coreblocks/fu/shift_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def elaborate(self, platform):
with OneHotCase(ShiftUnitFn.Fn.SRL):
m.d.comb += self.out.eq(self.in1 >> self.in2[0:xlen_log])
with OneHotCase(ShiftUnitFn.Fn.SRA):
m.d.comb += self.out.eq(Cat(self.in1, Repl(self.in1[xlen - 1], xlen)) >> self.in2[0:xlen_log])
m.d.comb += self.out.eq(Cat(self.in1, self.in1[xlen - 1].replicate(xlen)) >> self.in2[0:xlen_log])

if self.zbb_enable:
with OneHotCase(ShiftUnitFn.Fn.ROL):
Expand Down
14 changes: 7 additions & 7 deletions coreblocks/params/instr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self.funct7 = Value.cast(funct7)

def pack(self) -> Value:
return Cat(Repl(1, 2), self.opcode, self.rd, self.funct3, self.rs1, self.rs2, self.funct7)
return Cat(C(0b11, 2), self.opcode, self.rd, self.funct3, self.rs1, self.rs2, self.funct7)


class ITypeInstr(RISCVInstr):
Expand All @@ -62,7 +62,7 @@ def __init__(self, opcode: ValueLike, rd: ValueLike, funct3: ValueLike, rs1: Val
self.imm = Value.cast(imm)

def pack(self) -> Value:
return Cat(Repl(1, 2), self.opcode, self.rd, self.funct3, self.rs1, self.imm)
return Cat(C(0b11, 2), self.opcode, self.rd, self.funct3, self.rs1, self.imm)


class STypeInstr(RISCVInstr):
Expand All @@ -74,7 +74,7 @@ def __init__(self, opcode: ValueLike, imm: ValueLike, funct3: ValueLike, rs1: Va
self.rs2 = Value.cast(rs2)

def pack(self) -> Value:
return Cat(Repl(1, 2), self.opcode, self.imm[0:5], self.funct3, self.rs1, self.rs2, self.imm[5:12])
return Cat(C(0b11, 2), self.opcode, self.imm[0:5], self.funct3, self.rs1, self.rs2, self.imm[5:12])


class BTypeInstr(RISCVInstr):
Expand All @@ -87,7 +87,7 @@ def __init__(self, opcode: ValueLike, imm: ValueLike, funct3: ValueLike, rs1: Va

def pack(self) -> Value:
return Cat(
Repl(1, 2),
C(0b11, 2),
self.opcode,
self.imm[11],
self.imm[1:5],
Expand All @@ -106,7 +106,7 @@ def __init__(self, opcode: ValueLike, rd: ValueLike, imm: ValueLike):
self.imm = Value.cast(imm)

def pack(self) -> Value:
return Cat(Repl(1, 2), self.opcode, self.rd, self.imm[12:])
return Cat(C(0b11, 2), self.opcode, self.rd, self.imm[12:])


class JTypeInstr(RISCVInstr):
Expand All @@ -116,15 +116,15 @@ def __init__(self, opcode: ValueLike, rd: ValueLike, imm: ValueLike):
self.imm = Value.cast(imm)

def pack(self) -> Value:
return Cat(Repl(1, 2), self.opcode, self.rd, self.imm[12:20], self.imm[11], self.imm[1:11], self.imm[20])
return Cat(C(0b11, 2), self.opcode, self.rd, self.imm[12:20], self.imm[11], self.imm[1:11], self.imm[20])


class IllegalInstr(RISCVInstr):
def __init__(self):
pass

def pack(self) -> Value:
return Repl(1, 32) # Instructions with all bits set to 1 are reserved to be illegal.
return C(1).replicate(32) # Instructions with all bits set to 1 are reserved to be illegal.


class EBreakInstr(ITypeInstr):
Expand Down
3 changes: 1 addition & 2 deletions coreblocks/peripherals/wishbone.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from amaranth import *
from amaranth.hdl.rec import DIR_FANIN, DIR_FANOUT
from amaranth.lib.scheduler import RoundRobin
from functools import reduce
from typing import List
import operator

from transactron import Method, def_method, TModule
from transactron.core import Transaction
from transactron.lib import AdapterTrans, BasicFifo
from transactron.utils import OneHotSwitchDynamic, assign
from transactron.utils import OneHotSwitchDynamic, assign, RoundRobin
from transactron.lib.connectors import Forwarder


Expand Down
4 changes: 2 additions & 2 deletions scripts/core_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

from coreblocks.params.genparams import GenParams # noqa: E402
from transactron.graph import TracingFragment # noqa: E402
from test.test_core import TestElaboratable # noqa: E402
from test.test_core import CoreTestElaboratable # noqa: E402
from coreblocks.params.configurations import basic_core_config # noqa: E402
from transactron.core import TransactionModule # noqa: E402

gp = GenParams(basic_core_config)
elaboratable = TestElaboratable(gp)
elaboratable = CoreTestElaboratable(gp)
tm = TransactionModule(elaboratable)
fragment = TracingFragment.get(tm, platform=None).prepare()

Expand Down
20 changes: 20 additions & 0 deletions stubs/amaranth/hdl/ast.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,26 @@ class Value(metaclass=ABCMeta):
"""Rotate right by constant amount."""
...

def replicate(self, count : int) -> Value:
"""Replication.

A ``Value`` is replicated (repeated) several times to be used
on the RHS of assignments::

len(v.replicate(n)) == len(v) * n

Parameters
----------
count : int
Number of replications.

Returns
-------
Value, out
Replicated value.
"""
...

def eq(self, value: ValueLike) -> Assign:
"""Assignment.

Expand Down
4 changes: 2 additions & 2 deletions test/common/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def debug_signals(self):
return sigs


class TestModule(Elaboratable):
class _TestModule(Elaboratable):
def __init__(self, tested_module: HasElaborate, add_transaction_module):
self.tested_module = TransactionModule(tested_module) if add_transaction_module else tested_module
self.add_transaction_module = add_transaction_module
Expand Down Expand Up @@ -137,7 +137,7 @@ def _wrapping_function(self):

class PysimSimulator(Simulator):
def __init__(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True, traces_file=None):
test_module = TestModule(module, add_transaction_module)
test_module = _TestModule(module, add_transaction_module)
tested_module = test_module.tested_module
super().__init__(test_module)

Expand Down
Loading