diff --git a/coreblocks/params/instr.py b/coreblocks/params/instr.py index 6230dc25c..7c12658fc 100644 --- a/coreblocks/params/instr.py +++ b/coreblocks/params/instr.py @@ -32,6 +32,26 @@ @dataclass(kw_only=True) class Field: + """Information about a field in a RISC-V instruction. + + Attributes + ---------- + base: int | list[int] + A bit position (or a list of positions) where this field (or parts of the field) + would map in the instruction. + size: int | list[int] + Size (or sizes of the parts) of the field + signed: bool + Whether this field encodes a signed value. + offset: int + How many bits of this field should be skipped when encoding the instruction. + For example, the immediate of the jump instruction always skips the least + significant bit. This only affects encoding procedures, so externally (for example + when creating an instance of a instruction) full-size values should be always used. + static_value: Optional[Value] + Whether the field should have a static value for a given type of an instruction. + """ + base: int | list[int] size: int | list[int] @@ -47,8 +67,8 @@ def bases(self) -> list[int]: def sizes(self) -> list[int]: return [self.size] if isinstance(self.size, int) else self.size - def width(self) -> int: - return sum(self.sizes()) + def shape(self) -> Shape: + return Shape(width=sum(self.sizes()) + self.offset, signed=self.signed) def __set_name__(self, owner, name): self._name = name @@ -57,13 +77,13 @@ def __get__(self, obj, objtype=None) -> Value: if self.static_value is not None: return self.static_value - return obj.__dict__.get(self._name, C(0, Shape(self.width(), self.signed))) + return obj.__dict__.get(self._name, C(0, self.shape())) def __set__(self, obj, value) -> None: if self.static_value is not None: raise AttributeError("Can't overwrite the static value of a field.") - expected_shape = Shape(width=sum(self.sizes()) + self.offset, signed=self.signed) + expected_shape = self.shape() field_val: Value = C(0) if isinstance(value, Enum): @@ -85,6 +105,19 @@ def __set__(self, obj, value) -> None: obj.__dict__[self._name] = field_val + def get_parts(self, value: Value) -> list[Value]: + base = self.bases() + size = self.sizes() + offset = self.offset + + ret: list[Value] = [] + for i in range(len(base)): + ret.append(value[offset : offset + size[i]]) + offset += size[i] + + return ret + + def _get_fields(cls: type) -> list[Field]: fields = [cls.__dict__[member] for member in vars(cls) if isinstance(cls.__dict__[member], Field)] field_ids = set([id(field) for field in fields]) @@ -114,14 +147,7 @@ def as_value(self) -> Value: for field in _get_fields(type(self)): value = field.__get__(self, type(self)) - - base = field.bases() - size = field.sizes() - - offset = field.offset - for i in range(len(base)): - parts.append((base[i], value[offset : offset + size[i]])) - offset += size[i] + parts += zip(field.bases(), field.get_parts(value)) parts.sort() return Cat([part[1] for part in parts])