Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor RISC-V instruction models #631

Merged
merged 7 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions coreblocks/frontend/decoder/isa.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Opcode(IntEnum, shape=5):
JALR = 0b11001
JAL = 0b11011
SYSTEM = 0b11100
RESERVED = 0b11111


class Funct3(IntEnum, shape=3):
Expand Down
2 changes: 1 addition & 1 deletion coreblocks/frontend/decoder/rvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _quadrant_2(self) -> list[DecodedInstr]:
shamt = Cat(self.instr_in[2:7], self.instr_in[12])
ldsp_imm = Cat(C(0, 3), self.instr_in[5:7], self.instr_in[12], self.instr_in[2:5], C(0, 3))
lwsp_imm = Cat(C(0, 2), self.instr_in[4:7], self.instr_in[12], self.instr_in[2:4], C(0, 4))
sdsp_imm = Cat(C(0, 3), self.instr_in[10:13], self.instr_in[7:10], C(0, 2))
sdsp_imm = Cat(C(0, 3), self.instr_in[10:13], self.instr_in[7:10], C(0, 3))
swsp_imm = Cat(C(0, 2), self.instr_in[9:13], self.instr_in[7:9], C(0, 4))

slli = (
Expand Down
308 changes: 179 additions & 129 deletions coreblocks/params/instr.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from abc import abstractmethod, ABC
"""

Based on riscv-python-model by Stefan Wallentowitz
https://github.com/wallento/riscv-python-model
"""

from dataclasses import dataclass
from abc import ABC
from enum import Enum
from typing import Optional

from amaranth.hdl import ValueCastable
from amaranth import *

from transactron.utils import ValueLike, int_to_signed
from transactron.utils import ValueLike
from coreblocks.params.isa_params import *
from coreblocks.frontend.decoder.isa import *


__all__ = [
"RISCVInstr",
"RTypeInstr",
"ITypeInstr",
"STypeInstr",
Expand All @@ -20,154 +30,194 @@
]


@dataclass(kw_only=True)
class Field:
base: int | list[int]
size: int | list[int]

signed: bool = False
offset: int = 0
xThaid marked this conversation as resolved.
Show resolved Hide resolved
static_value: Optional[Value] = None

_name: str = ""

def bases(self) -> list[int]:
return [self.base] if isinstance(self.base, int) else self.base

def sizes(self) -> list[int]:
return [self.size] if isinstance(self.size, int) else self.size

def width(self) -> int:
return sum(self.sizes())

def __set_name__(self, owner, name):
self._name = name

def __get__(self, obj, objtype=None) -> Value:
if self.static_value is not None:
return self.static_value

return obj.__dict__.get(self._name, C(0, Shape(self.width(), self.signed)))
xThaid marked this conversation as resolved.
Show resolved Hide resolved

def __set__(self, obj, value) -> None:
xThaid marked this conversation as resolved.
Show resolved Hide resolved
if self.static_value is not None:
raise AttributeError("Can't overwrite the static value of a field.")

expected_shape = Shape(width=sum(self.sizes()) + self.offset, signed=self.signed)

field_val: Value = C(0)
if isinstance(value, Enum):
field_val = Const(value.value, expected_shape)
elif isinstance(value, int):
field_val = Const(value, expected_shape)
else:
field_val = Value.cast(value)

if field_val.shape().width != expected_shape.width:
raise AttributeError(
f"Expected width of the value: {expected_shape.width}, given: {field_val.shape().width}"
)
if field_val.shape().signed and not expected_shape.signed:
raise AttributeError(
f"Expected signedness of the value: {expected_shape.signed}, given: {field_val.shape().signed}"
)

obj.__dict__[self._name] = field_val


def _get_fields(cls: type) -> list[Field]:
fields = [cls.__dict__[member] for member in vars(cls) if isinstance(cls.__dict__[member], Field)]
field_ids = set([id(field) for field in fields])
for base in cls.__bases__:
for field in _get_fields(base):
if id(field) in field_ids:
continue
fields.append(field)
field_ids.add(id(field))

return fields


class RISCVInstr(ABC, ValueCastable):
@abstractmethod
def pack(self) -> Value:
pass
opcode = Field(base=0, size=7)

def __init__(self, opcode: Opcode):
self.opcode = Cat(C(0b11, 2), opcode)

def encode(self) -> int:
const = Const.cast(self.as_value())
return const.value # type: ignore

@ValueCastable.lowermethod
def as_value(self):
return self.pack()
def as_value(self) -> Value:
parts: list[tuple[int, Value]] = []

for field in _get_fields(type(self)):
value = field.__get__(self, type(self))

base = field.bases()
size = field.sizes()

def shape(self):
offset = field.offset
for i in range(len(base)):
parts.append((base[i], value[offset : offset + size[i]]))
xThaid marked this conversation as resolved.
Show resolved Hide resolved
offset += size[i]

parts.sort()
return Cat([part[1] for part in parts])

def shape(self) -> Shape:
return self.as_value().shape()


class RTypeInstr(RISCVInstr):
class InstructionFunct3Type(RISCVInstr):
funct3 = Field(base=12, size=3)


class InstructionFunct7Type(RISCVInstr):
funct7 = Field(base=25, size=7)


class RTypeInstr(InstructionFunct3Type, InstructionFunct7Type):
rd = Field(base=7, size=5)
rs1 = Field(base=15, size=5)
rs2 = Field(base=20, size=5)

def __init__(
self,
opcode: ValueLike,
rd: ValueLike,
funct3: ValueLike,
rs1: ValueLike,
rs2: ValueLike,
funct7: ValueLike,
self, opcode: Opcode, funct3: ValueLike, funct7: ValueLike, rd: ValueLike, rs1: ValueLike, rs2: ValueLike
):
self.opcode = Value.cast(opcode)
self.rd = Value.cast(rd)
self.funct3 = Value.cast(funct3)
self.rs1 = Value.cast(rs1)
self.rs2 = Value.cast(rs2)
self.funct7 = Value.cast(funct7)

def pack(self) -> Value:
return Cat(C(0b11, 2), self.opcode, self.rd, self.funct3, self.rs1, self.rs2, self.funct7)

@staticmethod
def encode(opcode: int, rd: int, funct3: int, rs1: int, rs2: int, funct7: int):
return int(f"{funct7:07b}{rs2:05b}{rs1:05b}{funct3:03b}{rd:05b}{opcode:05b}11", 2)


class ITypeInstr(RISCVInstr):
def __init__(self, opcode: ValueLike, rd: ValueLike, funct3: ValueLike, rs1: ValueLike, imm: ValueLike):
self.opcode = Value.cast(opcode)
self.rd = Value.cast(rd)
self.funct3 = Value.cast(funct3)
self.rs1 = Value.cast(rs1)
self.imm = Value.cast(imm)

def pack(self) -> Value:
return Cat(C(0b11, 2), self.opcode, self.rd, self.funct3, self.rs1, self.imm)

@staticmethod
def encode(opcode: int, rd: int, funct3: int, rs1: int, imm: int):
imm = int_to_signed(imm, 12)
return int(f"{imm:012b}{rs1:05b}{funct3:03b}{rd:05b}{opcode:05b}11", 2)


class STypeInstr(RISCVInstr):
def __init__(self, opcode: ValueLike, imm: ValueLike, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike):
self.opcode = Value.cast(opcode)
self.imm = Value.cast(imm)
self.funct3 = Value.cast(funct3)
self.rs1 = Value.cast(rs1)
self.rs2 = Value.cast(rs2)

def pack(self) -> Value:
return Cat(C(0b11, 2), self.opcode, self.imm[0:5], self.funct3, self.rs1, self.rs2, self.imm[5:12])

@staticmethod
def encode(opcode: int, imm: int, funct3: int, rs1: int, rs2: int):
imm = int_to_signed(imm, 12)
imm_str = f"{imm:012b}"
return int(f"{imm_str[5:12]:07b}{rs2:05b}{rs1:05b}{funct3:03b}{imm_str[0:5]:05b}{opcode:05b}11", 2)


class BTypeInstr(RISCVInstr):
def __init__(self, opcode: ValueLike, imm: ValueLike, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike):
self.opcode = Value.cast(opcode)
self.imm = Value.cast(imm)
self.funct3 = Value.cast(funct3)
self.rs1 = Value.cast(rs1)
self.rs2 = Value.cast(rs2)

def pack(self) -> Value:
return Cat(
C(0b11, 2),
self.opcode,
self.imm[11],
self.imm[1:5],
self.funct3,
self.rs1,
self.rs2,
self.imm[5:11],
self.imm[12],
)
super().__init__(opcode)
self.funct3 = funct3
self.funct7 = funct7
self.rd = rd
self.rs1 = rs1
self.rs2 = rs2

@staticmethod
def encode(opcode: int, imm: int, funct3: int, rs1: int, rs2: int):
imm = int_to_signed(imm, 13)
imm_str = f"{imm:013b}"
return int(
f"{imm_str[12]:01b}{imm_str[5:11]:06b}{rs2:05b}{rs1:05b}{funct3:03b}{imm_str[1:5]:04b}"
+ f"{imm_str[11]:01b}{opcode:05b}11",
2,
)

class ITypeInstr(InstructionFunct3Type):
rd = Field(base=7, size=5)
rs1 = Field(base=15, size=5)
imm = Field(base=20, size=12, signed=True)

class UTypeInstr(RISCVInstr):
def __init__(self, opcode: ValueLike, rd: ValueLike, imm: ValueLike):
self.opcode = Value.cast(opcode)
self.rd = Value.cast(rd)
self.imm = Value.cast(imm)
def __init__(self, opcode: Opcode, funct3: ValueLike, rd: ValueLike, rs1: ValueLike, imm: ValueLike):
super().__init__(opcode)
self.funct3 = funct3
self.rd = rd
self.rs1 = rs1
self.imm = imm


class STypeInstr(InstructionFunct3Type):
rs1 = Field(base=15, size=5)
rs2 = Field(base=20, size=5)
imm = Field(base=[7, 25], size=[5, 7], signed=True)

def pack(self) -> Value:
return Cat(C(0b11, 2), self.opcode, self.rd, self.imm[12:])
def __init__(self, opcode: Opcode, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike, imm: ValueLike):
super().__init__(opcode)
self.funct3 = funct3
self.rs1 = rs1
self.rs2 = rs2
self.imm = imm

@staticmethod
def encode(opcode: int, rd: int, imm: int):
imm = int_to_signed(imm, 20)
return int(f"{imm:020b}{rd:05b}{opcode:05b}11", 2)

class BTypeInstr(InstructionFunct3Type):
rs1 = Field(base=15, size=5)
rs2 = Field(base=20, size=5)
imm = Field(base=[8, 25, 7, 31], size=[4, 6, 1, 1], offset=1, signed=True)

def __init__(self, opcode: Opcode, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike, imm: ValueLike):
super().__init__(opcode)
self.funct3 = funct3
self.rs1 = rs1
self.rs2 = rs2
self.imm = imm


class UTypeInstr(RISCVInstr):
rd = Field(base=7, size=5)
imm = Field(base=12, size=20, offset=12, signed=True)

def __init__(self, opcode: Opcode, rd: ValueLike, imm: ValueLike):
super().__init__(opcode)
self.rd = rd
self.imm = imm


class JTypeInstr(RISCVInstr):
def __init__(self, opcode: ValueLike, rd: ValueLike, imm: ValueLike):
self.opcode = Value.cast(opcode)
self.rd = Value.cast(rd)
self.imm = Value.cast(imm)

def pack(self) -> Value:
return Cat(C(0b11, 2), self.opcode, self.rd, self.imm[12:20], self.imm[11], self.imm[1:11], self.imm[20])

@staticmethod
def encode(opcode: int, rd: int, imm: int):
imm = int_to_signed(imm, 21)
imm_str = f"{imm:021b}"
return int(
f"{imm_str[20]:01b}{imm_str[1:11]:010b}{imm_str[11]:01b}{imm_str[12:20]:08b}{rd:05b}{opcode:05b}11", 2
)
rd = Field(base=7, size=5)
imm = Field(base=[21, 20, 12, 31], size=[10, 1, 8, 1], offset=1, signed=True)

def __init__(self, opcode: Opcode, rd: ValueLike, imm: ValueLike):
super().__init__(opcode)
self.rd = rd
self.imm = imm

class IllegalInstr(RISCVInstr):
def __init__(self):
pass

def pack(self) -> Value:
return C(1).replicate(32) # Instructions with all bits set to 1 are reserved to be illegal.
class IllegalInstr(RISCVInstr):
illegal = Field(base=7, size=25, static_value=Cat(1).replicate(25))

@staticmethod
def encode(opcode: int, rd: int, imm: int):
return int("1" * 32, 2)
def __init__(self):
super().__init__(opcode=Opcode.RESERVED)


class EBreakInstr(ITypeInstr):
Expand Down
Loading