Skip to content

Commit

Permalink
Refactor RISC-V instruction models (#631)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jakub Urbańczyk authored Apr 1, 2024
1 parent 30b55cd commit 8ec353d
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 156 deletions.
1 change: 1 addition & 0 deletions coreblocks/frontend/decoder/isa.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Opcode(IntEnum, shape=5):
JALR = 0b11001
JAL = 0b11011
SYSTEM = 0b11100
RESERVED = 0b11111


class Funct3(IntEnum, shape=3):
Expand Down
2 changes: 1 addition & 1 deletion coreblocks/frontend/decoder/rvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _quadrant_2(self) -> list[DecodedInstr]:
shamt = Cat(self.instr_in[2:7], self.instr_in[12])
ldsp_imm = Cat(C(0, 3), self.instr_in[5:7], self.instr_in[12], self.instr_in[2:5], C(0, 3))
lwsp_imm = Cat(C(0, 2), self.instr_in[4:7], self.instr_in[12], self.instr_in[2:4], C(0, 4))
sdsp_imm = Cat(C(0, 3), self.instr_in[10:13], self.instr_in[7:10], C(0, 2))
sdsp_imm = Cat(C(0, 3), self.instr_in[10:13], self.instr_in[7:10], C(0, 3))
swsp_imm = Cat(C(0, 2), self.instr_in[9:13], self.instr_in[7:9], C(0, 4))

slli = (
Expand Down
333 changes: 204 additions & 129 deletions coreblocks/params/instr.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from abc import abstractmethod, ABC
"""
Based on riscv-python-model by Stefan Wallentowitz
https://github.com/wallento/riscv-python-model
"""

from dataclasses import dataclass
from abc import ABC
from enum import Enum
from typing import Optional

from amaranth.hdl import ValueCastable
from amaranth import *

from transactron.utils import ValueLike, int_to_signed
from transactron.utils import ValueLike
from coreblocks.params.isa_params import *
from coreblocks.frontend.decoder.isa import *


__all__ = [
"RISCVInstr",
"RTypeInstr",
"ITypeInstr",
"STypeInstr",
Expand All @@ -20,154 +30,219 @@
]


@dataclass(kw_only=True)
class Field:
"""Information about a field in a RISC-V instruction.
Attributes
----------
base: int | list[int]
A bit position (or a list of positions) where this field (or parts of the field)
would map in the instruction.
size: int | list[int]
Size (or sizes of the parts) of the field
signed: bool
Whether this field encodes a signed value.
offset: int
How many bits of this field should be skipped when encoding the instruction.
For example, the immediate of the jump instruction always skips the least
significant bit. This only affects encoding procedures, so externally (for example
when creating an instance of a instruction) full-size values should be always used.
static_value: Optional[Value]
Whether the field should have a static value for a given type of an instruction.
"""

base: int | list[int]
size: int | list[int]

signed: bool = False
offset: int = 0
static_value: Optional[Value] = None

_name: str = ""

def bases(self) -> list[int]:
return [self.base] if isinstance(self.base, int) else self.base

def sizes(self) -> list[int]:
return [self.size] if isinstance(self.size, int) else self.size

def shape(self) -> Shape:
return Shape(width=sum(self.sizes()) + self.offset, signed=self.signed)

def __set_name__(self, owner, name):
self._name = name

def __get__(self, obj, objtype=None) -> Value:
if self.static_value is not None:
return self.static_value

return obj.__dict__.get(self._name, C(0, self.shape()))

def __set__(self, obj, value) -> None:
if self.static_value is not None:
raise AttributeError("Can't overwrite the static value of a field.")

expected_shape = self.shape()

field_val: Value = C(0)
if isinstance(value, Enum):
field_val = Const(value.value, expected_shape)
elif isinstance(value, int):
field_val = Const(value, expected_shape)
else:
field_val = Value.cast(value)

if field_val.shape().width != expected_shape.width:
raise AttributeError(
f"Expected width of the value: {expected_shape.width}, given: {field_val.shape().width}"
)
if field_val.shape().signed and not expected_shape.signed:
raise AttributeError(
f"Expected signedness of the value: {expected_shape.signed}, given: {field_val.shape().signed}"
)

obj.__dict__[self._name] = field_val

def get_parts(self, value: Value) -> list[Value]:
base = self.bases()
size = self.sizes()
offset = self.offset

ret: list[Value] = []
for i in range(len(base)):
ret.append(value[offset : offset + size[i]])
offset += size[i]

return ret


def _get_fields(cls: type) -> list[Field]:
fields = [cls.__dict__[member] for member in vars(cls) if isinstance(cls.__dict__[member], Field)]
field_ids = set([id(field) for field in fields])
for base in cls.__bases__:
for field in _get_fields(base):
if id(field) in field_ids:
continue
fields.append(field)
field_ids.add(id(field))

return fields


class RISCVInstr(ABC, ValueCastable):
@abstractmethod
def pack(self) -> Value:
pass
opcode = Field(base=0, size=7)

def __init__(self, opcode: Opcode):
self.opcode = Cat(C(0b11, 2), opcode)

def encode(self) -> int:
const = Const.cast(self.as_value())
return const.value # type: ignore

@ValueCastable.lowermethod
def as_value(self):
return self.pack()
def as_value(self) -> Value:
parts: list[tuple[int, Value]] = []

for field in _get_fields(type(self)):
value = field.__get__(self, type(self))
parts += zip(field.bases(), field.get_parts(value))

parts.sort()
return Cat([part[1] for part in parts])

def shape(self):
def shape(self) -> Shape:
return self.as_value().shape()


class RTypeInstr(RISCVInstr):
class InstructionFunct3Type(RISCVInstr):
funct3 = Field(base=12, size=3)


class InstructionFunct7Type(RISCVInstr):
funct7 = Field(base=25, size=7)


class RTypeInstr(InstructionFunct3Type, InstructionFunct7Type):
rd = Field(base=7, size=5)
rs1 = Field(base=15, size=5)
rs2 = Field(base=20, size=5)

def __init__(
self,
opcode: ValueLike,
rd: ValueLike,
funct3: ValueLike,
rs1: ValueLike,
rs2: ValueLike,
funct7: ValueLike,
self, opcode: Opcode, funct3: ValueLike, funct7: ValueLike, rd: ValueLike, rs1: ValueLike, rs2: ValueLike
):
self.opcode = Value.cast(opcode)
self.rd = Value.cast(rd)
self.funct3 = Value.cast(funct3)
self.rs1 = Value.cast(rs1)
self.rs2 = Value.cast(rs2)
self.funct7 = Value.cast(funct7)

def pack(self) -> Value:
return Cat(C(0b11, 2), self.opcode, self.rd, self.funct3, self.rs1, self.rs2, self.funct7)

@staticmethod
def encode(opcode: int, rd: int, funct3: int, rs1: int, rs2: int, funct7: int):
return int(f"{funct7:07b}{rs2:05b}{rs1:05b}{funct3:03b}{rd:05b}{opcode:05b}11", 2)


class ITypeInstr(RISCVInstr):
def __init__(self, opcode: ValueLike, rd: ValueLike, funct3: ValueLike, rs1: ValueLike, imm: ValueLike):
self.opcode = Value.cast(opcode)
self.rd = Value.cast(rd)
self.funct3 = Value.cast(funct3)
self.rs1 = Value.cast(rs1)
self.imm = Value.cast(imm)

def pack(self) -> Value:
return Cat(C(0b11, 2), self.opcode, self.rd, self.funct3, self.rs1, self.imm)

@staticmethod
def encode(opcode: int, rd: int, funct3: int, rs1: int, imm: int):
imm = int_to_signed(imm, 12)
return int(f"{imm:012b}{rs1:05b}{funct3:03b}{rd:05b}{opcode:05b}11", 2)


class STypeInstr(RISCVInstr):
def __init__(self, opcode: ValueLike, imm: ValueLike, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike):
self.opcode = Value.cast(opcode)
self.imm = Value.cast(imm)
self.funct3 = Value.cast(funct3)
self.rs1 = Value.cast(rs1)
self.rs2 = Value.cast(rs2)

def pack(self) -> Value:
return Cat(C(0b11, 2), self.opcode, self.imm[0:5], self.funct3, self.rs1, self.rs2, self.imm[5:12])

@staticmethod
def encode(opcode: int, imm: int, funct3: int, rs1: int, rs2: int):
imm = int_to_signed(imm, 12)
imm_str = f"{imm:012b}"
return int(f"{imm_str[5:12]:07b}{rs2:05b}{rs1:05b}{funct3:03b}{imm_str[0:5]:05b}{opcode:05b}11", 2)


class BTypeInstr(RISCVInstr):
def __init__(self, opcode: ValueLike, imm: ValueLike, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike):
self.opcode = Value.cast(opcode)
self.imm = Value.cast(imm)
self.funct3 = Value.cast(funct3)
self.rs1 = Value.cast(rs1)
self.rs2 = Value.cast(rs2)

def pack(self) -> Value:
return Cat(
C(0b11, 2),
self.opcode,
self.imm[11],
self.imm[1:5],
self.funct3,
self.rs1,
self.rs2,
self.imm[5:11],
self.imm[12],
)
super().__init__(opcode)
self.funct3 = funct3
self.funct7 = funct7
self.rd = rd
self.rs1 = rs1
self.rs2 = rs2

@staticmethod
def encode(opcode: int, imm: int, funct3: int, rs1: int, rs2: int):
imm = int_to_signed(imm, 13)
imm_str = f"{imm:013b}"
return int(
f"{imm_str[12]:01b}{imm_str[5:11]:06b}{rs2:05b}{rs1:05b}{funct3:03b}{imm_str[1:5]:04b}"
+ f"{imm_str[11]:01b}{opcode:05b}11",
2,
)

class ITypeInstr(InstructionFunct3Type):
rd = Field(base=7, size=5)
rs1 = Field(base=15, size=5)
imm = Field(base=20, size=12, signed=True)

def __init__(self, opcode: Opcode, funct3: ValueLike, rd: ValueLike, rs1: ValueLike, imm: ValueLike):
super().__init__(opcode)
self.funct3 = funct3
self.rd = rd
self.rs1 = rs1
self.imm = imm

class UTypeInstr(RISCVInstr):
def __init__(self, opcode: ValueLike, rd: ValueLike, imm: ValueLike):
self.opcode = Value.cast(opcode)
self.rd = Value.cast(rd)
self.imm = Value.cast(imm)

def pack(self) -> Value:
return Cat(C(0b11, 2), self.opcode, self.rd, self.imm[12:])
class STypeInstr(InstructionFunct3Type):
rs1 = Field(base=15, size=5)
rs2 = Field(base=20, size=5)
imm = Field(base=[7, 25], size=[5, 7], signed=True)

@staticmethod
def encode(opcode: int, rd: int, imm: int):
imm = int_to_signed(imm, 20)
return int(f"{imm:020b}{rd:05b}{opcode:05b}11", 2)
def __init__(self, opcode: Opcode, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike, imm: ValueLike):
super().__init__(opcode)
self.funct3 = funct3
self.rs1 = rs1
self.rs2 = rs2
self.imm = imm


class BTypeInstr(InstructionFunct3Type):
rs1 = Field(base=15, size=5)
rs2 = Field(base=20, size=5)
imm = Field(base=[8, 25, 7, 31], size=[4, 6, 1, 1], offset=1, signed=True)

def __init__(self, opcode: Opcode, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike, imm: ValueLike):
super().__init__(opcode)
self.funct3 = funct3
self.rs1 = rs1
self.rs2 = rs2
self.imm = imm


class UTypeInstr(RISCVInstr):
rd = Field(base=7, size=5)
imm = Field(base=12, size=20, offset=12, signed=True)

def __init__(self, opcode: Opcode, rd: ValueLike, imm: ValueLike):
super().__init__(opcode)
self.rd = rd
self.imm = imm


class JTypeInstr(RISCVInstr):
def __init__(self, opcode: ValueLike, rd: ValueLike, imm: ValueLike):
self.opcode = Value.cast(opcode)
self.rd = Value.cast(rd)
self.imm = Value.cast(imm)

def pack(self) -> Value:
return Cat(C(0b11, 2), self.opcode, self.rd, self.imm[12:20], self.imm[11], self.imm[1:11], self.imm[20])

@staticmethod
def encode(opcode: int, rd: int, imm: int):
imm = int_to_signed(imm, 21)
imm_str = f"{imm:021b}"
return int(
f"{imm_str[20]:01b}{imm_str[1:11]:010b}{imm_str[11]:01b}{imm_str[12:20]:08b}{rd:05b}{opcode:05b}11", 2
)
rd = Field(base=7, size=5)
imm = Field(base=[21, 20, 12, 31], size=[10, 1, 8, 1], offset=1, signed=True)

def __init__(self, opcode: Opcode, rd: ValueLike, imm: ValueLike):
super().__init__(opcode)
self.rd = rd
self.imm = imm

class IllegalInstr(RISCVInstr):
def __init__(self):
pass

def pack(self) -> Value:
return C(1).replicate(32) # Instructions with all bits set to 1 are reserved to be illegal.
class IllegalInstr(RISCVInstr):
illegal = Field(base=7, size=25, static_value=Cat(1).replicate(25))

@staticmethod
def encode(opcode: int, rd: int, imm: int):
return int("1" * 32, 2)
def __init__(self):
super().__init__(opcode=Opcode.RESERVED)


class EBreakInstr(ITypeInstr):
Expand Down
Loading

0 comments on commit 8ec353d

Please sign in to comment.