From a4ff06168e2b83bfbe06e332796d0b3e2d65e77d Mon Sep 17 00:00:00 2001 From: Jacob Urbanczyk Date: Wed, 27 Mar 2024 15:46:41 +0100 Subject: [PATCH 1/6] Refactor RISC-V instruction models --- coreblocks/frontend/decoder/rvc.py | 2 +- coreblocks/params/instr.py | 289 ++++++++++++++++------------- test/frontend/test_rvc.py | 46 +++-- test/params/test_instr.py | 63 +++++++ test/test_core.py | 6 +- 5 files changed, 248 insertions(+), 158 deletions(-) create mode 100644 test/params/test_instr.py diff --git a/coreblocks/frontend/decoder/rvc.py b/coreblocks/frontend/decoder/rvc.py index 4ff48c07d..2fe9d42ee 100644 --- a/coreblocks/frontend/decoder/rvc.py +++ b/coreblocks/frontend/decoder/rvc.py @@ -209,7 +209,7 @@ def _quadrant_2(self) -> list[DecodedInstr]: shamt = Cat(self.instr_in[2:7], self.instr_in[12]) ldsp_imm = Cat(C(0, 3), self.instr_in[5:7], self.instr_in[12], self.instr_in[2:5], C(0, 3)) lwsp_imm = Cat(C(0, 2), self.instr_in[4:7], self.instr_in[12], self.instr_in[2:4], C(0, 4)) - sdsp_imm = Cat(C(0, 3), self.instr_in[10:13], self.instr_in[7:10], C(0, 2)) + sdsp_imm = Cat(C(0, 3), self.instr_in[10:13], self.instr_in[7:10], C(0, 3)) swsp_imm = Cat(C(0, 2), self.instr_in[9:13], self.instr_in[7:9], C(0, 4)) slli = ( diff --git a/coreblocks/params/instr.py b/coreblocks/params/instr.py index 370d25b84..08cf8322c 100644 --- a/coreblocks/params/instr.py +++ b/coreblocks/params/instr.py @@ -1,14 +1,24 @@ -from abc import abstractmethod, ABC +""" + +Based on riscv-python-model by Stefan Wallentowitz +https://github.com/wallento/riscv-python-model +""" + +from dataclasses import dataclass +from abc import ABC +from enum import Enum +from typing import Optional from amaranth.hdl import ValueCastable from amaranth import * -from transactron.utils import ValueLike, int_to_signed +from transactron.utils import ValueLike from coreblocks.params.isa_params import * from coreblocks.frontend.decoder.isa import * __all__ = [ + "RISCVInstr", "RTypeInstr", "ITypeInstr", "STypeInstr", @@ -20,154 +30,171 @@ ] +@dataclass(frozen=True, kw_only=True) +class Field: + name: str + base: int | list[int] + size: int | list[int] + + signed: bool = False + offset: int = 0 + static_value: Optional[Value] = None + + def get_base(self) -> list[int]: + if isinstance(self.base, int): + return [self.base] + return self.base + + def get_size(self) -> list[int]: + if isinstance(self.size, int): + return [self.size] + return self.size + + class RISCVInstr(ABC, ValueCastable): - @abstractmethod - def pack(self) -> Value: - pass + field_opcode = Field(name="opcode", base=0, size=7) + + def __init__(self, **kwargs): + for field in kwargs: + fname = "field_" + field + assert fname in dir(self), "Invalid field {} for {}".format(fname, self.__name__) + setattr(self, field, kwargs[field]) + + @classmethod + def get_fields(cls) -> list[Field]: + return [getattr(cls, member) for member in dir(cls) if member.startswith("field_")] + + def encode(self) -> int: + const = Const.cast(self.as_value()) + return const.value # type: ignore + + def __setattr__(self, key, value): + fname = "field_{}".format(key) + + if fname not in dir(self): + super().__setattr__(key, value) + return + + field = getattr(self, fname) + if field.static_value is not None: + raise AttributeError("Can't overwrite the static value of a field.") + + expected_shape = Shape(width=sum(field.get_size()) + field.offset, signed=field.signed) + + field_val: Value = C(0) + if isinstance(value, Enum): + field_val = Const(value.value, expected_shape) + elif isinstance(value, int): + field_val = Const(value, expected_shape) + else: + field_val = Value.cast(value) + + if field_val.shape().width != expected_shape.width: + raise AttributeError( + f"Expected width of the value: {expected_shape.width}, given: {field_val.shape().width}" + ) + if field_val.shape().signed and not expected_shape.signed: + raise AttributeError( + f"Expected signedness of the value: {expected_shape.signed}, given: {field_val.shape().signed}" + ) + + self.__dict__[key] = field_val @ValueCastable.lowermethod - def as_value(self): - return self.pack() + def as_value(self) -> Value: + parts: list[tuple[int, Value]] = [] + + for field in self.get_fields(): + value: Value = C(0) + if field.static_value is not None: + value = field.static_value + else: + value = getattr(self, field.name) + + base = field.get_base() + size = field.get_size() - def shape(self): + offset = field.offset + for i in range(len(base)): + parts.append((base[i], value[offset : offset + size[i]])) + offset += size[i] + + parts.sort() + return Cat([part[1] for part in parts]) + + def shape(self) -> Shape: return self.as_value().shape() -class RTypeInstr(RISCVInstr): - def __init__( - self, - opcode: ValueLike, - rd: ValueLike, - funct3: ValueLike, - rs1: ValueLike, - rs2: ValueLike, - funct7: ValueLike, - ): - self.opcode = Value.cast(opcode) - self.rd = Value.cast(rd) - self.funct3 = Value.cast(funct3) - self.rs1 = Value.cast(rs1) - self.rs2 = Value.cast(rs2) - self.funct7 = Value.cast(funct7) - - def pack(self) -> Value: - return Cat(C(0b11, 2), self.opcode, self.rd, self.funct3, self.rs1, self.rs2, self.funct7) - - @staticmethod - def encode(opcode: int, rd: int, funct3: int, rs1: int, rs2: int, funct7: int): - return int(f"{funct7:07b}{rs2:05b}{rs1:05b}{funct3:03b}{rd:05b}{opcode:05b}11", 2) - - -class ITypeInstr(RISCVInstr): - def __init__(self, opcode: ValueLike, rd: ValueLike, funct3: ValueLike, rs1: ValueLike, imm: ValueLike): - self.opcode = Value.cast(opcode) - self.rd = Value.cast(rd) - self.funct3 = Value.cast(funct3) - self.rs1 = Value.cast(rs1) - self.imm = Value.cast(imm) - - def pack(self) -> Value: - return Cat(C(0b11, 2), self.opcode, self.rd, self.funct3, self.rs1, self.imm) - - @staticmethod - def encode(opcode: int, rd: int, funct3: int, rs1: int, imm: int): - imm = int_to_signed(imm, 12) - return int(f"{imm:012b}{rs1:05b}{funct3:03b}{rd:05b}{opcode:05b}11", 2) - - -class STypeInstr(RISCVInstr): - def __init__(self, opcode: ValueLike, imm: ValueLike, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike): - self.opcode = Value.cast(opcode) - self.imm = Value.cast(imm) - self.funct3 = Value.cast(funct3) - self.rs1 = Value.cast(rs1) - self.rs2 = Value.cast(rs2) - - def pack(self) -> Value: - return Cat(C(0b11, 2), self.opcode, self.imm[0:5], self.funct3, self.rs1, self.rs2, self.imm[5:12]) - - @staticmethod - def encode(opcode: int, imm: int, funct3: int, rs1: int, rs2: int): - imm = int_to_signed(imm, 12) - imm_str = f"{imm:012b}" - return int(f"{imm_str[5:12]:07b}{rs2:05b}{rs1:05b}{funct3:03b}{imm_str[0:5]:05b}{opcode:05b}11", 2) - - -class BTypeInstr(RISCVInstr): - def __init__(self, opcode: ValueLike, imm: ValueLike, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike): - self.opcode = Value.cast(opcode) - self.imm = Value.cast(imm) - self.funct3 = Value.cast(funct3) - self.rs1 = Value.cast(rs1) - self.rs2 = Value.cast(rs2) - - def pack(self) -> Value: - return Cat( - C(0b11, 2), - self.opcode, - self.imm[11], - self.imm[1:5], - self.funct3, - self.rs1, - self.rs2, - self.imm[5:11], - self.imm[12], - ) +class InstructionFunct3Type(RISCVInstr): + field_funct3 = Field(name="funct3", base=12, size=3) - @staticmethod - def encode(opcode: int, imm: int, funct3: int, rs1: int, rs2: int): - imm = int_to_signed(imm, 13) - imm_str = f"{imm:013b}" - return int( - f"{imm_str[12]:01b}{imm_str[5:11]:06b}{rs2:05b}{rs1:05b}{funct3:03b}{imm_str[1:5]:04b}" - + f"{imm_str[11]:01b}{opcode:05b}11", - 2, - ) +class InstructionFunct5Type(RISCVInstr): + field_funct5 = Field(name="funct5", base=27, size=5) -class UTypeInstr(RISCVInstr): - def __init__(self, opcode: ValueLike, rd: ValueLike, imm: ValueLike): - self.opcode = Value.cast(opcode) - self.rd = Value.cast(rd) - self.imm = Value.cast(imm) - def pack(self) -> Value: - return Cat(C(0b11, 2), self.opcode, self.rd, self.imm[12:]) +class InstructionFunct7Type(RISCVInstr): + field_funct7 = Field(name="funct7", base=25, size=7) + + +class RTypeInstr(InstructionFunct3Type, InstructionFunct7Type): + field_rd = Field(name="rd", base=7, size=5) + field_rs1 = Field(name="rs1", base=15, size=5) + field_rs2 = Field(name="rs2", base=20, size=5) + + def __init__(self, opcode: ValueLike, **kwargs): + super().__init__(opcode=Cat(C(0b11, 2), opcode), **kwargs) - @staticmethod - def encode(opcode: int, rd: int, imm: int): - imm = int_to_signed(imm, 20) - return int(f"{imm:020b}{rd:05b}{opcode:05b}11", 2) + +class ITypeInstr(InstructionFunct3Type): + field_rd = Field(name="rd", base=7, size=5) + field_rs1 = Field(name="rs1", base=15, size=5) + field_imm = Field(name="imm", base=20, size=12, signed=True) + + def __init__(self, opcode: ValueLike, **kwargs): + super().__init__(opcode=Cat(C(0b11, 2), opcode), **kwargs) + + +class STypeInstr(InstructionFunct3Type): + field_rs1 = Field(name="rs1", base=15, size=5) + field_rs2 = Field(name="rs2", base=20, size=5) + field_imm = Field(name="imm", base=[7, 25], size=[5, 7], signed=True) + + def __init__(self, opcode: ValueLike, **kwargs): + super().__init__(opcode=Cat(C(0b11, 2), opcode), **kwargs) + + +class BTypeInstr(InstructionFunct3Type): + field_rs1 = Field(name="rs1", base=15, size=5) + field_rs2 = Field(name="rs2", base=20, size=5) + field_imm = Field(name="imm", base=[8, 25, 7, 31], size=[4, 6, 1, 1], offset=1, signed=True) + + def __init__(self, opcode: ValueLike, **kwargs): + super().__init__(opcode=Cat(C(0b11, 2), opcode), **kwargs) + + +class UTypeInstr(RISCVInstr): + field_rd = Field(name="rd", base=7, size=5) + field_imm = Field(name="imm", base=12, size=20, offset=12, signed=False) + + def __init__(self, opcode: ValueLike, **kwargs): + super().__init__(opcode=Cat(C(0b11, 2), opcode), **kwargs) class JTypeInstr(RISCVInstr): - def __init__(self, opcode: ValueLike, rd: ValueLike, imm: ValueLike): - self.opcode = Value.cast(opcode) - self.rd = Value.cast(rd) - self.imm = Value.cast(imm) - - def pack(self) -> Value: - return Cat(C(0b11, 2), self.opcode, self.rd, self.imm[12:20], self.imm[11], self.imm[1:11], self.imm[20]) - - @staticmethod - def encode(opcode: int, rd: int, imm: int): - imm = int_to_signed(imm, 21) - imm_str = f"{imm:021b}" - return int( - f"{imm_str[20]:01b}{imm_str[1:11]:010b}{imm_str[11]:01b}{imm_str[12:20]:08b}{rd:05b}{opcode:05b}11", 2 - ) + field_rd = Field(name="rd", base=7, size=5) + field_imm = Field(name="imm", base=[21, 20, 12, 31], size=[10, 1, 8, 1], offset=1, signed=True) + def __init__(self, opcode: ValueLike, **kwargs): + super().__init__(opcode=Cat(C(0b11, 2), opcode), **kwargs) -class IllegalInstr(RISCVInstr): - def __init__(self): - pass - def pack(self) -> Value: - return C(1).replicate(32) # Instructions with all bits set to 1 are reserved to be illegal. +class IllegalInstr(RISCVInstr): + field_illegal = Field(name="illegal", base=7, size=25, static_value=Cat(1).replicate(25)) - @staticmethod - def encode(opcode: int, rd: int, imm: int): - return int("1" * 32, 2) + def __init__(self): + super().__init__(opcode=0b1111111) class EBreakInstr(ITypeInstr): diff --git a/test/frontend/test_rvc.py b/test/frontend/test_rvc.py index 0b099f751..4f3e7cb4d 100644 --- a/test/frontend/test_rvc.py +++ b/test/frontend/test_rvc.py @@ -25,17 +25,17 @@ # c.addi x2, -28 ( 0x1111, - ITypeInstr(opcode=Opcode.OP_IMM, rd=Registers.X2, funct3=Funct3.ADD, rs1=Registers.X2, imm=C(-28, 12)), + ITypeInstr(opcode=Opcode.OP_IMM, rd=Registers.X2, funct3=Funct3.ADD, rs1=Registers.X2, imm=-28), ), # c.li x31, -7 ( 0x5FE5, - ITypeInstr(opcode=Opcode.OP_IMM, rd=Registers.X31, funct3=Funct3.ADD, rs1=Registers.ZERO, imm=C(-7, 12)), + ITypeInstr(opcode=Opcode.OP_IMM, rd=Registers.X31, funct3=Funct3.ADD, rs1=Registers.ZERO, imm=-7), ), # c.addi16sp 496 (0x617D, ITypeInstr(opcode=Opcode.OP_IMM, rd=Registers.SP, funct3=Funct3.ADD, rs1=Registers.SP, imm=496)), # c.lui x7, -3 - (0x73F5, UTypeInstr(opcode=Opcode.LUI, rd=Registers.X7, imm=C(-3, 20) << 12)), + (0x73F5, UTypeInstr(opcode=Opcode.LUI, rd=Registers.X7, imm=Cat(C(0, 12), C(-3, 20)))), # c.srli x10, 3 ( 0x810D, @@ -44,7 +44,7 @@ rd=Registers.X10, funct3=Funct3.SR, rs1=Registers.X10, - rs2=C(3, 5), + rs2=Registers.X3, funct7=Funct7.SL, ), ), @@ -56,7 +56,7 @@ rd=Registers.X12, funct3=Funct3.SR, rs1=Registers.X12, - rs2=C(8, 5), + rs2=Registers.X8, funct7=Funct7.SA, ), ), @@ -111,16 +111,16 @@ ), ), # c.j 2012 - (0xAFF1, JTypeInstr(opcode=Opcode.JAL, rd=Registers.ZERO, imm=C(2012, 21))), + (0xAFF1, JTypeInstr(opcode=Opcode.JAL, rd=Registers.ZERO, imm=2012)), # c.beqz x8, -6 ( 0xDC6D, - BTypeInstr(opcode=Opcode.BRANCH, imm=C(-6, 13), funct3=Funct3.BEQ, rs1=Registers.X8, rs2=Registers.ZERO), + BTypeInstr(opcode=Opcode.BRANCH, imm=-6, funct3=Funct3.BEQ, rs1=Registers.X8, rs2=Registers.ZERO), ), # c.bnez x15, 20 ( 0xEB91, - BTypeInstr(opcode=Opcode.BRANCH, imm=C(20, 13), funct3=Funct3.BNE, rs1=Registers.X15, rs2=Registers.ZERO), + BTypeInstr(opcode=Opcode.BRANCH, imm=20, funct3=Funct3.BNE, rs1=Registers.X15, rs2=Registers.ZERO), ), # c.slli x13, 31 ( @@ -130,18 +130,16 @@ rd=Registers.X13, funct3=Funct3.SLL, rs1=Registers.X13, - rs2=C(31, 5), + rs2=Registers.X31, funct7=Funct7.SL, ), ), # c.lwsp x2, 4 - (0x4112, ITypeInstr(opcode=Opcode.LOAD, rd=Registers.X2, funct3=Funct3.W, rs1=Registers.SP, imm=C(4, 12))), + (0x4112, ITypeInstr(opcode=Opcode.LOAD, rd=Registers.X2, funct3=Funct3.W, rs1=Registers.SP, imm=4)), # c.jr x30 ( 0x8F02, - ITypeInstr( - opcode=Opcode.JALR, rd=Registers.ZERO, funct3=Funct3.JALR, rs1=Registers.X30, imm=C(0).replicate(12) - ), + ITypeInstr(opcode=Opcode.JALR, rd=Registers.ZERO, funct3=Funct3.JALR, rs1=Registers.X30, imm=0), ), # c.mv x2, x26 ( @@ -170,7 +168,7 @@ ), ), # c.swsp x31, 20 - (0xCA7E, STypeInstr(opcode=Opcode.STORE, imm=C(20, 12), funct3=Funct3.W, rs1=Registers.SP, rs2=Registers.X31)), + (0xCA7E, STypeInstr(opcode=Opcode.STORE, imm=20, funct3=Funct3.W, rs1=Registers.SP, rs2=Registers.X31)), ] RV32_TESTS = [ @@ -179,9 +177,9 @@ # c.sd x14, 0(x13) (0xE298, IllegalInstr()), # c.jal 40 - (0x2025, JTypeInstr(opcode=Opcode.JAL, rd=Registers.RA, imm=C(40, 21))), + (0x2025, JTypeInstr(opcode=Opcode.JAL, rd=Registers.RA, imm=40)), # c.jal -412 - (0x3595, JTypeInstr(opcode=Opcode.JAL, rd=Registers.RA, imm=C(-412, 21))), + (0x3595, JTypeInstr(opcode=Opcode.JAL, rd=Registers.RA, imm=-412)), # c.srli x10, 32 (0x9101, IllegalInstr()), # c.srai x12, 40 @@ -196,13 +194,13 @@ RV64_TESTS = [ # c.ld x8, 8(x9) - (0x6480, ITypeInstr(opcode=Opcode.LOAD, rd=Registers.X8, funct3=Funct3.D, rs1=Registers.X9, imm=C(8, 12))), + (0x6480, ITypeInstr(opcode=Opcode.LOAD, rd=Registers.X8, funct3=Funct3.D, rs1=Registers.X9, imm=8)), # c.sd x14, 0(x13) - (0xE298, STypeInstr(opcode=Opcode.STORE, imm=C(0, 12), funct3=Funct3.D, rs1=Registers.X13, rs2=Registers.X14)), + (0xE298, STypeInstr(opcode=Opcode.STORE, imm=0, funct3=Funct3.D, rs1=Registers.X13, rs2=Registers.X14)), # c.addiw x13, -12, ( 0x36D1, - ITypeInstr(opcode=Opcode.OP_IMM_32, rd=Registers.X13, funct3=Funct3.ADD, rs1=Registers.X13, imm=C(-12, 12)), + ITypeInstr(opcode=Opcode.OP_IMM_32, rd=Registers.X13, funct3=Funct3.ADD, rs1=Registers.X13, imm=-12), ), # c.srli x10, 32 ( @@ -212,7 +210,7 @@ rd=Registers.X10, funct3=Funct3.SR, rs1=Registers.X10, - rs2=C(0, 5), + rs2=Registers.X0, funct7=Funct7.SL | 1, ), ), @@ -224,7 +222,7 @@ rd=Registers.X12, funct3=Funct3.SR, rs1=Registers.X12, - rs2=C(8, 5), + rs2=Registers.X8, funct7=Funct7.SA | 1, ), ), @@ -260,14 +258,14 @@ rd=Registers.X13, funct3=Funct3.SLL, rs1=Registers.X13, - rs2=C(31, 5), + rs2=Registers.X31, funct7=Funct7.SL | 1, ), ), # c.ldsp x29, 40 - (0x7EA2, ITypeInstr(opcode=Opcode.LOAD, rd=Registers.X29, funct3=Funct3.D, rs1=Registers.SP, imm=C(40, 12))), + (0x7EA2, ITypeInstr(opcode=Opcode.LOAD, rd=Registers.X29, funct3=Funct3.D, rs1=Registers.SP, imm=40)), # c.sdsp x4, 8 - (0xE412, STypeInstr(opcode=Opcode.STORE, imm=C(8, 12), funct3=Funct3.D, rs1=Registers.SP, rs2=Registers.X4)), + (0xE412, STypeInstr(opcode=Opcode.STORE, imm=8, funct3=Funct3.D, rs1=Registers.SP, rs2=Registers.X4)), ] diff --git a/test/params/test_instr.py b/test/params/test_instr.py new file mode 100644 index 000000000..0ed97e19c --- /dev/null +++ b/test/params/test_instr.py @@ -0,0 +1,63 @@ +import unittest +from typing import Sequence + +from amaranth import * + +from coreblocks.params.instr import * +from coreblocks.frontend.decoder.isa import * + + +class InstructionTest(unittest.TestCase): + def do_run(self, test_cases: Sequence[tuple[RISCVInstr, int]]): + for instr, raw_instr in test_cases: + self.assertEqual(instr.encode(), raw_instr) + + def test_r_type(self): + test_cases = [ + (RTypeInstr(opcode=Opcode.OP, rd=21, funct3=Funct3.AND, rs1=10, rs2=31, funct7=Funct7.AND), 0x1F57AB3), + ] + + self.do_run(test_cases) + + def test_i_type(self): + test_cases = [ + (ITypeInstr(opcode=Opcode.LOAD_FP, rd=22, funct3=Funct3.D, rs1=10, imm=2047), 0x7FF53B07), + (ITypeInstr(opcode=Opcode.LOAD_FP, rd=22, funct3=Funct3.D, rs1=10, imm=-2048), 0x80053B07), + ] + + self.do_run(test_cases) + + def test_s_type(self): + test_cases = [ + (STypeInstr(opcode=Opcode.STORE_FP, imm=2047, funct3=Funct3.D, rs1=31, rs2=0), 0x7E0FBFA7), + (STypeInstr(opcode=Opcode.STORE_FP, imm=-2048, funct3=Funct3.D, rs1=5, rs2=13), 0x80D2B027), + ] + + self.do_run(test_cases) + + def test_b_type(self): + test_cases = [ + (BTypeInstr(opcode=Opcode.BRANCH, imm=4094, funct3=Funct3.BNE, rs1=10, rs2=0), 0x7E051FE3), + (BTypeInstr(opcode=Opcode.BRANCH, imm=-4096, funct3=Funct3.BEQ, rs1=31, rs2=4), 0x804F8063), + ] + + self.do_run(test_cases) + + def test_u_type(self): + test_cases = [ + (UTypeInstr(opcode=Opcode.LUI, rd=10, imm=3102 << 12), 0xC1E537), + (UTypeInstr(opcode=Opcode.LUI, rd=31, imm=1048575 << 12), 0xFFFFFFB7), + ] + + self.do_run(test_cases) + + def test_j_type(self): + test_cases = [ + (JTypeInstr(opcode=Opcode.JAL, rd=0, imm=0), 0x6F), + (JTypeInstr(opcode=Opcode.JAL, rd=0, imm=2), 0x20006F), + (JTypeInstr(opcode=Opcode.JAL, rd=10, imm=1048572), 0x7FDFF56F), + (JTypeInstr(opcode=Opcode.JAL, rd=3, imm=-230), 0xF1BFF1EF), + (JTypeInstr(opcode=Opcode.JAL, rd=15, imm=-1048576), 0x800007EF), + ] + + self.do_run(test_cases) diff --git a/test/test_core.py b/test/test_core.py index dbb8692f8..7c080c658 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -79,8 +79,10 @@ def push_register_load_imm(self, reg_id, val): if val & 0x800: lui_imm = (lui_imm + 1) & (0xFFFFF) - yield from self.push_instr(UTypeInstr.encode(Opcode.LUI, reg_id, lui_imm)) - yield from self.push_instr(ITypeInstr.encode(Opcode.OP_IMM, reg_id, Funct3.ADD, reg_id, addi_imm)) + yield from self.push_instr(UTypeInstr(opcode=Opcode.LUI, rd=reg_id, imm=lui_imm).encode()) + yield from self.push_instr( + ITypeInstr(opcode=Opcode.OP_IMM, rd=reg_id, funct3=Funct3.ADD, rs1=reg_id, imm=addi_imm).encode() + ) class TestCoreAsmSourceBase(TestCoreBase): From 4b9ff0631fb441f55ae5874cdb5e829481d24538 Mon Sep 17 00:00:00 2001 From: Jacob Urbanczyk Date: Wed, 27 Mar 2024 16:49:34 +0100 Subject: [PATCH 2/6] Fix test_core --- coreblocks/params/instr.py | 2 +- test/test_core.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/coreblocks/params/instr.py b/coreblocks/params/instr.py index 08cf8322c..e66e90219 100644 --- a/coreblocks/params/instr.py +++ b/coreblocks/params/instr.py @@ -176,7 +176,7 @@ def __init__(self, opcode: ValueLike, **kwargs): class UTypeInstr(RISCVInstr): field_rd = Field(name="rd", base=7, size=5) - field_imm = Field(name="imm", base=12, size=20, offset=12, signed=False) + field_imm = Field(name="imm", base=12, size=20, offset=12, signed=True) def __init__(self, opcode: ValueLike, **kwargs): super().__init__(opcode=Cat(C(0b11, 2), opcode), **kwargs) diff --git a/test/test_core.py b/test/test_core.py index 7c080c658..33419a0d5 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -74,12 +74,8 @@ def push_instr(self, opcode): def push_register_load_imm(self, reg_id, val): addi_imm = signed_to_int(val & 0xFFF, 12) - lui_imm = (val & 0xFFFFF000) >> 12 - # handle addi sign extension, see: https://stackoverflow.com/a/59546567 - if val & 0x800: - lui_imm = (lui_imm + 1) & (0xFFFFF) - yield from self.push_instr(UTypeInstr(opcode=Opcode.LUI, rd=reg_id, imm=lui_imm).encode()) + yield from self.push_instr(UTypeInstr(opcode=Opcode.LUI, rd=reg_id, imm=val).encode()) yield from self.push_instr( ITypeInstr(opcode=Opcode.OP_IMM, rd=reg_id, funct3=Funct3.ADD, rs1=reg_id, imm=addi_imm).encode() ) From e3336f9d329f9a71af52bd5d9e567f5bd36113d7 Mon Sep 17 00:00:00 2001 From: Jacob Urbanczyk Date: Thu, 28 Mar 2024 11:59:16 +0100 Subject: [PATCH 3/6] fix the fix of test_core.py --- test/test_core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_core.py b/test/test_core.py index 33419a0d5..dcf66970b 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -74,8 +74,12 @@ def push_instr(self, opcode): def push_register_load_imm(self, reg_id, val): addi_imm = signed_to_int(val & 0xFFF, 12) + lui_imm = (val & 0xFFFFF000) >> 12 + # handle addi sign extension, see: https://stackoverflow.com/a/59546567 + if val & 0x800: + lui_imm = (lui_imm + 1) & (0xFFFFF) - yield from self.push_instr(UTypeInstr(opcode=Opcode.LUI, rd=reg_id, imm=val).encode()) + yield from self.push_instr(UTypeInstr(opcode=Opcode.LUI, rd=reg_id, imm=lui_imm << 12).encode()) yield from self.push_instr( ITypeInstr(opcode=Opcode.OP_IMM, rd=reg_id, funct3=Funct3.ADD, rs1=reg_id, imm=addi_imm).encode() ) From 120831140c04c1cb27006a7fab2f87c7b13a8261 Mon Sep 17 00:00:00 2001 From: Jacob Urbanczyk Date: Thu, 28 Mar 2024 16:04:11 +0100 Subject: [PATCH 4/6] Use descriptors --- coreblocks/frontend/decoder/isa.py | 1 + coreblocks/params/instr.py | 179 ++++++++++++++++------------- 2 files changed, 102 insertions(+), 78 deletions(-) diff --git a/coreblocks/frontend/decoder/isa.py b/coreblocks/frontend/decoder/isa.py index 229d65c9b..10bb72854 100644 --- a/coreblocks/frontend/decoder/isa.py +++ b/coreblocks/frontend/decoder/isa.py @@ -40,6 +40,7 @@ class Opcode(IntEnum, shape=5): JALR = 0b11001 JAL = 0b11011 SYSTEM = 0b11100 + RESERVED = 0b11111 class Funct3(IntEnum, shape=3): diff --git a/coreblocks/params/instr.py b/coreblocks/params/instr.py index e66e90219..6230dc25c 100644 --- a/coreblocks/params/instr.py +++ b/coreblocks/params/instr.py @@ -30,9 +30,8 @@ ] -@dataclass(frozen=True, kw_only=True) +@dataclass(kw_only=True) class Field: - name: str base: int | list[int] size: int | list[int] @@ -40,46 +39,31 @@ class Field: offset: int = 0 static_value: Optional[Value] = None - def get_base(self) -> list[int]: - if isinstance(self.base, int): - return [self.base] - return self.base + _name: str = "" - def get_size(self) -> list[int]: - if isinstance(self.size, int): - return [self.size] - return self.size + def bases(self) -> list[int]: + return [self.base] if isinstance(self.base, int) else self.base + def sizes(self) -> list[int]: + return [self.size] if isinstance(self.size, int) else self.size -class RISCVInstr(ABC, ValueCastable): - field_opcode = Field(name="opcode", base=0, size=7) - - def __init__(self, **kwargs): - for field in kwargs: - fname = "field_" + field - assert fname in dir(self), "Invalid field {} for {}".format(fname, self.__name__) - setattr(self, field, kwargs[field]) - - @classmethod - def get_fields(cls) -> list[Field]: - return [getattr(cls, member) for member in dir(cls) if member.startswith("field_")] + def width(self) -> int: + return sum(self.sizes()) - def encode(self) -> int: - const = Const.cast(self.as_value()) - return const.value # type: ignore + def __set_name__(self, owner, name): + self._name = name - def __setattr__(self, key, value): - fname = "field_{}".format(key) + def __get__(self, obj, objtype=None) -> Value: + if self.static_value is not None: + return self.static_value - if fname not in dir(self): - super().__setattr__(key, value) - return + return obj.__dict__.get(self._name, C(0, Shape(self.width(), self.signed))) - field = getattr(self, fname) - if field.static_value is not None: + def __set__(self, obj, value) -> None: + if self.static_value is not None: raise AttributeError("Can't overwrite the static value of a field.") - expected_shape = Shape(width=sum(field.get_size()) + field.offset, signed=field.signed) + expected_shape = Shape(width=sum(self.sizes()) + self.offset, signed=self.signed) field_val: Value = C(0) if isinstance(value, Enum): @@ -98,21 +82,41 @@ def __setattr__(self, key, value): f"Expected signedness of the value: {expected_shape.signed}, given: {field_val.shape().signed}" ) - self.__dict__[key] = field_val + obj.__dict__[self._name] = field_val + + +def _get_fields(cls: type) -> list[Field]: + fields = [cls.__dict__[member] for member in vars(cls) if isinstance(cls.__dict__[member], Field)] + field_ids = set([id(field) for field in fields]) + for base in cls.__bases__: + for field in _get_fields(base): + if id(field) in field_ids: + continue + fields.append(field) + field_ids.add(id(field)) + + return fields + + +class RISCVInstr(ABC, ValueCastable): + opcode = Field(base=0, size=7) + + def __init__(self, opcode: Opcode): + self.opcode = Cat(C(0b11, 2), opcode) + + def encode(self) -> int: + const = Const.cast(self.as_value()) + return const.value # type: ignore @ValueCastable.lowermethod def as_value(self) -> Value: parts: list[tuple[int, Value]] = [] - for field in self.get_fields(): - value: Value = C(0) - if field.static_value is not None: - value = field.static_value - else: - value = getattr(self, field.name) + for field in _get_fields(type(self)): + value = field.__get__(self, type(self)) - base = field.get_base() - size = field.get_size() + base = field.bases() + size = field.sizes() offset = field.offset for i in range(len(base)): @@ -127,74 +131,93 @@ def shape(self) -> Shape: class InstructionFunct3Type(RISCVInstr): - field_funct3 = Field(name="funct3", base=12, size=3) - - -class InstructionFunct5Type(RISCVInstr): - field_funct5 = Field(name="funct5", base=27, size=5) + funct3 = Field(base=12, size=3) class InstructionFunct7Type(RISCVInstr): - field_funct7 = Field(name="funct7", base=25, size=7) + funct7 = Field(base=25, size=7) class RTypeInstr(InstructionFunct3Type, InstructionFunct7Type): - field_rd = Field(name="rd", base=7, size=5) - field_rs1 = Field(name="rs1", base=15, size=5) - field_rs2 = Field(name="rs2", base=20, size=5) + rd = Field(base=7, size=5) + rs1 = Field(base=15, size=5) + rs2 = Field(base=20, size=5) - def __init__(self, opcode: ValueLike, **kwargs): - super().__init__(opcode=Cat(C(0b11, 2), opcode), **kwargs) + def __init__( + self, opcode: Opcode, funct3: ValueLike, funct7: ValueLike, rd: ValueLike, rs1: ValueLike, rs2: ValueLike + ): + super().__init__(opcode) + self.funct3 = funct3 + self.funct7 = funct7 + self.rd = rd + self.rs1 = rs1 + self.rs2 = rs2 class ITypeInstr(InstructionFunct3Type): - field_rd = Field(name="rd", base=7, size=5) - field_rs1 = Field(name="rs1", base=15, size=5) - field_imm = Field(name="imm", base=20, size=12, signed=True) + rd = Field(base=7, size=5) + rs1 = Field(base=15, size=5) + imm = Field(base=20, size=12, signed=True) - def __init__(self, opcode: ValueLike, **kwargs): - super().__init__(opcode=Cat(C(0b11, 2), opcode), **kwargs) + def __init__(self, opcode: Opcode, funct3: ValueLike, rd: ValueLike, rs1: ValueLike, imm: ValueLike): + super().__init__(opcode) + self.funct3 = funct3 + self.rd = rd + self.rs1 = rs1 + self.imm = imm class STypeInstr(InstructionFunct3Type): - field_rs1 = Field(name="rs1", base=15, size=5) - field_rs2 = Field(name="rs2", base=20, size=5) - field_imm = Field(name="imm", base=[7, 25], size=[5, 7], signed=True) + rs1 = Field(base=15, size=5) + rs2 = Field(base=20, size=5) + imm = Field(base=[7, 25], size=[5, 7], signed=True) - def __init__(self, opcode: ValueLike, **kwargs): - super().__init__(opcode=Cat(C(0b11, 2), opcode), **kwargs) + def __init__(self, opcode: Opcode, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike, imm: ValueLike): + super().__init__(opcode) + self.funct3 = funct3 + self.rs1 = rs1 + self.rs2 = rs2 + self.imm = imm class BTypeInstr(InstructionFunct3Type): - field_rs1 = Field(name="rs1", base=15, size=5) - field_rs2 = Field(name="rs2", base=20, size=5) - field_imm = Field(name="imm", base=[8, 25, 7, 31], size=[4, 6, 1, 1], offset=1, signed=True) + rs1 = Field(base=15, size=5) + rs2 = Field(base=20, size=5) + imm = Field(base=[8, 25, 7, 31], size=[4, 6, 1, 1], offset=1, signed=True) - def __init__(self, opcode: ValueLike, **kwargs): - super().__init__(opcode=Cat(C(0b11, 2), opcode), **kwargs) + def __init__(self, opcode: Opcode, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike, imm: ValueLike): + super().__init__(opcode) + self.funct3 = funct3 + self.rs1 = rs1 + self.rs2 = rs2 + self.imm = imm class UTypeInstr(RISCVInstr): - field_rd = Field(name="rd", base=7, size=5) - field_imm = Field(name="imm", base=12, size=20, offset=12, signed=True) + rd = Field(base=7, size=5) + imm = Field(base=12, size=20, offset=12, signed=True) - def __init__(self, opcode: ValueLike, **kwargs): - super().__init__(opcode=Cat(C(0b11, 2), opcode), **kwargs) + def __init__(self, opcode: Opcode, rd: ValueLike, imm: ValueLike): + super().__init__(opcode) + self.rd = rd + self.imm = imm class JTypeInstr(RISCVInstr): - field_rd = Field(name="rd", base=7, size=5) - field_imm = Field(name="imm", base=[21, 20, 12, 31], size=[10, 1, 8, 1], offset=1, signed=True) + rd = Field(base=7, size=5) + imm = Field(base=[21, 20, 12, 31], size=[10, 1, 8, 1], offset=1, signed=True) - def __init__(self, opcode: ValueLike, **kwargs): - super().__init__(opcode=Cat(C(0b11, 2), opcode), **kwargs) + def __init__(self, opcode: Opcode, rd: ValueLike, imm: ValueLike): + super().__init__(opcode) + self.rd = rd + self.imm = imm class IllegalInstr(RISCVInstr): - field_illegal = Field(name="illegal", base=7, size=25, static_value=Cat(1).replicate(25)) + illegal = Field(base=7, size=25, static_value=Cat(1).replicate(25)) def __init__(self): - super().__init__(opcode=0b1111111) + super().__init__(opcode=Opcode.RESERVED) class EBreakInstr(ITypeInstr): From 9162196e125edc2a80ab27577b72f1a1b5e88723 Mon Sep 17 00:00:00 2001 From: Jacob Urbanczyk Date: Sun, 31 Mar 2024 17:04:58 +0200 Subject: [PATCH 5/6] PR review --- coreblocks/params/instr.py | 50 +++++++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/coreblocks/params/instr.py b/coreblocks/params/instr.py index 6230dc25c..7c12658fc 100644 --- a/coreblocks/params/instr.py +++ b/coreblocks/params/instr.py @@ -32,6 +32,26 @@ @dataclass(kw_only=True) class Field: + """Information about a field in a RISC-V instruction. + + Attributes + ---------- + base: int | list[int] + A bit position (or a list of positions) where this field (or parts of the field) + would map in the instruction. + size: int | list[int] + Size (or sizes of the parts) of the field + signed: bool + Whether this field encodes a signed value. + offset: int + How many bits of this field should be skipped when encoding the instruction. + For example, the immediate of the jump instruction always skips the least + significant bit. This only affects encoding procedures, so externally (for example + when creating an instance of a instruction) full-size values should be always used. + static_value: Optional[Value] + Whether the field should have a static value for a given type of an instruction. + """ + base: int | list[int] size: int | list[int] @@ -47,8 +67,8 @@ def bases(self) -> list[int]: def sizes(self) -> list[int]: return [self.size] if isinstance(self.size, int) else self.size - def width(self) -> int: - return sum(self.sizes()) + def shape(self) -> Shape: + return Shape(width=sum(self.sizes()) + self.offset, signed=self.signed) def __set_name__(self, owner, name): self._name = name @@ -57,13 +77,13 @@ def __get__(self, obj, objtype=None) -> Value: if self.static_value is not None: return self.static_value - return obj.__dict__.get(self._name, C(0, Shape(self.width(), self.signed))) + return obj.__dict__.get(self._name, C(0, self.shape())) def __set__(self, obj, value) -> None: if self.static_value is not None: raise AttributeError("Can't overwrite the static value of a field.") - expected_shape = Shape(width=sum(self.sizes()) + self.offset, signed=self.signed) + expected_shape = self.shape() field_val: Value = C(0) if isinstance(value, Enum): @@ -85,6 +105,19 @@ def __set__(self, obj, value) -> None: obj.__dict__[self._name] = field_val + def get_parts(self, value: Value) -> list[Value]: + base = self.bases() + size = self.sizes() + offset = self.offset + + ret: list[Value] = [] + for i in range(len(base)): + ret.append(value[offset : offset + size[i]]) + offset += size[i] + + return ret + + def _get_fields(cls: type) -> list[Field]: fields = [cls.__dict__[member] for member in vars(cls) if isinstance(cls.__dict__[member], Field)] field_ids = set([id(field) for field in fields]) @@ -114,14 +147,7 @@ def as_value(self) -> Value: for field in _get_fields(type(self)): value = field.__get__(self, type(self)) - - base = field.bases() - size = field.sizes() - - offset = field.offset - for i in range(len(base)): - parts.append((base[i], value[offset : offset + size[i]])) - offset += size[i] + parts += zip(field.bases(), field.get_parts(value)) parts.sort() return Cat([part[1] for part in parts]) From 7a185018aec6e9313f50c9cd8dd72456177f2e22 Mon Sep 17 00:00:00 2001 From: Jacob Urbanczyk Date: Sun, 31 Mar 2024 17:12:32 +0200 Subject: [PATCH 6/6] Formatting --- coreblocks/params/instr.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/coreblocks/params/instr.py b/coreblocks/params/instr.py index 7c12658fc..f3755b25d 100644 --- a/coreblocks/params/instr.py +++ b/coreblocks/params/instr.py @@ -104,7 +104,6 @@ def __set__(self, obj, value) -> None: obj.__dict__[self._name] = field_val - def get_parts(self, value: Value) -> list[Value]: base = self.bases() size = self.sizes() @@ -114,7 +113,7 @@ def get_parts(self, value: Value) -> list[Value]: for i in range(len(base)): ret.append(value[offset : offset + size[i]]) offset += size[i] - + return ret